File size: 4,466 Bytes
cbcb207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import src.utils.utils as utils
from src.utils.video_utils import create_video_from_intermediate_results
import torch
from torch.autograd import Variable
from torch.optim import Adam, LBFGS
import numpy as np


def make_tuning_step(optimizer, config):

    def tuning_step(optimizing_img):

        config.current_set_of_feature_maps = config.neural_net(optimizing_img)
        loss, config.current_representation = utils.getCurrentData(config)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        return loss.item(), config.current_representation

    return tuning_step


def reconstruct_image_from_representation():

    dump_path = os.path.join(os.path.dirname(__file__), "data/reconstruct")
    config = utils.Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img, img_path = utils.getImageAndPath(device)
    white_noise_img = np.random.uniform(-90., 90.,
                                        img.shape).astype(np.float32)
    init_img = torch.from_numpy(white_noise_img).float().to(device)
    optimizing_img = Variable(init_img, requires_grad=True)

    # indices pick relevant feature maps (say conv4_1, relu1_1, etc.)
    output = list(utils.prepare_model(device))
    config.neural_net = output[0]
    content_feature_maps_index_name = output[1]
    style_feature_maps_indices_names = output[2]

    config.content_feature_maps_index = content_feature_maps_index_name[0]
    config.style_feature_maps_indices = style_feature_maps_indices_names[0]

    config.current_set_of_feature_maps = config.neural_net(img)

    config.target_content_representation = config.current_set_of_feature_maps[
        config.content_feature_maps_index].squeeze(axis=0)
    config.target_style_representation = [
        utils.gram_matrix(fmaps)
        for i, fmaps in enumerate(config.current_set_of_feature_maps)
        if i in config.style_feature_maps_indices
    ]

    if utils.yamlGet('reconstruct') == "Content":
        config.target_representation = config.target_content_representation
        num_of_feature_maps = config.target_content_representation.size()[0]
        for i in range(num_of_feature_maps):
            feature_map = config.target_content_representation[i].to(
                'cpu').numpy()
            feature_map = np.uint8(utils.get_uint8_range(feature_map))
            # filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
            # utils.save_image(feature_map, os.path.join(dump_path, filename))

    elif utils.yamlGet('reconstruct') == "Style":
        config.target_representation = config.target_style_representation
        num_of_gram_matrices = len(config.target_style_representation)
        for i in range(num_of_gram_matrices):
            Gram_matrix = config.target_style_representation[i].squeeze(
                axis=0).to('cpu').numpy()
            Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix))
            # filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
            # utils.save_image(Gram_matrix, os.path.join(dump_path, filename))

    if utils.yamlGet('optimizer') == 'Adam':
        optimizer = Adam((optimizing_img, ), lr=utils.yamlGet('learning_rate'))
        tuning_step = make_tuning_step(optimizer, config)
        for it in range(utils.yamlGet('optimizer')):
            tuning_step(optimizing_img)
            with torch.no_grad():
                utils.save_optimizing_image(optimizing_img, dump_path, it)

    elif utils.yamlGet('optimizer') == 'LBFGS':
        optimizer = LBFGS((optimizing_img, ),
                          max_iter=utils.yamlGet('optimizer'),
                          line_search_fn='strong_wolfe')
        cnt = 0

        def closure():
            nonlocal cnt
            loss = utils.getLBFGSReconstructLoss(config, optimizing_img)
            loss.backward()
            with torch.no_grad():
                utils.save_optimizing_image(optimizing_img, dump_path, cnt)
                cnt += 1
            return loss

        optimizer.step(closure)

    return dump_path


if __name__ == "__main__":

    # reconstruct style or content image purely from their representation
    results_path = reconstruct_image_from_representation()

    create_video_from_intermediate_results(results_path)