Spaces:
Runtime error
Runtime error
File size: 4,466 Bytes
cbcb207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import src.utils.utils as utils
from src.utils.video_utils import create_video_from_intermediate_results
import torch
from torch.autograd import Variable
from torch.optim import Adam, LBFGS
import numpy as np
def make_tuning_step(optimizer, config):
def tuning_step(optimizing_img):
config.current_set_of_feature_maps = config.neural_net(optimizing_img)
loss, config.current_representation = utils.getCurrentData(config)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss.item(), config.current_representation
return tuning_step
def reconstruct_image_from_representation():
dump_path = os.path.join(os.path.dirname(__file__), "data/reconstruct")
config = utils.Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img, img_path = utils.getImageAndPath(device)
white_noise_img = np.random.uniform(-90., 90.,
img.shape).astype(np.float32)
init_img = torch.from_numpy(white_noise_img).float().to(device)
optimizing_img = Variable(init_img, requires_grad=True)
# indices pick relevant feature maps (say conv4_1, relu1_1, etc.)
output = list(utils.prepare_model(device))
config.neural_net = output[0]
content_feature_maps_index_name = output[1]
style_feature_maps_indices_names = output[2]
config.content_feature_maps_index = content_feature_maps_index_name[0]
config.style_feature_maps_indices = style_feature_maps_indices_names[0]
config.current_set_of_feature_maps = config.neural_net(img)
config.target_content_representation = config.current_set_of_feature_maps[
config.content_feature_maps_index].squeeze(axis=0)
config.target_style_representation = [
utils.gram_matrix(fmaps)
for i, fmaps in enumerate(config.current_set_of_feature_maps)
if i in config.style_feature_maps_indices
]
if utils.yamlGet('reconstruct') == "Content":
config.target_representation = config.target_content_representation
num_of_feature_maps = config.target_content_representation.size()[0]
for i in range(num_of_feature_maps):
feature_map = config.target_content_representation[i].to(
'cpu').numpy()
feature_map = np.uint8(utils.get_uint8_range(feature_map))
# filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
# utils.save_image(feature_map, os.path.join(dump_path, filename))
elif utils.yamlGet('reconstruct') == "Style":
config.target_representation = config.target_style_representation
num_of_gram_matrices = len(config.target_style_representation)
for i in range(num_of_gram_matrices):
Gram_matrix = config.target_style_representation[i].squeeze(
axis=0).to('cpu').numpy()
Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix))
# filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
# utils.save_image(Gram_matrix, os.path.join(dump_path, filename))
if utils.yamlGet('optimizer') == 'Adam':
optimizer = Adam((optimizing_img, ), lr=utils.yamlGet('learning_rate'))
tuning_step = make_tuning_step(optimizer, config)
for it in range(utils.yamlGet('optimizer')):
tuning_step(optimizing_img)
with torch.no_grad():
utils.save_optimizing_image(optimizing_img, dump_path, it)
elif utils.yamlGet('optimizer') == 'LBFGS':
optimizer = LBFGS((optimizing_img, ),
max_iter=utils.yamlGet('optimizer'),
line_search_fn='strong_wolfe')
cnt = 0
def closure():
nonlocal cnt
loss = utils.getLBFGSReconstructLoss(config, optimizing_img)
loss.backward()
with torch.no_grad():
utils.save_optimizing_image(optimizing_img, dump_path, cnt)
cnt += 1
return loss
optimizer.step(closure)
return dump_path
if __name__ == "__main__":
# reconstruct style or content image purely from their representation
results_path = reconstruct_image_from_representation()
create_video_from_intermediate_results(results_path)
|