File size: 3,113 Bytes
cbcb207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from abc import ABC, abstractmethod
import numpy as np
import utils
import torch


class Tuning(ABC):

    @abstractmethod
    def Image(self, image):
        pass


class TuningReconstruction(Tuning):

    def __init__(self, model, optimizer, target_representation,
                 content_feature_maps_index, style_feature_maps_indices):

        self.model = model
        self.optimizer = optimizer
        self.target_representation = target_representation
        self.content_feature_maps_index = content_feature_maps_index
        self.style_feature_maps_indices = style_feature_maps_indices

    def Image(self, image):

        # Finds the current representation
        set_of_feature_maps = self.model(image)
        if utils.yamlGet('reconstruct') == 'Content':
            current_representation = set_of_feature_maps[
                self.content_feature_maps_index].squeeze(axis=0)
        elif utils.yamlGet('reconstruct') == 'Style':
            current_representation = [
                utils.gram_matrix(fmaps)
                for i, fmaps in enumerate(set_of_feature_maps)
                if i in self.style_feature_maps_indices
            ]

        loss = 0.0

        if utils.yamlGet('reconstruct') == 'Content':
            loss = torch.nn.MSELoss(reduction='mean')(
                self.target_representation, current_representation)
        elif utils.yamlGet('reconstruct') == 'Style':
            for gram_gt, gram_hat in zip(self.target_representation,
                                         current_representation):
                loss += (1 / len(self.target_representation)) * \
                    torch.nn.MSELoss(
                    reduction='sum')(gram_gt[0], gram_hat[0])

        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        return loss.item(), current_representation


class Reconstruct(ABC):

    @abstractmethod
    def Visualize(self):
        pass


class ContentReconstruct(Reconstruct):
    """
        tcr -> target_content_representation
    """

    def __init__(self, feature_maps):
        self.fm = feature_maps
        self.tcr = self.fm['set_of_feature_maps'][
            self.fm['content_feature_maps_index_name'][0]].squeeze(axis=0)
        self.nfm = self.tcr.size()[0]

    def Visualize(self):
        for i in range(self.nfm):
            feature_map = self.tcr[i].to('cpu').numpy()
            feature_map = np.uint8(utils.get_uint8_range(feature_map))
            # plt.imshow(feature_map)
            # plt.title(
            #     f'Feature map {i+1}/{num_of_feature_maps} from layer'
            #     f' {content_feature_maps_index_name[1]} '
            #     f'(model={config["model"]}) for'
            #     f' {config["content_img_name"]} image.'
            # )
            # plt.show()
            filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
            utils.save_image(feature_map, os.path.join(dump_path, filename))


class StyleReconstruct(Reconstruct):
    pass


class Invoker:
    pass