|
""" |
|
https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py |
|
|
|
by lyuwenyu |
|
""" |
|
|
|
from collections import OrderedDict |
|
from typing import Dict, List |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
class IntermediateLayerGetter(nn.ModuleDict): |
|
""" |
|
Module wrapper that returns intermediate layers from a model |
|
|
|
It has a strong assumption that the modules have been registered |
|
into the model in the same order as they are used. |
|
This means that one should **not** reuse the same nn.Module |
|
twice in the forward if you want this to work. |
|
|
|
Additionally, it is only able to query submodules that are directly |
|
assigned to the model. So if `model` is passed, `model.feature1` can |
|
be returned, but not `model.feature1.layer2`. |
|
""" |
|
|
|
_version = 3 |
|
|
|
def __init__(self, model: nn.Module, return_layers: List[str]) -> None: |
|
if not set(return_layers).issubset([name for name, _ in model.named_children()]): |
|
raise ValueError("return_layers are not present in model. {}"\ |
|
.format([name for name, _ in model.named_children()])) |
|
orig_return_layers = return_layers |
|
return_layers = {str(k): str(k) for k in return_layers} |
|
layers = OrderedDict() |
|
for name, module in model.named_children(): |
|
layers[name] = module |
|
if name in return_layers: |
|
del return_layers[name] |
|
if not return_layers: |
|
break |
|
|
|
super().__init__(layers) |
|
self.return_layers = orig_return_layers |
|
|
|
def forward(self, x): |
|
|
|
outputs = [] |
|
for name, module in self.items(): |
|
x = module(x) |
|
if name in self.return_layers: |
|
|
|
|
|
outputs.append(x) |
|
|
|
return outputs |
|
|
|
|