Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| class BaseModel(torch.nn.Module): | |
| def modify_commandline_options(parser, is_train): | |
| return parser | |
| def __init__(self, opt): | |
| super().__init__() | |
| self.opt = opt | |
| self.device = torch.device('cuda:0') if opt.num_gpus > 0 else torch.device('cpu') | |
| def initialize(self): | |
| pass | |
| def per_gpu_initialize(self): | |
| pass | |
| def compute_generator_losses(self, data_i): | |
| return {} | |
| def compute_discriminator_losses(self, data_i): | |
| return {} | |
| def get_visuals_for_snapshot(self, data_i): | |
| return {} | |
| def get_parameters_for_mode(self, mode): | |
| return {} | |
| def save(self, total_steps_so_far): | |
| savedir = os.path.join(self.opt.checkpoints_dir, self.opt.name) | |
| checkpoint_name = "%dk_checkpoint.pth" % (total_steps_so_far // 1000) | |
| savepath = os.path.join(savedir, checkpoint_name) | |
| torch.save(self.state_dict(), savepath) | |
| sympath = os.path.join(savedir, "latest_checkpoint.pth") | |
| if os.path.exists(sympath): | |
| os.remove(sympath) | |
| os.symlink(checkpoint_name, sympath) | |
| def load(self): | |
| if self.opt.isTrain and self.opt.pretrained_name is not None: | |
| loaddir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) | |
| else: | |
| loaddir = os.path.join(self.opt.checkpoints_dir, self.opt.name) | |
| checkpoint_name = "%s_checkpoint.pth" % self.opt.resume_iter | |
| checkpoint_path = os.path.join(loaddir, checkpoint_name) | |
| if not os.path.exists(checkpoint_path): | |
| print("\n\ncheckpoint %s does not exist!" % checkpoint_path) | |
| assert self.opt.isTrain, "In test mode, the checkpoint file must exist" | |
| print("Training will start from scratch") | |
| return | |
| state_dict = torch.load(checkpoint_path, | |
| map_location=str(self.device)) | |
| # self.load_state_dict(state_dict) | |
| own_state = self.state_dict() | |
| skip_all = False | |
| for name, own_param in own_state.items(): | |
| if not self.opt.isTrain and (name.startswith("D.") or name.startswith("Dpatch.")): | |
| continue | |
| if name not in state_dict: | |
| print("Key %s does not exist in checkpoint. Skipping..." % name) | |
| continue | |
| # if name.startswith("C.net"): | |
| # continue | |
| param = state_dict[name] | |
| if own_param.shape != param.shape: | |
| message = "Key [%s]: Shape does not match the created model (%s) and loaded checkpoint (%s)" % (name, str(own_param.shape), str(param.shape)) | |
| if skip_all: | |
| print(message) | |
| min_shape = [min(s1, s2) for s1, s2 in zip(own_param.shape, param.shape)] | |
| ms = min_shape | |
| if len(min_shape) == 1: | |
| own_param[:ms[0]].copy_(param[:ms[0]]) | |
| own_param[ms[0]:].copy_(own_param[ms[0]:] * 0) | |
| elif len(min_shape) == 2: | |
| own_param[:ms[0], :ms[1]].copy_(param[:ms[0], :ms[1]]) | |
| own_param[ms[0]:, ms[1]:].copy_(own_param[ms[0]:, ms[1]:] * 0) | |
| elif len(ms) == 4: | |
| own_param[:ms[0], :ms[1], :ms[2], :ms[3]].copy_(param[:ms[0], :ms[1], :ms[2], :ms[3]]) | |
| own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:].copy_(own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:] * 0) | |
| else: | |
| print("Skipping min_shape of %s" % str(ms)) | |
| continue | |
| userinput = input("%s. Force loading? (yes, no, all) " % (message)) | |
| if userinput.lower() == "yes": | |
| pass | |
| elif userinput.lower() == "no": | |
| #assert own_param.shape == param.shape | |
| continue | |
| elif userinput.lower() == "all": | |
| skip_all = True | |
| else: | |
| raise ValueError(userinput) | |
| min_shape = [min(s1, s2) for s1, s2 in zip(own_param.shape, param.shape)] | |
| ms = min_shape | |
| if len(min_shape) == 1: | |
| own_param[:ms[0]].copy_(param[:ms[0]]) | |
| own_param[ms[0]:].copy_(own_param[ms[0]:] * 0) | |
| elif len(min_shape) == 2: | |
| own_param[:ms[0], :ms[1]].copy_(param[:ms[0], :ms[1]]) | |
| own_param[ms[0]:, ms[1]:].copy_(own_param[ms[0]:, ms[1]:] * 0) | |
| elif len(ms) == 4: | |
| own_param[:ms[0], :ms[1], :ms[2], :ms[3]].copy_(param[:ms[0], :ms[1], :ms[2], :ms[3]]) | |
| own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:].copy_(own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:] * 0) | |
| else: | |
| print("Skipping min_shape of %s" % str(ms)) | |
| continue | |
| own_param.copy_(param) | |
| print("checkpoint loaded from %s" % os.path.join(loaddir, checkpoint_name)) | |
| def forward(self, *args, command=None, **kwargs): | |
| """ wrapper for multigpu training. BaseModel is expected to be | |
| wrapped in nn.parallel.DataParallel, which distributes its call to | |
| the BaseModel instance on each GPU """ | |
| if command is not None: | |
| method = getattr(self, command) | |
| assert callable(method), "[%s] is not a method of %s" % (command, type(self).__name__) | |
| return method(*args, **kwargs) | |
| else: | |
| raise ValueError(command) | |