Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional | |
| import torch | |
| import torch.nn as nn | |
| from mmpretrain.registry import MODELS | |
| from mmpretrain.structures import DataSample | |
| from .base import BaseClassifier | |
| class ImageClassifier(BaseClassifier): | |
| """Image classifiers for supervised classification task. | |
| Args: | |
| backbone (dict): The backbone module. See | |
| :mod:`mmpretrain.models.backbones`. | |
| neck (dict, optional): The neck module to process features from | |
| backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. | |
| head (dict, optional): The head module to do prediction and calculate | |
| loss from processed features. See :mod:`mmpretrain.models.heads`. | |
| Notice that if the head is not set, almost all methods cannot be | |
| used except :meth:`extract_feat`. Defaults to None. | |
| pretrained (str, optional): The pretrained checkpoint path, support | |
| local path and remote path. Defaults to None. | |
| train_cfg (dict, optional): The training setting. The acceptable | |
| fields are: | |
| - augments (List[dict]): The batch augmentation methods to use. | |
| More details can be found in | |
| :mod:`mmpretrain.model.utils.augment`. | |
| - probs (List[float], optional): The probability of every batch | |
| augmentation methods. If None, choose evenly. Defaults to None. | |
| Defaults to None. | |
| data_preprocessor (dict, optional): The config for preprocessing input | |
| data. If None or no specified type, it will use | |
| "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for | |
| more details. Defaults to None. | |
| init_cfg (dict, optional): the config to control the initialization. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| backbone: dict, | |
| neck: Optional[dict] = None, | |
| head: Optional[dict] = None, | |
| pretrained: Optional[str] = None, | |
| train_cfg: Optional[dict] = None, | |
| data_preprocessor: Optional[dict] = None, | |
| init_cfg: Optional[dict] = None): | |
| if pretrained is not None: | |
| init_cfg = dict(type='Pretrained', checkpoint=pretrained) | |
| data_preprocessor = data_preprocessor or {} | |
| if isinstance(data_preprocessor, dict): | |
| data_preprocessor.setdefault('type', 'ClsDataPreprocessor') | |
| data_preprocessor.setdefault('batch_augments', train_cfg) | |
| data_preprocessor = MODELS.build(data_preprocessor) | |
| elif not isinstance(data_preprocessor, nn.Module): | |
| raise TypeError('data_preprocessor should be a `dict` or ' | |
| f'`nn.Module` instance, but got ' | |
| f'{type(data_preprocessor)}') | |
| super(ImageClassifier, self).__init__( | |
| init_cfg=init_cfg, data_preprocessor=data_preprocessor) | |
| if not isinstance(backbone, nn.Module): | |
| backbone = MODELS.build(backbone) | |
| if neck is not None and not isinstance(neck, nn.Module): | |
| neck = MODELS.build(neck) | |
| if head is not None and not isinstance(head, nn.Module): | |
| head = MODELS.build(head) | |
| self.backbone = backbone | |
| self.neck = neck | |
| self.head = head | |
| # If the model needs to load pretrain weights from a third party, | |
| # the key can be modified with this hook | |
| if hasattr(self.backbone, '_checkpoint_filter'): | |
| self._register_load_state_dict_pre_hook( | |
| self.backbone._checkpoint_filter) | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| data_samples: Optional[List[DataSample]] = None, | |
| mode: str = 'tensor'): | |
| """The unified entry for a forward process in both training and test. | |
| The method should accept three modes: "tensor", "predict" and "loss": | |
| - "tensor": Forward the whole network and return tensor(s) without any | |
| post-processing, same as a common PyTorch Module. | |
| - "predict": Forward and return the predictions, which are fully | |
| processed to a list of :obj:`DataSample`. | |
| - "loss": Forward and return a dict of losses according to the given | |
| inputs and data samples. | |
| Args: | |
| inputs (torch.Tensor): The input tensor with shape | |
| (N, C, ...) in general. | |
| data_samples (List[DataSample], optional): The annotation | |
| data of every samples. It's required if ``mode="loss"``. | |
| Defaults to None. | |
| mode (str): Return what kind of value. Defaults to 'tensor'. | |
| Returns: | |
| The return type depends on ``mode``. | |
| - If ``mode="tensor"``, return a tensor or a tuple of tensor. | |
| - If ``mode="predict"``, return a list of | |
| :obj:`mmpretrain.structures.DataSample`. | |
| - If ``mode="loss"``, return a dict of tensor. | |
| """ | |
| if mode == 'tensor': | |
| feats = self.extract_feat(inputs) | |
| return self.head(feats) if self.with_head else feats | |
| elif mode == 'loss': | |
| return self.loss(inputs, data_samples) | |
| elif mode == 'predict': | |
| return self.predict(inputs, data_samples) | |
| else: | |
| raise RuntimeError(f'Invalid mode "{mode}".') | |
| def extract_feat(self, inputs, stage='neck'): | |
| """Extract features from the input tensor with shape (N, C, ...). | |
| Args: | |
| inputs (Tensor): A batch of inputs. The shape of it should be | |
| ``(num_samples, num_channels, *img_shape)``. | |
| stage (str): Which stage to output the feature. Choose from: | |
| - "backbone": The output of backbone network. Returns a tuple | |
| including multiple stages features. | |
| - "neck": The output of neck module. Returns a tuple including | |
| multiple stages features. | |
| - "pre_logits": The feature before the final classification | |
| linear layer. Usually returns a tensor. | |
| Defaults to "neck". | |
| Returns: | |
| tuple | Tensor: The output of specified stage. | |
| The output depends on detailed implementation. In general, the | |
| output of backbone and neck is a tuple and the output of | |
| pre_logits is a tensor. | |
| Examples: | |
| 1. Backbone output | |
| >>> import torch | |
| >>> from mmengine import Config | |
| >>> from mmpretrain.models import build_classifier | |
| >>> | |
| >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model | |
| >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps | |
| >>> model = build_classifier(cfg) | |
| >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') | |
| >>> for out in outs: | |
| ... print(out.shape) | |
| torch.Size([1, 64, 56, 56]) | |
| torch.Size([1, 128, 28, 28]) | |
| torch.Size([1, 256, 14, 14]) | |
| torch.Size([1, 512, 7, 7]) | |
| 2. Neck output | |
| >>> import torch | |
| >>> from mmengine import Config | |
| >>> from mmpretrain.models import build_classifier | |
| >>> | |
| >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model | |
| >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps | |
| >>> model = build_classifier(cfg) | |
| >>> | |
| >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') | |
| >>> for out in outs: | |
| ... print(out.shape) | |
| torch.Size([1, 64]) | |
| torch.Size([1, 128]) | |
| torch.Size([1, 256]) | |
| torch.Size([1, 512]) | |
| 3. Pre-logits output (without the final linear classifier head) | |
| >>> import torch | |
| >>> from mmengine import Config | |
| >>> from mmpretrain.models import build_classifier | |
| >>> | |
| >>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model | |
| >>> model = build_classifier(cfg) | |
| >>> | |
| >>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') | |
| >>> print(out.shape) # The hidden dims in head is 3072 | |
| torch.Size([1, 3072]) | |
| """ # noqa: E501 | |
| assert stage in ['backbone', 'neck', 'pre_logits'], \ | |
| (f'Invalid output stage "{stage}", please choose from "backbone", ' | |
| '"neck" and "pre_logits"') | |
| x = self.backbone(inputs) | |
| if stage == 'backbone': | |
| return x | |
| if self.with_neck: | |
| x = self.neck(x) | |
| if stage == 'neck': | |
| return x | |
| assert self.with_head and hasattr(self.head, 'pre_logits'), \ | |
| "No head or the head doesn't implement `pre_logits` method." | |
| return self.head.pre_logits(x) | |
| def loss(self, inputs: torch.Tensor, | |
| data_samples: List[DataSample]) -> dict: | |
| """Calculate losses from a batch of inputs and data samples. | |
| Args: | |
| inputs (torch.Tensor): The input tensor with shape | |
| (N, C, ...) in general. | |
| data_samples (List[DataSample]): The annotation data of | |
| every samples. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components | |
| """ | |
| feats = self.extract_feat(inputs) | |
| return self.head.loss(feats, data_samples) | |
| def predict(self, | |
| inputs: torch.Tensor, | |
| data_samples: Optional[List[DataSample]] = None, | |
| **kwargs) -> List[DataSample]: | |
| """Predict results from a batch of inputs. | |
| Args: | |
| inputs (torch.Tensor): The input tensor with shape | |
| (N, C, ...) in general. | |
| data_samples (List[DataSample], optional): The annotation | |
| data of every samples. Defaults to None. | |
| **kwargs: Other keyword arguments accepted by the ``predict`` | |
| method of :attr:`head`. | |
| """ | |
| feats = self.extract_feat(inputs) | |
| return self.head.predict(feats, data_samples, **kwargs) | |
| def get_layer_depth(self, param_name: str): | |
| """Get the layer-wise depth of a parameter. | |
| Args: | |
| param_name (str): The name of the parameter. | |
| Returns: | |
| Tuple[int, int]: The layer-wise depth and the max depth. | |
| """ | |
| if hasattr(self.backbone, 'get_layer_depth'): | |
| return self.backbone.get_layer_depth(param_name, 'backbone.') | |
| else: | |
| raise NotImplementedError( | |
| f"The backbone {type(self.backbone)} doesn't " | |
| 'support `get_layer_depth` by now.') | |