Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import logging | |
import os | |
import shutil | |
import tempfile | |
import time | |
from unittest import TestCase | |
from uuid import uuid4 | |
import torch | |
import torch.nn as nn | |
from torch.distributed import destroy_process_group | |
from torch.utils.data import Dataset | |
import mmengine.hooks # noqa F401 | |
import mmengine.optim # noqa F401 | |
from mmengine.config import Config | |
from mmengine.dist import is_distributed | |
from mmengine.evaluator import BaseMetric | |
from mmengine.logging import MessageHub, MMLogger | |
from mmengine.model import BaseModel | |
from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope | |
from mmengine.runner import Runner | |
from mmengine.visualization import Visualizer | |
class ToyModel(BaseModel): | |
def __init__(self, data_preprocessor=None): | |
super().__init__(data_preprocessor=data_preprocessor) | |
self.linear1 = nn.Linear(2, 2) | |
self.linear2 = nn.Linear(2, 1) | |
def forward(self, inputs, data_samples=None, mode='tensor'): | |
if isinstance(inputs, list): | |
inputs = torch.stack(inputs) | |
if isinstance(data_samples, list): | |
data_samples = torch.stack(data_samples) | |
outputs = self.linear1(inputs) | |
outputs = self.linear2(outputs) | |
if mode == 'tensor': | |
return outputs | |
elif mode == 'loss': | |
loss = (data_samples - outputs).sum() | |
outputs = dict(loss=loss) | |
return outputs | |
elif mode == 'predict': | |
return outputs | |
class ToyDataset(Dataset): | |
METAINFO = dict() # type: ignore | |
data = torch.randn(12, 2) | |
label = torch.ones(12) | |
def metainfo(self): | |
return self.METAINFO | |
def __len__(self): | |
return self.data.size(0) | |
def __getitem__(self, index): | |
return dict(inputs=self.data[index], data_samples=self.label[index]) | |
class ToyMetric(BaseMetric): | |
def __init__(self, collect_device='cpu', dummy_metrics=None): | |
super().__init__(collect_device=collect_device) | |
self.dummy_metrics = dummy_metrics | |
def process(self, data_batch, predictions): | |
result = {'acc': 1} | |
self.results.append(result) | |
def compute_metrics(self, results): | |
return dict(acc=1) | |
class RunnerTestCase(TestCase): | |
"""A test case to build runner easily. | |
`RunnerTestCase` will do the following things: | |
1. Registers a toy model, a toy metric, and a toy dataset, which can be | |
used to run the `Runner` successfully. | |
2. Provides epoch based and iteration based cfg to build runner. | |
3. Provides `build_runner` method to build runner easily. | |
4. Clean the global variable used by the runner. | |
""" | |
dist_cfg = dict( | |
MASTER_ADDR='127.0.0.1', | |
MASTER_PORT=29600, | |
RANK='0', | |
WORLD_SIZE='1', | |
LOCAL_RANK='0') | |
def setUp(self) -> None: | |
self.temp_dir = tempfile.TemporaryDirectory() | |
# Prevent from registering module with the same name by other unit | |
# test. These registries will be cleared in `tearDown` | |
MODELS.register_module(module=ToyModel, force=True) | |
METRICS.register_module(module=ToyMetric, force=True) | |
DATASETS.register_module(module=ToyDataset, force=True) | |
epoch_based_cfg = dict( | |
work_dir=self.temp_dir.name, | |
model=dict(type='ToyModel'), | |
train_dataloader=dict( | |
dataset=dict(type='ToyDataset'), | |
sampler=dict(type='DefaultSampler', shuffle=True), | |
batch_size=3, | |
num_workers=0), | |
val_dataloader=dict( | |
dataset=dict(type='ToyDataset'), | |
sampler=dict(type='DefaultSampler', shuffle=False), | |
batch_size=3, | |
num_workers=0), | |
val_evaluator=[dict(type='ToyMetric')], | |
test_dataloader=dict( | |
dataset=dict(type='ToyDataset'), | |
sampler=dict(type='DefaultSampler', shuffle=False), | |
batch_size=3, | |
num_workers=0), | |
test_evaluator=[dict(type='ToyMetric')], | |
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), | |
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), | |
val_cfg=dict(), | |
test_cfg=dict(), | |
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), | |
custom_hooks=[], | |
env_cfg=dict(dist_cfg=dict(backend='nccl')), | |
experiment_name='test1') | |
self.epoch_based_cfg = Config(epoch_based_cfg) | |
# prepare iter based cfg. | |
self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg) | |
self.iter_based_cfg.train_dataloader = dict( | |
dataset=dict(type='ToyDataset'), | |
sampler=dict(type='InfiniteSampler', shuffle=True), | |
batch_size=3, | |
num_workers=0) | |
self.iter_based_cfg.log_processor = dict(by_epoch=False) | |
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) | |
self.iter_based_cfg.default_hooks = dict( | |
logger=dict(type='LoggerHook', interval=1), | |
checkpoint=dict( | |
type='CheckpointHook', interval=12, by_epoch=False)) | |
def tearDown(self): | |
# `FileHandler` should be closed in Windows, otherwise we cannot | |
# delete the temporary directory | |
logging.shutdown() | |
MMLogger._instance_dict.clear() | |
Visualizer._instance_dict.clear() | |
DefaultScope._instance_dict.clear() | |
MessageHub._instance_dict.clear() | |
MODELS.module_dict.pop('ToyModel', None) | |
METRICS.module_dict.pop('ToyMetric', None) | |
DATASETS.module_dict.pop('ToyDataset', None) | |
self.temp_dir.cleanup() | |
if is_distributed(): | |
destroy_process_group() | |
def build_runner(self, cfg: Config): | |
cfg.experiment_name = self.experiment_name | |
runner = Runner.from_cfg(cfg) | |
return runner | |
def experiment_name(self): | |
# Since runners could be built too fast to have a unique experiment | |
# name(timestamp is the same), here we use uuid to make sure each | |
# runner has the unique experiment name. | |
return f'{self._testMethodName}_{time.time()} + ' \ | |
f'{uuid4()}' | |
def setup_dist_env(self): | |
self.dist_cfg['MASTER_PORT'] += 1 | |
os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT']) | |
os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR'] | |
os.environ['RANK'] = self.dist_cfg['RANK'] | |
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE'] | |
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK'] | |
def clear_work_dir(self): | |
logging.shutdown() | |
for filename in os.listdir(self.temp_dir.name): | |
filepath = os.path.join(self.temp_dir.name, filename) | |
if os.path.isfile(filepath): | |
os.remove(filepath) | |
else: | |
shutil.rmtree(filepath) | |