Spaces:
Build error
Build error
File size: 14,202 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import abstractmethod
from collections import OrderedDict
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.utils import is_list_of
from ..base_module import BaseModule
from .data_preprocessor import BaseDataPreprocessor
class BaseModel(BaseModule):
"""Base class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as
weights initialize, batch inputs preprocess(see more information in
:class:`BaseDataPreprocessor`), parse losses, and update model parameters.
Subclasses inherit from BaseModel only need to implement the forward
method, which implements the logic to calculate loss and predictions,
then can be trained in the runner.
Examples:
>>> @MODELS.register_module()
>>> class ToyModel(BaseModel):
>>>
>>> def __init__(self):
>>> super().__init__()
>>> self.backbone = nn.Sequential()
>>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5))
>>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2))
>>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5))
>>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120))
>>> self.backbone.add_module('fc2', nn.Linear(120, 84))
>>> self.backbone.add_module('fc3', nn.Linear(84, 10))
>>>
>>> self.criterion = nn.CrossEntropyLoss()
>>>
>>> def forward(self, batch_inputs, data_samples, mode='tensor'):
>>> data_samples = torch.stack(data_samples)
>>> if mode == 'tensor':
>>> return self.backbone(batch_inputs)
>>> elif mode == 'predict':
>>> feats = self.backbone(batch_inputs)
>>> predictions = torch.argmax(feats, 1)
>>> return predictions
>>> elif mode == 'loss':
>>> feats = self.backbone(batch_inputs)
>>> loss = self.criterion(feats, data_samples)
>>> return dict(loss=loss)
Args:
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
Attributes:
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
pre-processing data sampled by dataloader to the format accepted by
:meth:`forward`.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg)
if data_preprocessor is None:
data_preprocessor = dict(type='BaseDataPreprocessor')
if isinstance(data_preprocessor, nn.Module):
self.data_preprocessor = data_preprocessor
elif isinstance(data_preprocessor, dict):
self.data_preprocessor = MODELS.build(data_preprocessor)
else:
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Implements the default model training process including
preprocessing, model forward propagation, loss calculation,
optimization, and back-propagation.
During non-distributed training. If subclasses do not override the
:meth:`train_step`, :class:`EpochBasedTrainLoop` or
:class:`IterBasedTrainLoop` will call this method to update model
parameters. The default parameter update process is as follows:
1. Calls ``self.data_processor(data, training=False)`` to collect
batch_inputs and corresponding data_samples(labels).
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw
loss
3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to
backward and dict of loss tensor used to log messages.
4. Calls ``optim_wrapper.update_params(loss)`` to update model.
Args:
data (dict or tuple or list): Data sampled from dataset.
optim_wrapper (OptimWrapper): OptimWrapper instance
used to update model parameters.
Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
"""
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
data = self.data_preprocessor(data, True)
losses = self._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
optim_wrapper.update_params(parsed_losses)
return log_vars
def val_step(self, data: Union[tuple, dict, list]) -> list:
"""Gets the predictions of given data.
Calls ``self.data_preprocessor(data, False)`` and
``self(inputs, data_sample, mode='predict')`` in order. Return the
predictions which will be passed to evaluator.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def test_step(self, data: Union[dict, tuple, list]) -> list:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def parse_losses(
self, losses: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Parses the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: There are two elements. The first is the
loss tensor passed to optim_wrapper which may be a weighted sum
of all losses, and the second is log_vars which will be sent to
the logger.
"""
log_vars = []
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars.append([loss_name, loss_value.mean()])
elif is_list_of(loss_value, torch.Tensor):
log_vars.append(
[loss_name,
sum(_loss.mean() for _loss in loss_value)])
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(value for key, value in log_vars if 'loss' in key)
log_vars.insert(0, ['loss', loss])
log_vars = OrderedDict(log_vars) # type: ignore
return loss, log_vars # type: ignore
def to(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.
Returns:
nn.Module: The model itself.
"""
# Since Torch has not officially merged
# the npu-related fields, using the _parse_to function
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
import torch_npu
args = tuple([
list(args)[0].replace(
'npu', torch_npu.npu.native_device if hasattr(
torch_npu.npu, 'native_device') else 'privateuseone')
])
if kwargs and 'npu' in str(kwargs.get('device', '')):
import torch_npu
kwargs['device'] = kwargs['device'].replace(
'npu', torch_npu.npu.native_device if hasattr(
torch_npu.npu, 'native_device') else 'privateuseone')
device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
self._set_device(torch.device(device))
return super().to(*args, **kwargs)
def cuda(
self,
device: Optional[Union[int, str, torch.device]] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
additionally.
Returns:
nn.Module: The model itself.
"""
if device is None or isinstance(device, int):
device = torch.device('cuda', index=device)
self._set_device(torch.device(device))
return super().cuda(device)
def mlu(
self,
device: Union[int, str, torch.device, None] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.mlu`
additionally.
Returns:
nn.Module: The model itself.
"""
device = torch.device('mlu', torch.mlu.current_device())
self._set_device(device)
return super().mlu()
def npu(
self,
device: Union[int, str, torch.device, None] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.npu`
additionally.
Returns:
nn.Module: The model itself.
Note:
This generation of NPU(Ascend910) does not support
the use of multiple cards in a single process,
so the index here needs to be consistent with the default device
"""
device = torch.npu.current_device()
self._set_device(device)
return super().npu()
def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
additionally.
Returns:
nn.Module: The model itself.
"""
self._set_device(torch.device('cpu'))
return super().cpu()
def _set_device(self, device: torch.device) -> None:
"""Recursively set device for `BaseDataPreprocessor` instance.
Args:
device (torch.device): the desired device of the parameters and
buffers in this module.
"""
def apply_fn(module):
if not isinstance(module, BaseDataPreprocessor):
return
if device is not None:
module._device = device
self.apply(apply_fn)
@abstractmethod
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:
"""Returns losses or predictions of training, validation, testing, and
simple inference process.
``forward`` method of BaseModel is an abstract method, its subclasses
must implement this method.
Accepts ``batch_inputs`` and ``data_sample`` processed by
:attr:`data_preprocessor`, and returns results according to mode
arguments.
During non-distributed training, validation, and testing process,
``forward`` will be called by ``BaseModel.train_step``,
``BaseModel.val_step`` and ``BaseModel.test_step`` directly.
During distributed data parallel training process,
``MMSeparateDistributedDataParallel.train_step`` will first call
``DistributedDataParallel.forward`` to enable automatic
gradient synchronization, and then call ``forward`` to get training
loss.
Args:
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (list, optional):
data samples collated by :attr:`data_preprocessor`.
mode (str): mode should be one of ``loss``, ``predict`` and
``tensor``
- ``loss``: Called by ``train_step`` and return loss ``dict``
used for logging
- ``predict``: Called by ``val_step`` and ``test_step``
and return list of results used for computing metric.
- ``tensor``: Called by custom use to get ``Tensor`` type
results.
Returns:
dict or list:
- If ``mode == loss``, return a ``dict`` of loss tensor used
for backward and logging.
- If ``mode == predict``, return a ``list`` of inference
results.
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
or ``dict`` of tensor for custom use.
"""
def _run_forward(self, data: Union[dict, tuple, list],
mode: str) -> Union[Dict[str, torch.Tensor], list]:
"""Unpacks data for :meth:`forward`
Args:
data (dict or tuple or list): Data sampled from dataset.
mode (str): Mode of forward.
Returns:
dict or list: Results of training or testing mode.
"""
if isinstance(data, dict):
results = self(**data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self(*data, mode=mode)
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list, tuple or dict, but got {type(data)}')
return results
|