FashionGen / netdissect /modelconfig.py
Prathm's picture
Duplicate from safi842/FashionGen
337965d
raw
history blame
5.78 kB
'''
Original from https://github.com/CSAILVision/GANDissect
Modified by Erik Härkönen, 29.11.2019
'''
import numbers
import torch
from netdissect.autoeval import autoimport_eval
from netdissect.progress import print_progress
from netdissect.nethook import InstrumentedModel
from netdissect.easydict import EasyDict
def create_instrumented_model(args, **kwargs):
'''
Creates an instrumented model out of a namespace of arguments that
correspond to ArgumentParser command-line args:
model: a string to evaluate as a constructor for the model.
pthfile: (optional) filename of .pth file for the model.
layers: a list of layers to instrument, defaulted if not provided.
edit: True to instrument the layers for editing.
gen: True for a generator model. One-pixel input assumed.
imgsize: For non-generator models, (y, x) dimensions for RGB input.
cuda: True to use CUDA.
The constructed model will be decorated with the following attributes:
input_shape: (usually 4d) tensor shape for single-image input.
output_shape: 4d tensor shape for output.
feature_shape: map of layer names to 4d tensor shape for featuremaps.
retained: map of layernames to tensors, filled after every evaluation.
ablation: if editing, map of layernames to [0..1] alpha values to fill.
replacement: if editing, map of layernames to values to fill.
When editing, the feature value x will be replaced by:
`x = (replacement * ablation) + (x * (1 - ablation))`
'''
args = EasyDict(vars(args), **kwargs)
# Construct the network
if args.model is None:
print_progress('No model specified')
return None
if isinstance(args.model, torch.nn.Module):
model = args.model
else:
model = autoimport_eval(args.model)
# Unwrap any DataParallel-wrapped model
if isinstance(model, torch.nn.DataParallel):
model = next(model.children())
# Load its state dict
meta = {}
if getattr(args, 'pthfile', None) is not None:
data = torch.load(args.pthfile)
if 'state_dict' in data:
meta = {}
for key in data:
if isinstance(data[key], numbers.Number):
meta[key] = data[key]
data = data['state_dict']
submodule = getattr(args, 'submodule', None)
if submodule is not None and len(submodule):
remove_prefix = submodule + '.'
data = { k[len(remove_prefix):]: v for k, v in data.items()
if k.startswith(remove_prefix)}
if not len(data):
print_progress('No submodule %s found in %s' %
(submodule, args.pthfile))
return None
model.load_state_dict(data, strict=not getattr(args, 'unstrict', False))
# Decide which layers to instrument.
if getattr(args, 'layer', None) is not None:
args.layers = [args.layer]
if getattr(args, 'layers', None) is None:
# Skip wrappers with only one named model
container = model
prefix = ''
while len(list(container.named_children())) == 1:
name, container = next(container.named_children())
prefix += name + '.'
# Default to all nontrivial top-level layers except last.
args.layers = [prefix + name
for name, module in container.named_children()
if type(module).__module__ not in [
# Skip ReLU and other activations.
'torch.nn.modules.activation',
# Skip pooling layers.
'torch.nn.modules.pooling']
][:-1]
print_progress('Defaulting to layers: %s' % ' '.join(args.layers))
# Now wrap the model for instrumentation.
model = InstrumentedModel(model)
model.meta = meta
# Instrument the layers.
model.retain_layers(args.layers)
model.eval()
if args.cuda:
model.cuda()
# Annotate input, output, and feature shapes
annotate_model_shapes(model,
gen=getattr(args, 'gen', False),
imgsize=getattr(args, 'imgsize', None),
latent_shape=getattr(args, 'latent_shape', None))
return model
def annotate_model_shapes(model, gen=False, imgsize=None, latent_shape=None):
assert (imgsize is not None) or gen
# Figure the input shape.
if gen:
if latent_shape is None:
# We can guess a generator's input shape by looking at the model.
# Examine first conv in model to determine input feature size.
first_layer = [c for c in model.modules()
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
torch.nn.Linear))][0]
# 4d input if convolutional, 2d input if first layer is linear.
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
input_shape = (1, first_layer.in_channels, 1, 1)
else:
input_shape = (1, first_layer.in_features)
else:
# Specify input shape manually
input_shape = latent_shape
else:
# For a classifier, the input image shape is given as an argument.
input_shape = (1, 3) + tuple(imgsize)
# Run the model once to observe feature shapes.
device = next(model.parameters()).device
dry_run = torch.zeros(input_shape).to(device)
with torch.no_grad():
output = model(dry_run)
# Annotate shapes.
model.input_shape = input_shape
model.feature_shape = { layer: feature.shape
for layer, feature in model.retained_features().items() }
model.output_shape = output.shape
return model