# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import collections from random import Random from typing import Dict, Iterable, Optional import numpy as np from infinibatch import iterators EOL_SYMBOL = "" BOI_SYMBOL = "" EOI_SYMBOL = "" def apply_to_sample(f, sample): if hasattr(sample, "__len__") and len(sample) == 0: return {} def _apply(x): if isinstance(x, np.ndarray): return f(x) elif isinstance(x, collections.OrderedDict): # OrderedDict has attributes that needs to be preserved od = collections.OrderedDict( (key, _apply(value)) for key, value in x.items() ) od.__dict__ = x.__dict__ return od elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): return [_apply(x) for x in x] elif isinstance(x, tuple): return tuple(_apply(x) for x in x) elif isinstance(x, set): return {_apply(x) for x in x} else: return x return _apply(sample) class NativeCheckpointableIterator(iterators.CheckpointableIterator): def __init__(self, iterable: Iterable): self._input_iterable = iterable self.setstate(None) def getstate(self) -> Dict: return {"num_items_yielded": self._num_items_yielded} def setstate(self, checkpoint: Optional[Dict]): self._iterator = iter(self._input_iterable) self._num_items_yielded = ( iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"]) if checkpoint is not None else 0 ) def __next__(self): item = next(self._iterator) self._num_items_yielded += 1 return item def close(self): pass class WeightIterator(object): def __init__(self, weights, seed): self.weights = weights self.seed = seed self.control_index = list(range(len(weights))) self.setstate(None) def __iter__(self): return self def getstate(self): return {"random_state": self._random_state} def setstate(self, checkpoint): self._random_state = checkpoint["random_state"] if checkpoint else None self._random = ( None # this will trigger the lazy initialization in self.__next__ ) def __next__(self): if self._random is None: self._random = Random(self.seed) if self._random_state is not None: self._random.setstate(self._random_state) idx = self._random.choices(self.control_index, self.weights)[0] self._random_state = self._random.getstate() return idx def close(self): pass def safe_getattr(obj, k, default=None): """Returns obj[k] if it exists and is not None, otherwise returns default.""" from omegaconf import OmegaConf if OmegaConf.is_config(obj): return obj[k] if k in obj and obj[k] is not None else default return getattr(obj, k, default) def safe_hasattr(obj, k): """Returns True if the given key exists and is not None.""" return getattr(obj, k, None) is not None def image_code_to_token(code): return "".format(code) class ConcatIterator(iterators.CheckpointableIterator): """ Concat items from all given iterators. """ def __init__(self, source_iterators): """ Args: source_iterators: list of iterators to zip, item by item """ # TODO: Use all function? for source_iterator in source_iterators: if not isinstance(source_iterator, iterators.CheckpointableIterator): raise ValueError('all iterators in source_iterators have to be CheckpointableIterator') self._source_iterators = source_iterators # type: List[CheckpointableIterator] def getstate(self): return {'input_states': tuple(iterator.getstate() for iterator in self._source_iterators)} def setstate(self, checkpoint): if checkpoint is None: for iterator in self._source_iterators: iterator.setstate(None) else: # TODO: Add check that both lists have the same length? for iterator, state in zip(self._source_iterators, checkpoint['input_states']): iterator.setstate(state) def __next__(self): res = {} # (note: can't use a generator expression, as it gets confused when a next() call raises StopIteration) for iterator in self._source_iterators: res.update(next(iterator)) return res def close(self): for it in self._source_iterators: it.close()