FashionGen / netdissect /nethook.py
Prathm's picture
Duplicate from safi842/FashionGen
337965d
raw
history blame
9.33 kB
'''
Utilities for instrumenting a torch model.
InstrumentedModel will wrap a pytorch model and allow hooking
arbitrary layers to monitor or modify their output directly.
Modified by Erik Härkönen:
- 29.11.2019: Unhooking bugfix
- 25.01.2020: Offset edits, removed old API
'''
import torch, numpy, types
from collections import OrderedDict
class InstrumentedModel(torch.nn.Module):
'''
A wrapper for hooking, probing and intervening in pytorch Modules.
Example usage:
```
model = load_my_model()
with inst as InstrumentedModel(model):
inst.retain_layer(layername)
inst.edit_layer(layername, 0.5, target_features)
inst.edit_layer(layername, offset=offset_tensor)
inst(inputs)
original_features = inst.retained_layer(layername)
```
'''
def __init__(self, model):
super(InstrumentedModel, self).__init__()
self.model = model
self._retained = OrderedDict()
self._ablation = {}
self._replacement = {}
self._offset = {}
self._hooked_layer = {}
self._old_forward = {}
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def forward(self, *inputs, **kwargs):
return self.model(*inputs, **kwargs)
def retain_layer(self, layername):
'''
Pass a fully-qualified layer name (E.g., module.submodule.conv3)
to hook that layer and retain its output each time the model is run.
A pair (layername, aka) can be provided, and the aka will be used
as the key for the retained value instead of the layername.
'''
self.retain_layers([layername])
def retain_layers(self, layernames):
'''
Retains a list of a layers at once.
'''
self.add_hooks(layernames)
for layername in layernames:
aka = layername
if not isinstance(aka, str):
layername, aka = layername
if aka not in self._retained:
self._retained[aka] = None
def retained_features(self):
'''
Returns a dict of all currently retained features.
'''
return OrderedDict(self._retained)
def retained_layer(self, aka=None, clear=False):
'''
Retrieve retained data that was previously hooked by retain_layer.
Call this after the model is run. If clear is set, then the
retained value will return and also cleared.
'''
if aka is None:
# Default to the first retained layer.
aka = next(self._retained.keys().__iter__())
result = self._retained[aka]
if clear:
self._retained[aka] = None
return result
def edit_layer(self, layername, ablation=None, replacement=None, offset=None):
'''
Pass a fully-qualified layer name (E.g., module.submodule.conv3)
to hook that layer and modify its output each time the model is run.
The output of the layer will be modified to be a convex combination
of the replacement and x interpolated according to the ablation, i.e.:
`output = x * (1 - a) + (r * a)`.
Additionally or independently, an offset can be added to the output.
'''
if not isinstance(layername, str):
layername, aka = layername
else:
aka = layername
# The default ablation if a replacement is specified is 1.0.
if ablation is None and replacement is not None:
ablation = 1.0
self.add_hooks([(layername, aka)])
if ablation is not None:
self._ablation[aka] = ablation
if replacement is not None:
self._replacement[aka] = replacement
if offset is not None:
self._offset[aka] = offset
# If needed, could add an arbitrary postprocessing lambda here.
def remove_edits(self, layername=None, remove_offset=True, remove_replacement=True):
'''
Removes edits at the specified layer, or removes edits at all layers
if no layer name is specified.
'''
if layername is None:
if remove_replacement:
self._ablation.clear()
self._replacement.clear()
if remove_offset:
self._offset.clear()
return
if not isinstance(layername, str):
layername, aka = layername
else:
aka = layername
if remove_replacement and aka in self._ablation:
del self._ablation[aka]
if remove_replacement and aka in self._replacement:
del self._replacement[aka]
if remove_offset and aka in self._offset:
del self._offset[aka]
def add_hooks(self, layernames):
'''
Sets up a set of layers to be hooked.
Usually not called directly: use edit_layer or retain_layer instead.
'''
needed = set()
aka_map = {}
for name in layernames:
aka = name
if not isinstance(aka, str):
name, aka = name
if self._hooked_layer.get(aka, None) != name:
aka_map[name] = aka
needed.add(name)
if not needed:
return
for name, layer in self.model.named_modules():
if name in aka_map:
needed.remove(name)
aka = aka_map[name]
self._hook_layer(layer, name, aka)
for name in needed:
raise ValueError('Layer %s not found in model' % name)
def _hook_layer(self, layer, layername, aka):
'''
Internal method to replace a forward method with a closure that
intercepts the call, and tracks the hook so that it can be reverted.
'''
if aka in self._hooked_layer:
raise ValueError('Layer %s already hooked' % aka)
if layername in self._old_forward:
raise ValueError('Layer %s already hooked' % layername)
self._hooked_layer[aka] = layername
self._old_forward[layername] = (layer, aka,
layer.__dict__.get('forward', None))
editor = self
original_forward = layer.forward
def new_forward(self, *inputs, **kwargs):
original_x = original_forward(*inputs, **kwargs)
x = editor._postprocess_forward(original_x, aka)
return x
layer.forward = types.MethodType(new_forward, layer)
def _unhook_layer(self, aka):
'''
Internal method to remove a hook, restoring the original forward method.
'''
if aka not in self._hooked_layer:
return
layername = self._hooked_layer[aka]
layer, check, old_forward = self._old_forward[layername]
assert check == aka
if old_forward is None:
if 'forward' in layer.__dict__:
del layer.__dict__['forward']
else:
layer.forward = old_forward
del self._old_forward[layername]
del self._hooked_layer[aka]
if aka in self._ablation:
del self._ablation[aka]
if aka in self._replacement:
del self._replacement[aka]
if aka in self._offset:
del self._offset[aka]
if aka in self._retained:
del self._retained[aka]
def _postprocess_forward(self, x, aka):
'''
The internal method called by the hooked layers after they are run.
'''
# Retain output before edits, if desired.
if aka in self._retained:
self._retained[aka] = x.detach()
# Apply replacement edit
a = make_matching_tensor(self._ablation, aka, x)
if a is not None:
x = x * (1 - a)
v = make_matching_tensor(self._replacement, aka, x)
if v is not None:
x += (v * a)
# Apply offset edit
b = make_matching_tensor(self._offset, aka, x)
if b is not None:
x = x + b
return x
def close(self):
'''
Unhooks all hooked layers in the model.
'''
for aka in list(self._old_forward.keys()):
self._unhook_layer(aka)
assert len(self._old_forward) == 0
def make_matching_tensor(valuedict, name, data):
'''
Converts `valuedict[name]` to be a tensor with the same dtype, device,
and dimension count as `data`, and caches the converted tensor.
'''
v = valuedict.get(name, None)
if v is None:
return None
if not isinstance(v, torch.Tensor):
# Accept non-torch data.
v = torch.from_numpy(numpy.array(v))
valuedict[name] = v
if not v.device == data.device or not v.dtype == data.dtype:
# Ensure device and type matches.
assert not v.requires_grad, '%s wrong device or type' % (name)
v = v.to(device=data.device, dtype=data.dtype)
valuedict[name] = v
if len(v.shape) < len(data.shape):
# Ensure dimensions are unsqueezed as needed.
assert not v.requires_grad, '%s wrong dimensions' % (name)
v = v.view((1,) + tuple(v.shape) +
(1,) * (len(data.shape) - len(v.shape) - 1))
valuedict[name] = v
return v