|
''' |
|
Utilities for instrumenting a torch model. |
|
|
|
InstrumentedModel will wrap a pytorch model and allow hooking |
|
arbitrary layers to monitor or modify their output directly. |
|
|
|
Modified by Erik Härkönen: |
|
- 29.11.2019: Unhooking bugfix |
|
- 25.01.2020: Offset edits, removed old API |
|
''' |
|
|
|
import torch, numpy, types |
|
from collections import OrderedDict |
|
|
|
class InstrumentedModel(torch.nn.Module): |
|
''' |
|
A wrapper for hooking, probing and intervening in pytorch Modules. |
|
Example usage: |
|
|
|
``` |
|
model = load_my_model() |
|
with inst as InstrumentedModel(model): |
|
inst.retain_layer(layername) |
|
inst.edit_layer(layername, 0.5, target_features) |
|
inst.edit_layer(layername, offset=offset_tensor) |
|
inst(inputs) |
|
original_features = inst.retained_layer(layername) |
|
``` |
|
''' |
|
|
|
def __init__(self, model): |
|
super(InstrumentedModel, self).__init__() |
|
self.model = model |
|
self._retained = OrderedDict() |
|
self._ablation = {} |
|
self._replacement = {} |
|
self._offset = {} |
|
self._hooked_layer = {} |
|
self._old_forward = {} |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, type, value, traceback): |
|
self.close() |
|
|
|
def forward(self, *inputs, **kwargs): |
|
return self.model(*inputs, **kwargs) |
|
|
|
def retain_layer(self, layername): |
|
''' |
|
Pass a fully-qualified layer name (E.g., module.submodule.conv3) |
|
to hook that layer and retain its output each time the model is run. |
|
A pair (layername, aka) can be provided, and the aka will be used |
|
as the key for the retained value instead of the layername. |
|
''' |
|
self.retain_layers([layername]) |
|
|
|
def retain_layers(self, layernames): |
|
''' |
|
Retains a list of a layers at once. |
|
''' |
|
self.add_hooks(layernames) |
|
for layername in layernames: |
|
aka = layername |
|
if not isinstance(aka, str): |
|
layername, aka = layername |
|
if aka not in self._retained: |
|
self._retained[aka] = None |
|
|
|
def retained_features(self): |
|
''' |
|
Returns a dict of all currently retained features. |
|
''' |
|
return OrderedDict(self._retained) |
|
|
|
def retained_layer(self, aka=None, clear=False): |
|
''' |
|
Retrieve retained data that was previously hooked by retain_layer. |
|
Call this after the model is run. If clear is set, then the |
|
retained value will return and also cleared. |
|
''' |
|
if aka is None: |
|
|
|
aka = next(self._retained.keys().__iter__()) |
|
result = self._retained[aka] |
|
if clear: |
|
self._retained[aka] = None |
|
return result |
|
|
|
def edit_layer(self, layername, ablation=None, replacement=None, offset=None): |
|
''' |
|
Pass a fully-qualified layer name (E.g., module.submodule.conv3) |
|
to hook that layer and modify its output each time the model is run. |
|
The output of the layer will be modified to be a convex combination |
|
of the replacement and x interpolated according to the ablation, i.e.: |
|
`output = x * (1 - a) + (r * a)`. |
|
Additionally or independently, an offset can be added to the output. |
|
''' |
|
if not isinstance(layername, str): |
|
layername, aka = layername |
|
else: |
|
aka = layername |
|
|
|
|
|
if ablation is None and replacement is not None: |
|
ablation = 1.0 |
|
self.add_hooks([(layername, aka)]) |
|
if ablation is not None: |
|
self._ablation[aka] = ablation |
|
if replacement is not None: |
|
self._replacement[aka] = replacement |
|
if offset is not None: |
|
self._offset[aka] = offset |
|
|
|
|
|
def remove_edits(self, layername=None, remove_offset=True, remove_replacement=True): |
|
''' |
|
Removes edits at the specified layer, or removes edits at all layers |
|
if no layer name is specified. |
|
''' |
|
if layername is None: |
|
if remove_replacement: |
|
self._ablation.clear() |
|
self._replacement.clear() |
|
if remove_offset: |
|
self._offset.clear() |
|
return |
|
|
|
if not isinstance(layername, str): |
|
layername, aka = layername |
|
else: |
|
aka = layername |
|
if remove_replacement and aka in self._ablation: |
|
del self._ablation[aka] |
|
if remove_replacement and aka in self._replacement: |
|
del self._replacement[aka] |
|
if remove_offset and aka in self._offset: |
|
del self._offset[aka] |
|
|
|
def add_hooks(self, layernames): |
|
''' |
|
Sets up a set of layers to be hooked. |
|
|
|
Usually not called directly: use edit_layer or retain_layer instead. |
|
''' |
|
needed = set() |
|
aka_map = {} |
|
for name in layernames: |
|
aka = name |
|
if not isinstance(aka, str): |
|
name, aka = name |
|
if self._hooked_layer.get(aka, None) != name: |
|
aka_map[name] = aka |
|
needed.add(name) |
|
if not needed: |
|
return |
|
for name, layer in self.model.named_modules(): |
|
if name in aka_map: |
|
needed.remove(name) |
|
aka = aka_map[name] |
|
self._hook_layer(layer, name, aka) |
|
for name in needed: |
|
raise ValueError('Layer %s not found in model' % name) |
|
|
|
def _hook_layer(self, layer, layername, aka): |
|
''' |
|
Internal method to replace a forward method with a closure that |
|
intercepts the call, and tracks the hook so that it can be reverted. |
|
''' |
|
if aka in self._hooked_layer: |
|
raise ValueError('Layer %s already hooked' % aka) |
|
if layername in self._old_forward: |
|
raise ValueError('Layer %s already hooked' % layername) |
|
self._hooked_layer[aka] = layername |
|
self._old_forward[layername] = (layer, aka, |
|
layer.__dict__.get('forward', None)) |
|
editor = self |
|
original_forward = layer.forward |
|
def new_forward(self, *inputs, **kwargs): |
|
original_x = original_forward(*inputs, **kwargs) |
|
x = editor._postprocess_forward(original_x, aka) |
|
return x |
|
layer.forward = types.MethodType(new_forward, layer) |
|
|
|
def _unhook_layer(self, aka): |
|
''' |
|
Internal method to remove a hook, restoring the original forward method. |
|
''' |
|
if aka not in self._hooked_layer: |
|
return |
|
layername = self._hooked_layer[aka] |
|
layer, check, old_forward = self._old_forward[layername] |
|
assert check == aka |
|
if old_forward is None: |
|
if 'forward' in layer.__dict__: |
|
del layer.__dict__['forward'] |
|
else: |
|
layer.forward = old_forward |
|
del self._old_forward[layername] |
|
del self._hooked_layer[aka] |
|
if aka in self._ablation: |
|
del self._ablation[aka] |
|
if aka in self._replacement: |
|
del self._replacement[aka] |
|
if aka in self._offset: |
|
del self._offset[aka] |
|
if aka in self._retained: |
|
del self._retained[aka] |
|
|
|
def _postprocess_forward(self, x, aka): |
|
''' |
|
The internal method called by the hooked layers after they are run. |
|
''' |
|
|
|
if aka in self._retained: |
|
self._retained[aka] = x.detach() |
|
|
|
|
|
a = make_matching_tensor(self._ablation, aka, x) |
|
if a is not None: |
|
x = x * (1 - a) |
|
v = make_matching_tensor(self._replacement, aka, x) |
|
if v is not None: |
|
x += (v * a) |
|
|
|
|
|
b = make_matching_tensor(self._offset, aka, x) |
|
if b is not None: |
|
x = x + b |
|
|
|
return x |
|
|
|
def close(self): |
|
''' |
|
Unhooks all hooked layers in the model. |
|
''' |
|
for aka in list(self._old_forward.keys()): |
|
self._unhook_layer(aka) |
|
assert len(self._old_forward) == 0 |
|
|
|
|
|
def make_matching_tensor(valuedict, name, data): |
|
''' |
|
Converts `valuedict[name]` to be a tensor with the same dtype, device, |
|
and dimension count as `data`, and caches the converted tensor. |
|
''' |
|
v = valuedict.get(name, None) |
|
if v is None: |
|
return None |
|
if not isinstance(v, torch.Tensor): |
|
|
|
v = torch.from_numpy(numpy.array(v)) |
|
valuedict[name] = v |
|
if not v.device == data.device or not v.dtype == data.dtype: |
|
|
|
assert not v.requires_grad, '%s wrong device or type' % (name) |
|
v = v.to(device=data.device, dtype=data.dtype) |
|
valuedict[name] = v |
|
if len(v.shape) < len(data.shape): |
|
|
|
assert not v.requires_grad, '%s wrong dimensions' % (name) |
|
v = v.view((1,) + tuple(v.shape) + |
|
(1,) * (len(data.shape) - len(v.shape) - 1)) |
|
valuedict[name] = v |
|
return v |
|
|