Spaces:
Runtime error
Runtime error
| diff --git a/src/gan_control/inference/controller.py b/src/gan_control/inference/controller.py | |
| index ee464ba..d1907dd 100644 | |
| --- a/src/gan_control/inference/controller.py | |
| +++ b/src/gan_control/inference/controller.py | |
| class Controller(Inference): | |
| - def __init__(self, controller_dir): | |
| + def __init__(self, controller_dir, device): | |
| _log.info('Init Controller class...') | |
| - super(Controller, self).__init__(os.path.join(controller_dir, 'generator')) | |
| + super(Controller, self).__init__(os.path.join(controller_dir, 'generator'), device) | |
| self.fc_controls = {} | |
| self.config_controls = {} | |
| for sub_group_name in self.batch_utils.sub_group_names: | |
| def gen_batch_by_controls(self, batch_size=1, latent=None, normalize=True, input_is_latent=False, static_noise=True, **kwargs): | |
| if latent is None: | |
| - latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda') | |
| + latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device) | |
| latent = latent.clone() | |
| if input_is_latent: | |
| latent_w = latent | |
| else: | |
| if isinstance(self.model, torch.nn.DataParallel): | |
| - latent_w = self.model.module.style(latent.cuda()) | |
| + latent_w = self.model.module.style(latent.to(self.device)) | |
| else: | |
| - latent_w = self.model.style(latent.cuda()) | |
| + latent_w = self.model.style(latent.to(self.device)) | |
| for group_key in kwargs.keys(): | |
| if self.check_if_group_has_control(group_key): | |
| if group_key == 'expression' and kwargs[group_key].shape[1] == 8: | |
| - group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].cuda().float()) | |
| + group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].to(self.device).float()) | |
| else: | |
| - group_w_latent = self.fc_controls[group_key](kwargs[group_key].cuda().float()) | |
| + group_w_latent = self.fc_controls[group_key](kwargs[group_key].to(self.device).float()) | |
| latent_w = self.insert_group_w_latent(latent_w, group_w_latent, group_key) | |
| injection_noise = None | |
| if static_noise: | |
| ckpt_path = ckpt_list[-1] | |
| ckpt_iter = ckpt_path.split('.')[0] | |
| config = read_json(config_path, return_obj=True) | |
| - ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path)) | |
| + ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=self.device) | |
| group_chunk = self.batch_utils.place_in_latent_dict[sub_group_name if sub_group_name is not 'expression_q' else 'expression'] | |
| group_latent_size = group_chunk[1] - group_chunk[0] | |
| _log.info('Init %s Controller...' % sub_group_name) | |
| - controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).cuda() | |
| + controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).to(self.device) | |
| controller.print() | |
| _log.info('Loading Controller: %s, ckpt iter %s' % (controller_dir_path, ckpt_iter)) | |
| diff --git a/src/gan_control/inference/inference.py b/src/gan_control/inference/inference.py | |
| index e6ccedb..4393bb7 100644 | |
| --- a/src/gan_control/inference/inference.py | |
| +++ b/src/gan_control/inference/inference.py | |
| class Inference(): | |
| - def __init__(self, model_dir): | |
| + def __init__(self, model_dir, device): | |
| _log.info('Init inference class...') | |
| self.model_dir = model_dir | |
| - self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir) | |
| + self.device = device | |
| + self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir, device) | |
| self.noise = None | |
| self.reset_noise() | |
| self.mean_w_latent = None | |
| _log.info('Calc mean_w_latents...') | |
| mean_latent_w_list = [] | |
| for i in range(100): | |
| - latent_z = torch.randn(1000, self.config.model_config['latent_size'], device='cuda') | |
| + latent_z = torch.randn(1000, self.config.model_config['latent_size'], device=self.device) | |
| if isinstance(self.model, torch.nn.DataParallel): | |
| latent_w = self.model.module.style(latent_z).cpu() | |
| else: | |
| def reset_noise(self): | |
| if isinstance(self.model, torch.nn.DataParallel): | |
| - self.noise = self.model.module.make_noise(device='cuda') | |
| + self.noise = self.model.module.make_noise(device=self.device) | |
| else: | |
| - self.noise = self.model.make_noise(device='cuda') | |
| + self.noise = self.model.make_noise(device=self.device) | |
| def expend_noise(noise, batch_size): | |
| self.calc_mean_w_latents() | |
| injection_noise = None | |
| if latent is None: | |
| - latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda') | |
| + latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device) | |
| elif input_is_latent: | |
| - latent = latent.cuda() | |
| + latent = latent.to(self.device) | |
| for group_key in kwargs.keys(): | |
| if group_key not in self.batch_utils.sub_group_names: | |
| raise ValueError('group_key: %s not in sub_group_names %s' % (group_key, str(self.batch_utils.sub_group_names))) | |
| if isinstance(kwargs[group_key], str) and kwargs[group_key] == 'random': | |
| - group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device='cuda')) | |
| + group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device=self.device)) | |
| group_latent_w = group_latent_w[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] | |
| latent[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] = group_latent_w | |
| if static_noise: | |
| latent[:, place_in_latent[0]: place_in_latent[1]] = \ | |
| truncation * (latent[:, place_in_latent[0]: place_in_latent[1]] - torch.cat( | |
| [self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0 | |
| - ).cuda()) + torch.cat( | |
| + ).to(self.device)) + torch.cat( | |
| [self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0 | |
| - ).cuda() | |
| + ).to(self.device) | |
| - tensor, latent_w = self.model([latent.cuda()], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise) | |
| + tensor, latent_w = self.model([latent.to(self.device)], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise) | |
| if normalize: | |
| tensor = tensor.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu() | |
| return tensor, latent, latent_w | |
| return grid_image | |
| - def retrieve_model(model_dir): | |
| + def retrieve_model(model_dir, device): | |
| config_path = os.path.join(model_dir, 'args.json') | |
| _log.info('Retrieve config from %s' % config_path) | |
| ckpt_path = ckpt_list[-1] | |
| ckpt_iter = ckpt_path.split('.')[0] | |
| config = read_json(config_path, return_obj=True) | |
| - ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path)) | |
| + ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=device) | |
| batch_utils = None | |
| if not config.model_config['vanilla']: | |
| fc_config=None if config.model_config['vanilla'] else batch_utils.get_fc_config(), | |
| conv_transpose=config.model_config['conv_transpose'], | |
| noise_mode=config.model_config['g_noise_mode'] | |
| - ).cuda() | |
| + ).to(device) | |
| _log.info('Loading Model: %s, ckpt iter %s' % (model_dir, ckpt_iter)) | |
| model.load_state_dict(ckpt['g_ema']) | |
| model = torch.nn.DataParallel(model) | |