Spaces:
Sleeping
Sleeping
import warnings | |
import torch | |
from torch import Tensor | |
def torch_linspace(start: Tensor, stop: Tensor, num: int) -> torch.Tensor: | |
""" | |
Copy-pasted from https://github.com/pytorch/pytorch/issues/61292 | |
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. | |
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. | |
""" | |
# create a tensor of 'num' steps from 0 to 1 | |
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) | |
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings | |
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript | |
# "cannot statically infer the expected size of a list in this contex", hence the code below | |
for i in range(start.ndim): | |
steps = steps.unsqueeze(-1) | |
# the output starts at 'start' and increments until 'stop' in each dimension | |
out = start[None] + steps * (stop - start)[None] | |
return out | |
def load_weights( | |
model: torch.nn.Module, checkpoint: dict, strict=True | |
) -> torch.nn.Module: | |
"""This function is used instead of the one provided by pytorch lightning | |
because for unexplained reasons, the pytorch lightning load function did | |
not behave as intended: loading several times from the same checkpoint | |
resulted in different loaded weight values... | |
Args: | |
model: a model in which new weights should be set | |
checkpoint: a loaded pytorch checkpoint (probably resulting from torch.load(filename)) | |
strict: Default to True, wether to fail if | |
""" | |
if not strict: | |
model_dict = model.state_dict() | |
pretrained_dict = { | |
k: v for k, v in checkpoint["state_dict"].items() if k in model_dict | |
} | |
diff1 = checkpoint["state_dict"].keys() - model_dict.keys() | |
if diff1: | |
warnings.warn( | |
f"Found keys {diff1} in checkpoint without any match in the model, ignoring corresponding values." | |
) | |
diff2 = model_dict.keys() - checkpoint["state_dict"].keys() | |
if diff2: | |
warnings.warn( | |
f"Missing keys {diff2} from the checkpoint, the corresponding weights will keep their initial values." | |
) | |
pretrained_dict = { | |
k: v for k, v in checkpoint["state_dict"].items() if k in model_dict | |
} | |
model_dict.update(pretrained_dict) | |
else: | |
model_dict = checkpoint["state_dict"] | |
model.load_state_dict(model_dict, strict=strict) | |
return model | |