FashionGen / netdissect /__init__.py
Prathm's picture
Duplicate from safi842/FashionGen
337965d
raw
history blame
2.22 kB
'''
Netdissect package.
To run dissection:
1. Load up the convolutional model you wish to dissect, and wrap it
in an InstrumentedModel. Call imodel.retain_layers([layernames,..])
to analyze a specified set of layers.
2. Load the segmentation dataset using the BrodenDataset class;
use the transform_image argument to normalize images to be
suitable for the model, or the size argument to truncate the dataset.
3. Write a function to recover the original image (with RGB scaled to
[0...1]) given a normalized dataset image; ReverseNormalize in this
package inverts transforms.Normalize for this purpose.
4. Choose a directory in which to write the output, and call
dissect(outdir, model, dataset).
Example:
from netdissect import InstrumentedModel, dissect
from netdissect import BrodenDataset, ReverseNormalize
model = InstrumentedModel(load_my_model())
model.eval()
model.cuda()
model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5'])
bds = BrodenDataset('dataset/broden1_227',
transform_image=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
size=1000)
dissect('result/dissect', model, bds,
recover_image=ReverseNormalize(IMAGE_MEAN, IMAGE_STDEV),
examples_per_unit=10)
'''
from .dissection import dissect, ReverseNormalize
from .dissection import ClassifierSegRunner, GeneratorSegRunner
from .dissection import ImageOnlySegRunner
from .broden import BrodenDataset, ScaleSegmentation, scatter_batch
from .segdata import MultiSegmentDataset
from .nethook import InstrumentedModel
from .zdataset import z_dataset_for_model, z_sample_for_model, standard_z_sample
from . import actviz
from . import progress
from . import runningstats
from . import sampler
__all__ = [
'dissect', 'ReverseNormalize',
'ClassifierSegRunner', 'GeneratorSegRunner', 'ImageOnlySegRunner',
'BrodenDataset', 'ScaleSegmentation', 'scatter_batch',
'MultiSegmentDataset',
'InstrumentedModel',
'z_dataset_for_model', 'z_sample_for_model', 'standard_z_sample'
'actviz',
'progress',
'runningstats',
'sampler'
]