File size: 6,194 Bytes
e17e8cc
 
 
ae0d2c8
e17e8cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Imports
import os
import copy
import torch
import glob
import pandas as pd
import pickle
from xml.dom import minidom
from svgpathtools import svg2paths2
from svgpathtools import wsvg
import sys
sys.path.append(os.getcwd())
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
from src.preprocessing.deepsvg.deepsvg_config import config_hierarchical_ordered
from src.preprocessing.deepsvg.deepsvg_utils import train_utils
from src.preprocessing.deepsvg.deepsvg_utils import utils
from src.preprocessing.deepsvg.deepsvg_dataloader import svg_dataset

# ---- Methods for embedding logos ----

def compute_embedding_folder(folder_path: str, model_path: str, save: str = None) -> pd.DataFrame:
    data_list = []
    for file in os.listdir(folder_path):
        print('File: ' + file)
        try:
            embedding = compute_embedding(os.path.join(folder_path, file), model_path)
            embedding['filename'] = file
            data_list.append(embedding)
        except:
            print('Embedding failed')
    print('Concatenating')
    data = pd.concat(data_list)
    if not save == None:
        output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
        pickle.dump(data, output)
        output.close()
    return data
    

def compute_embedding(path: str, model_path: str, save: str = None) -> pd.DataFrame:
    # Convert all primitives to SVG paths - TODO text
    paths, attributes, svg_attributes = svg2paths2(path) # In previous project, this is performed at the end
    wsvg(paths, attributes=attributes, svg_attributes=svg_attributes, filename=path)
    
    svg = SVG.load_svg(path)
    svg.normalize() # Using DeepSVG normalize instead of expanding viewbox - TODO check is this equal?
    svg_str = svg.to_str()
    
    # Assign animation id to every path - TODO this changes the original logo!
    document = minidom.parseString(svg_str)
    paths = document.getElementsByTagName('path')
    for i in range(len(paths)):
        paths[i].setAttribute('animation_id', str(i))
    with open(path, 'wb') as svg_file:
        svg_file.write(document.toxml(encoding='iso-8859-1'))

    # Decompose SVGs

    decomposed_svgs = {}
    
    for i in range(len(paths)):
        doc_temp = copy.deepcopy(document)
        paths_temp = doc_temp.getElementsByTagName('path')
        current_path = paths_temp[i]
        # Iteratively choose path i and remove all others
        remove_temp = paths_temp[:i] + paths_temp[i+1:]
        for path in remove_temp:
            if not path.parentNode.nodeName == 'clipPath':
                path.parentNode.removeChild(path)
        # Check for style attributes; add in case there are none
        if len(current_path.getAttribute('style')) <= 0:
            current_path.setAttribute('stroke', 'black')
            current_path.setAttribute('stroke-width', '2')
        id = current_path.getAttribute('animation_id')
        decomposed_svgs[id] = doc_temp.toprettyxml(encoding='iso-8859-1')
        doc_temp.unlink()
    #print(decomposed_svgs)
    meta = {}
    for id in decomposed_svgs:
        svg_d_str = decomposed_svgs[id]
        # Load into SVG and canonicalize
        current_svg = SVG.from_str(svg_d_str)
        # Canonicalize
        current_svg.canonicalize() # Applies DeepSVG canonicalize; previously custom methods were used
        decomposed_svgs[id] = current_svg.to_str()
        if not os.path.exists('data/temp_svg'):
            os.mkdir('data/temp_svg')
        with open(('data/temp_svg/path_' + str(id)) + '.svg', 'w') as svg_file:
            svg_file.write(decomposed_svgs[id])

        # Collect metadata
        len_groups = [path_group.total_len() for path_group in current_svg.svg_path_groups]
        start_pos = [path_group.svg_paths[0].start_pos for path_group in current_svg.svg_path_groups]
        try:
            total_len = sum(len_groups)
            nb_groups = len(len_groups)
            max_len_group = max(len_groups)
        except:
            total_len = 0
            nb_groups = 0
            max_len_group = 0
        
        meta[id] = {
            'id': id,
            'total_len': total_len,
            'nb_groups': nb_groups,
            'len_groups': len_groups,
            'max_len_group': max_len_group,
            'start_pos': start_pos
        }
    metadata = pd.DataFrame(meta.values())
    #print(metadata)
    if not os.path.exists('data/metadata'):
        os.mkdir('data/metadata')
    metadata.to_csv('data/metadata/metadata.csv', index=False)
    # Load pretrained DeepSVG model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    cfg = config_hierarchical_ordered.Config()
    model = cfg.make_model().to(device)
    train_utils.load_model(model_path, model)
    model.eval()
    # Load dataset
    cfg.data_dir = 'data/temp_svg/'
    cfg.meta_filepath = 'data/metadata/metadata.csv'
    dataset = svg_dataset.load_dataset(cfg)
    svg_files = glob.glob('data/temp_svg/*.svg')
    #print(svg_files)
    svg_list = []
    for svg_file in svg_files:
        id = svg_file.split('\\')[1].split('_')[1].split('.')[0]
        # Preprocessing
        svg = SVG.load_svg(svg_file)
        svg = dataset.simplify(svg)
        svg = dataset.preprocess(svg, augment=False)
        data = dataset.get(svg=svg)
        # Get embedding
        model_args = utils.batchify((data[key] for key in cfg.model_args), device)
        with torch.no_grad():
            z = model(*model_args, encode_mode=True).cpu().numpy()[0][0][0]
        dict_data = {
            'animation_id': id,
            'embedding': z
        }
        svg_list.append(dict_data)
    data = pd.DataFrame.from_records(svg_list, index='animation_id')['embedding'].apply(pd.Series)
    data.reset_index(level=0, inplace=True)
    data.dropna(inplace=True)
    data.reset_index(drop=True, inplace=True)
    if not save == None:
        output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
        pickle.dump(data, output)
        output.close()
    print('Embedding computed')
    return data


#compute_embedding_folder('data/raw_dataset', 'src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar', 'data/embedding')