# Copyright (c) Meta, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import defaultdict from contextlib import contextmanager import math import os import tempfile import typing as tp import torch from torch.nn import functional as F from torch.utils.data import Subset def unfold(a, kernel_size, stride): """Given input of size [*OT, T], output Tensor of size [*OT, F, K] with K the kernel size, by extracting frames with the given stride. This will pad the input so that `F = ceil(T / K)`. see https://github.com/pytorch/pytorch/issues/60466 """ *shape, length = a.shape n_frames = math.ceil(length / stride) tgt_length = (n_frames - 1) * stride + kernel_size a = F.pad(a, (0, tgt_length - length)) strides = list(a.stride()) assert strides[-1] == 1, 'data should be contiguous' strides = strides[:-1] + [stride, 1] return a.as_strided([*shape, n_frames, kernel_size], strides) def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): """ Center trim `tensor` with respect to `reference`, along the last dimension. `reference` can also be a number, representing the length to trim to. If the size difference != 0 mod 2, the extra sample is removed on the right side. """ ref_size: int if isinstance(reference, torch.Tensor): ref_size = reference.size(-1) else: ref_size = reference delta = tensor.size(-1) - ref_size if delta < 0: raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") if delta: tensor = tensor[..., delta // 2:-(delta - delta // 2)] return tensor def pull_metric(history: tp.List[dict], name: str): out = [] for metrics in history: metric = metrics for part in name.split("."): metric = metric[part] out.append(metric) return out def EMA(beta: float = 1): """ Exponential Moving Average callback. Returns a single function that can be called to repeatidly update the EMA with a dict of metrics. The callback will return the new averaged dict of metrics. Note that for `beta=1`, this is just plain averaging. """ fix: tp.Dict[str, float] = defaultdict(float) total: tp.Dict[str, float] = defaultdict(float) def _update(metrics: dict, weight: float = 1) -> dict: nonlocal total, fix for key, value in metrics.items(): total[key] = total[key] * beta + weight * float(value) fix[key] = fix[key] * beta + weight return {key: tot / fix[key] for key, tot in total.items()} return _update def sizeof_fmt(num: float, suffix: str = 'B'): """ Given `num` bytes, return human readable size. Taken from https://stackoverflow.com/a/1094933 """ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: if abs(num) < 1024.0: return "%3.1f%s%s" % (num, unit, suffix) num /= 1024.0 return "%.1f%s%s" % (num, 'Yi', suffix) @contextmanager def temp_filenames(count: int, delete=True): names = [] try: for _ in range(count): names.append(tempfile.NamedTemporaryFile(delete=False).name) yield names finally: if delete: for name in names: os.unlink(name) def random_subset(dataset, max_samples: int, seed: int = 42): if max_samples >= len(dataset): return dataset generator = torch.Generator().manual_seed(seed) perm = torch.randperm(len(dataset), generator=generator) return Subset(dataset, perm[:max_samples].tolist()) class DummyPoolExecutor: class DummyResult: def __init__(self, func, *args, **kwargs): self.func = func self.args = args self.kwargs = kwargs def result(self): return self.func(*self.args, **self.kwargs) def __init__(self, workers=0): pass def submit(self, func, *args, **kwargs): return DummyPoolExecutor.DummyResult(func, *args, **kwargs) def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_tb): return