File size: 2,657 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import collections
from itertools import repeat
from typing import List, Dict, Any

__all__ = ['consume_prefix_in_state_dict_if_present']


def _ntuple(n, name="parse"):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return tuple(x)
        return tuple(repeat(x, n))

    parse.__name__ = name
    return parse


_single = _ntuple(1, "_single")
_pair = _ntuple(2, "_pair")
_triple = _ntuple(3, "_triple")
_quadruple = _ntuple(4, "_quadruple")


def _reverse_repeat_tuple(t, n):
    r"""Reverse the order of `t` and repeat each element for `n` times.



    This can be used to translate padding arg used by Conv and Pooling modules

    to the ones used by `F.pad`.

    """
    return tuple(x for x in reversed(t) for _ in range(n))


def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
    import torch
    if isinstance(out_size, (int, torch.SymInt)):
        return out_size
    if len(defaults) <= len(out_size):
        raise ValueError(
            f"Input dimension should be at least {len(out_size) + 1}"
        )
    return [
        v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
    ]


def consume_prefix_in_state_dict_if_present(

    state_dict: Dict[str, Any], prefix: str

) -> None:
    r"""Strip the prefix in state_dict in place, if any.



    ..note::

        Given a `state_dict` from a DP/DDP model, a local model can load it by applying

        `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling

        :meth:`torch.nn.Module.load_state_dict`.



    Args:

        state_dict (OrderedDict): a state-dict to be loaded to the model.

        prefix (str): prefix.

    """
    keys = list(state_dict.keys())
    for key in keys:
        if key.startswith(prefix):
            newkey = key[len(prefix) :]
            state_dict[newkey] = state_dict.pop(key)

    # also strip the prefix in metadata if any.
    if hasattr(state_dict, "_metadata"):
        keys = list(state_dict._metadata.keys())
        for key in keys:
            # for the metadata dict, the key can be:
            # '': for the DDP module, which we want to remove.
            # 'module': for the actual model.
            # 'module.xx.xx': for the rest.
            if len(key) == 0:
                continue
            # handling both, 'module' case and  'module.' cases
            if key == prefix.replace('.', '') or key.startswith(prefix):
                newkey = key[len(prefix) :]
                state_dict._metadata[newkey] = state_dict._metadata.pop(key)