Spaces:
Running
on
T4
Running
on
T4
| import os, sys | |
| import torch | |
| # Import files from same folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from opt import opt | |
| from architecture.rrdb import RRDBNet | |
| from architecture.grl import GRL | |
| from architecture.dat import DAT | |
| from architecture.swinir import SwinIR | |
| from architecture.cunet import UNet_Full | |
| def load_rrdb(generator_weight_PATH, scale, print_options=False): | |
| ''' A simpler API to load RRDB model from Real-ESRGAN | |
| Args: | |
| generator_weight_PATH (str): The path to the weight | |
| scale (int): the scaling factor | |
| print_options (bool): whether to print options to show what kinds of setting is used | |
| Returns: | |
| generator (torch): the generator instance of the model | |
| ''' | |
| # Load the checkpoint | |
| checkpoint_g = torch.load(generator_weight_PATH) | |
| # Find the generator weight | |
| if 'params_ema' in checkpoint_g: | |
| # For official ESRNET/ESRGAN weight | |
| weight = checkpoint_g['params_ema'] | |
| generator = RRDBNet(3, 3, scale=scale) # Default blocks num is 6 | |
| elif 'params' in checkpoint_g: | |
| # For official ESRNET/ESRGAN weight | |
| weight = checkpoint_g['params'] | |
| generator = RRDBNet(3, 3, scale=scale) | |
| elif 'model_state_dict' in checkpoint_g: | |
| # For my personal trained weight | |
| weight = checkpoint_g['model_state_dict'] | |
| generator = RRDBNet(3, 3, scale=scale) | |
| else: | |
| print("This weight is not supported") | |
| os._exit(0) | |
| # Handle torch.compile weight key rename | |
| old_keys = [key for key in weight] | |
| for old_key in old_keys: | |
| if old_key[:10] == "_orig_mod.": | |
| new_key = old_key[10:] | |
| weight[new_key] = weight[old_key] | |
| del weight[old_key] | |
| generator.load_state_dict(weight) | |
| generator = generator.eval().cuda() | |
| # Print options to show what kinds of setting is used | |
| if print_options: | |
| if 'opt' in checkpoint_g: | |
| for key in checkpoint_g['opt']: | |
| value = checkpoint_g['opt'][key] | |
| print(f'{key} : {value}') | |
| return generator | |
| def load_cunet(generator_weight_PATH, scale, print_options=False): | |
| ''' A simpler API to load CUNET model from Real-CUGAN | |
| Args: | |
| generator_weight_PATH (str): The path to the weight | |
| scale (int): the scaling factor | |
| print_options (bool): whether to print options to show what kinds of setting is used | |
| Returns: | |
| generator (torch): the generator instance of the model | |
| ''' | |
| # This func is deprecated now | |
| if scale != 2: | |
| raise NotImplementedError("We only support 2x in CUNET") | |
| # Load the checkpoint | |
| checkpoint_g = torch.load(generator_weight_PATH) | |
| # Find the generator weight | |
| if 'model_state_dict' in checkpoint_g: | |
| # For my personal trained weight | |
| weight = checkpoint_g['model_state_dict'] | |
| loss = checkpoint_g["lowest_generator_weight"] | |
| if "iteration" in checkpoint_g: | |
| iteration = checkpoint_g["iteration"] | |
| else: | |
| iteration = "NAN" | |
| generator = UNet_Full() | |
| # generator = torch.compile(generator)# torch.compile | |
| print(f"the generator weight is {loss} at iteration {iteration}") | |
| else: | |
| print("This weight is not supported") | |
| os._exit(0) | |
| # Handle torch.compile weight key rename | |
| old_keys = [key for key in weight] | |
| for old_key in old_keys: | |
| if old_key[:10] == "_orig_mod.": | |
| new_key = old_key[10:] | |
| weight[new_key] = weight[old_key] | |
| del weight[old_key] | |
| generator.load_state_dict(weight) | |
| generator = generator.eval().cuda() | |
| # Print options to show what kinds of setting is used | |
| if print_options: | |
| if 'opt' in checkpoint_g: | |
| for key in checkpoint_g['opt']: | |
| value = checkpoint_g['opt'][key] | |
| print(f'{key} : {value}') | |
| return generator | |
| def load_grl(generator_weight_PATH, scale=4): | |
| ''' A simpler API to load GRL model | |
| Args: | |
| generator_weight_PATH (str): The path to the weight | |
| scale (int): Scale Factor (Usually Set as 4) | |
| Returns: | |
| generator (torch): the generator instance of the model | |
| ''' | |
| # Load the checkpoint | |
| checkpoint_g = torch.load(generator_weight_PATH) | |
| # Find the generator weight | |
| if 'model_state_dict' in checkpoint_g: | |
| weight = checkpoint_g['model_state_dict'] | |
| # GRL tiny model (Note: tiny2 version) | |
| generator = GRL( | |
| upscale = scale, | |
| img_size = 64, | |
| window_size = 8, | |
| depths = [4, 4, 4, 4], | |
| embed_dim = 64, | |
| num_heads_window = [2, 2, 2, 2], | |
| num_heads_stripe = [2, 2, 2, 2], | |
| mlp_ratio = 2, | |
| qkv_proj_type = "linear", | |
| anchor_proj_type = "avgpool", | |
| anchor_window_down_factor = 2, | |
| out_proj_type = "linear", | |
| conv_type = "1conv", | |
| upsampler = "nearest+conv", # Change | |
| ).cuda() | |
| else: | |
| print("This weight is not supported") | |
| os._exit(0) | |
| generator.load_state_dict(weight) | |
| generator = generator.eval().cuda() | |
| num_params = 0 | |
| for p in generator.parameters(): | |
| if p.requires_grad: | |
| num_params += p.numel() | |
| print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") | |
| return generator | |
| def load_dat(generator_weight_PATH, scale=4): | |
| # Load the checkpoint | |
| checkpoint_g = torch.load(generator_weight_PATH) | |
| # Find the generator weight | |
| if 'model_state_dict' in checkpoint_g: | |
| weight = checkpoint_g['model_state_dict'] | |
| # DAT small model in default | |
| generator = DAT(upscale = 4, | |
| in_chans = 3, | |
| img_size = 64, | |
| img_range = 1., | |
| depth = [6, 6, 6, 6, 6, 6], | |
| embed_dim = 180, | |
| num_heads = [6, 6, 6, 6, 6, 6], | |
| expansion_factor = 2, | |
| resi_connection = '1conv', | |
| split_size = [8, 16], | |
| upsampler = 'pixelshuffledirect', | |
| ).cuda() | |
| else: | |
| print("This weight is not supported") | |
| os._exit(0) | |
| generator.load_state_dict(weight) | |
| generator = generator.eval().cuda() | |
| num_params = 0 | |
| for p in generator.parameters(): | |
| if p.requires_grad: | |
| num_params += p.numel() | |
| print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") | |
| return generator | |