Spaces:
Build error
Build error
import os, sys, argparse, json, shutil | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
from matplotlib.figure import Figure | |
from matplotlib.ticker import MaxNLocator | |
import matplotlib | |
def main(): | |
parser = argparse.ArgumentParser(description='ACE optimization utility', | |
prog='python -m netdissect.aceoptimize') | |
parser.add_argument('--classname', type=str, default=None, | |
help='intervention classname') | |
parser.add_argument('--layer', type=str, default='layer4', | |
help='layer name') | |
parser.add_argument('--outdir', type=str, default=None, | |
help='dissection directory') | |
parser.add_argument('--metric', type=str, default=None, | |
help='experiment variant') | |
args = parser.parse_args() | |
if args.metric is None: | |
args.metric = 'ace' | |
run_command(args) | |
def run_command(args): | |
fig = Figure(figsize=(4.5,3.5)) | |
FigureCanvas(fig) | |
ax = fig.add_subplot(111) | |
for metric in [args.metric, 'iou']: | |
jsonname = os.path.join(args.outdir, args.layer, 'fullablation', | |
'%s-%s.json' % (args.classname, metric)) | |
with open(jsonname) as f: | |
summary = json.load(f) | |
baseline = summary['baseline'] | |
effects = summary['ablation_effects'][:26] | |
norm_effects = [0] + [1.0 - e / baseline for e in effects] | |
ax.plot(norm_effects, label= | |
'Units by ACE' if 'ace' in metric else 'Top units by IoU') | |
ax.set_title('Effect of ablating units for %s' % (args.classname)) | |
ax.grid(True) | |
ax.legend() | |
ax.set_ylabel('Portion of %s pixels removed' % args.classname) | |
ax.set_xlabel('Number of units ablated') | |
ax.set_ylim(0, 1.0) | |
ax.set_xlim(0, 25) | |
fig.tight_layout() | |
dirname = os.path.join(args.outdir, args.layer, 'fullablation') | |
fig.savefig(os.path.join(dirname, 'effect-%s-%s.png' % | |
(args.classname, args.metric))) | |
fig.savefig(os.path.join(dirname, 'effect-%s-%s.pdf' % | |
(args.classname, args.metric))) | |
if __name__ == '__main__': | |
main() | |