Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from abc import ABCMeta, abstractmethod | |
from typing import Any, Dict, Union | |
from torch.utils.data import DataLoader | |
class BaseLoop(metaclass=ABCMeta): | |
"""Base loop class. | |
All subclasses inherited from ``BaseLoop`` should overwrite the | |
:meth:`run` method. | |
Args: | |
runner (Runner): A reference of runner. | |
dataloader (Dataloader or dict): An iterator to generate one batch of | |
dataset each iteration. | |
""" | |
def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: | |
self._runner = runner | |
if isinstance(dataloader, dict): | |
# Determine whether or not different ranks use different seed. | |
diff_rank_seed = runner._randomness_cfg.get( | |
'diff_rank_seed', False) | |
self.dataloader = runner.build_dataloader( | |
dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) | |
else: | |
self.dataloader = dataloader | |
def runner(self): | |
return self._runner | |
def run(self) -> Any: | |
"""Execute loop.""" | |