File size: 2,618 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import warnings

import torch
from torch import Tensor


@torch.jit.script
def torch_linspace(start: Tensor, stop: Tensor, num: int) -> torch.Tensor:
    """
    Copy-pasted from https://github.com/pytorch/pytorch/issues/61292
    Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
    Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
    """
    # create a tensor of 'num' steps from 0 to 1
    steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)

    # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
    # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
    #   "cannot statically infer the expected size of a list in this contex", hence the code below
    for i in range(start.ndim):
        steps = steps.unsqueeze(-1)

    # the output starts at 'start' and increments until 'stop' in each dimension
    out = start[None] + steps * (stop - start)[None]

    return out


def load_weights(
    model: torch.nn.Module, checkpoint: dict, strict=True
) -> torch.nn.Module:
    """This function is used instead of the one provided by pytorch lightning
     because for unexplained reasons, the pytorch lightning load function did
     not behave as intended: loading several times from the same checkpoint
     resulted in different loaded weight values...

    Args:
        model: a model in which new weights should be set
        checkpoint: a loaded pytorch checkpoint (probably resulting from torch.load(filename))
        strict: Default to True, wether to fail if

    """
    if not strict:
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v for k, v in checkpoint["state_dict"].items() if k in model_dict
        }
        diff1 = checkpoint["state_dict"].keys() - model_dict.keys()
        if diff1:
            warnings.warn(
                f"Found keys {diff1} in checkpoint without any match in the model, ignoring corresponding values."
            )
        diff2 = model_dict.keys() - checkpoint["state_dict"].keys()
        if diff2:
            warnings.warn(
                f"Missing keys {diff2} from the checkpoint, the corresponding weights will keep their initial values."
            )
        pretrained_dict = {
            k: v for k, v in checkpoint["state_dict"].items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
    else:
        model_dict = checkpoint["state_dict"]

    model.load_state_dict(model_dict, strict=strict)
    return model