Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
import numpy as np | |
import utils | |
import torch | |
class Tuning(ABC): | |
def Image(self, image): | |
pass | |
class TuningReconstruction(Tuning): | |
def __init__(self, model, optimizer, target_representation, | |
content_feature_maps_index, style_feature_maps_indices): | |
self.model = model | |
self.optimizer = optimizer | |
self.target_representation = target_representation | |
self.content_feature_maps_index = content_feature_maps_index | |
self.style_feature_maps_indices = style_feature_maps_indices | |
def Image(self, image): | |
# Finds the current representation | |
set_of_feature_maps = self.model(image) | |
if utils.yamlGet('reconstruct') == 'Content': | |
current_representation = set_of_feature_maps[ | |
self.content_feature_maps_index].squeeze(axis=0) | |
elif utils.yamlGet('reconstruct') == 'Style': | |
current_representation = [ | |
utils.gram_matrix(fmaps) | |
for i, fmaps in enumerate(set_of_feature_maps) | |
if i in self.style_feature_maps_indices | |
] | |
loss = 0.0 | |
if utils.yamlGet('reconstruct') == 'Content': | |
loss = torch.nn.MSELoss(reduction='mean')( | |
self.target_representation, current_representation) | |
elif utils.yamlGet('reconstruct') == 'Style': | |
for gram_gt, gram_hat in zip(self.target_representation, | |
current_representation): | |
loss += (1 / len(self.target_representation)) * \ | |
torch.nn.MSELoss( | |
reduction='sum')(gram_gt[0], gram_hat[0]) | |
loss.backward() | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
return loss.item(), current_representation | |
class Reconstruct(ABC): | |
def Visualize(self): | |
pass | |
class ContentReconstruct(Reconstruct): | |
""" | |
tcr -> target_content_representation | |
""" | |
def __init__(self, feature_maps): | |
self.fm = feature_maps | |
self.tcr = self.fm['set_of_feature_maps'][ | |
self.fm['content_feature_maps_index_name'][0]].squeeze(axis=0) | |
self.nfm = self.tcr.size()[0] | |
def Visualize(self): | |
for i in range(self.nfm): | |
feature_map = self.tcr[i].to('cpu').numpy() | |
feature_map = np.uint8(utils.get_uint8_range(feature_map)) | |
# plt.imshow(feature_map) | |
# plt.title( | |
# f'Feature map {i+1}/{num_of_feature_maps} from layer' | |
# f' {content_feature_maps_index_name[1]} ' | |
# f'(model={config["model"]}) for' | |
# f' {config["content_img_name"]} image.' | |
# ) | |
# plt.show() | |
filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}' | |
utils.save_image(feature_map, os.path.join(dump_path, filename)) | |
class StyleReconstruct(Reconstruct): | |
pass | |
class Invoker: | |
pass | |