File size: 4,858 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import collections
from random import Random
from typing import Dict, Iterable, Optional

import numpy as np
from infinibatch import iterators


EOL_SYMBOL = "</line>"
BOI_SYMBOL = "<image>"
EOI_SYMBOL = "</image>"


def apply_to_sample(f, sample):
    if hasattr(sample, "__len__") and len(sample) == 0:
        return {}

    def _apply(x):
        if isinstance(x, np.ndarray):
            return f(x)
        elif isinstance(x, collections.OrderedDict):
            # OrderedDict has attributes that needs to be preserved
            od = collections.OrderedDict(
                (key, _apply(value)) for key, value in x.items()
            )
            od.__dict__ = x.__dict__
            return od
        elif isinstance(x, dict):
            return {key: _apply(value) for key, value in x.items()}
        elif isinstance(x, list):
            return [_apply(x) for x in x]
        elif isinstance(x, tuple):
            return tuple(_apply(x) for x in x)
        elif isinstance(x, set):
            return {_apply(x) for x in x}
        else:
            return x

    return _apply(sample)


class NativeCheckpointableIterator(iterators.CheckpointableIterator):
    def __init__(self, iterable: Iterable):
        self._input_iterable = iterable
        self.setstate(None)

    def getstate(self) -> Dict:
        return {"num_items_yielded": self._num_items_yielded}

    def setstate(self, checkpoint: Optional[Dict]):
        self._iterator = iter(self._input_iterable)
        self._num_items_yielded = (
            iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
            if checkpoint is not None
            else 0
        )

    def __next__(self):
        item = next(self._iterator)
        self._num_items_yielded += 1
        return item

    def close(self):
        pass


class WeightIterator(object):
    def __init__(self, weights, seed):
        self.weights = weights
        self.seed = seed
        self.control_index = list(range(len(weights)))
        self.setstate(None)

    def __iter__(self):
        return self

    def getstate(self):
        return {"random_state": self._random_state}

    def setstate(self, checkpoint):
        self._random_state = checkpoint["random_state"] if checkpoint else None
        self._random = (
            None  # this will trigger the lazy initialization in self.__next__
        )

    def __next__(self):
        if self._random is None:
            self._random = Random(self.seed)
            if self._random_state is not None:
                self._random.setstate(self._random_state)
        idx = self._random.choices(self.control_index, self.weights)[0]
        self._random_state = self._random.getstate()
        return idx

    def close(self):
        pass


def safe_getattr(obj, k, default=None):
    """Returns obj[k] if it exists and is not None, otherwise returns default."""
    from omegaconf import OmegaConf

    if OmegaConf.is_config(obj):
        return obj[k] if k in obj and obj[k] is not None else default

    return getattr(obj, k, default)


def safe_hasattr(obj, k):
    """Returns True if the given key exists and is not None."""
    return getattr(obj, k, None) is not None


def image_code_to_token(code):
    return "<image{}>".format(code)


class ConcatIterator(iterators.CheckpointableIterator):
    """
    Concat items from all given iterators.
    """
    def __init__(self, source_iterators):
        """
        Args:
                source_iterators: list of iterators to zip, item by item
        """
        # TODO: Use all function?
        for source_iterator in source_iterators:
            if not isinstance(source_iterator, iterators.CheckpointableIterator):
                raise ValueError('all iterators in source_iterators have to be CheckpointableIterator')
        self._source_iterators = source_iterators        # type: List[CheckpointableIterator]

    def getstate(self):
        return {'input_states': tuple(iterator.getstate() for iterator in self._source_iterators)}

    def setstate(self, checkpoint):
        if checkpoint is None:
            for iterator in self._source_iterators:
                iterator.setstate(None)
        else:
            # TODO: Add check that both lists have the same length?
            for iterator, state in zip(self._source_iterators, checkpoint['input_states']):
                iterator.setstate(state)

    def __next__(self):
        res = {}    # (note: can't use a generator expression, as it gets confused when a next() call raises StopIteration)
        for iterator in self._source_iterators:
            res.update(next(iterator))
        return res
    
    def close(self):
        for it in self._source_iterators:
            it.close()