Daniel Gil-U Fuhge commited on
Commit
e17e8cc
1 Parent(s): 2f22ac0

add model files

Browse files
Files changed (38) hide show
  1. AnimationTransformer.py +272 -0
  2. models/animation_transformer.pth +3 -0
  3. models/reward_function_mode_state_dict.pth +3 -0
  4. src/postprocessing/__init__.py +0 -0
  5. src/postprocessing/get_style_attributes.py +318 -0
  6. src/postprocessing/get_svg_color_tendency.py +19 -0
  7. src/postprocessing/get_svg_size_pos.py +268 -0
  8. src/postprocessing/insert_animation.py +333 -0
  9. src/postprocessing/logo_0.svg +809 -0
  10. src/postprocessing/postprocessing.py +604 -0
  11. src/postprocessing/transform_animation_predictor_output.py +78 -0
  12. src/preprocessing/deepsvg/deepsvg_config/config.py +106 -0
  13. src/preprocessing/deepsvg/deepsvg_config/config_hierarchical_ordered.py +29 -0
  14. src/preprocessing/deepsvg/deepsvg_config/default_icons.py +102 -0
  15. src/preprocessing/deepsvg/deepsvg_dataloader/svg_dataset.py +239 -0
  16. src/preprocessing/deepsvg/deepsvg_difflib/tensor.py +305 -0
  17. src/preprocessing/deepsvg/deepsvg_models/basic_blocks.py +70 -0
  18. src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar +0 -0
  19. src/preprocessing/deepsvg/deepsvg_models/layers/attention.py +166 -0
  20. src/preprocessing/deepsvg/deepsvg_models/layers/functional.py +261 -0
  21. src/preprocessing/deepsvg/deepsvg_models/layers/improved_transformer.py +146 -0
  22. src/preprocessing/deepsvg/deepsvg_models/layers/positional_encoding.py +48 -0
  23. src/preprocessing/deepsvg/deepsvg_models/layers/transformer.py +398 -0
  24. src/preprocessing/deepsvg/deepsvg_models/loss.py +70 -0
  25. src/preprocessing/deepsvg/deepsvg_models/model.py +484 -0
  26. src/preprocessing/deepsvg/deepsvg_models/model_config.py +113 -0
  27. src/preprocessing/deepsvg/deepsvg_models/model_utils.py +89 -0
  28. src/preprocessing/deepsvg/deepsvg_schedulers/warmup.py +68 -0
  29. src/preprocessing/deepsvg/deepsvg_svglib/geom.py +493 -0
  30. src/preprocessing/deepsvg/deepsvg_svglib/svg.py +579 -0
  31. src/preprocessing/deepsvg/deepsvg_svglib/svg_command.py +531 -0
  32. src/preprocessing/deepsvg/deepsvg_svglib/svg_path.py +659 -0
  33. src/preprocessing/deepsvg/deepsvg_svglib/svg_primitive.py +452 -0
  34. src/preprocessing/deepsvg/deepsvg_svglib/svglib_utils.py +95 -0
  35. src/preprocessing/deepsvg/deepsvg_svglib/util_fns.py +22 -0
  36. src/preprocessing/deepsvg/deepsvg_utils/train_utils.py +241 -0
  37. src/preprocessing/deepsvg/deepsvg_utils/utils.py +54 -0
  38. src/preprocessing/preprocessing.py +157 -0
AnimationTransformer.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ import dataset_helper
8
+
9
+
10
+ class AnimationTransformer(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim_model, # hidden_size; corresponds to embedding length
14
+ num_heads,
15
+ num_encoder_layers,
16
+ num_decoder_layers,
17
+ dropout_p,
18
+ use_positional_encoder=True
19
+ ):
20
+ super().__init__()
21
+
22
+ self.model_type = "Transformer"
23
+ self.dim_model = dim_model
24
+
25
+ # TODO: Currently left out, as input sequence shuffled. Later check if use is beneficial.
26
+ self.use_positional_encoder = use_positional_encoder
27
+ self.positional_encoder = PositionalEncoding(
28
+ dim_model=dim_model,
29
+ dropout_p=dropout_p
30
+ )
31
+
32
+ self.transformer = nn.Transformer(
33
+ d_model=dim_model,
34
+ nhead=num_heads,
35
+ num_encoder_layers=num_encoder_layers,
36
+ num_decoder_layers=num_decoder_layers,
37
+ dropout=dropout_p,
38
+ batch_first=True
39
+ )
40
+
41
+ def forward(self, src, tgt, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
42
+ # Src size must be (batch_size, src sequence length)
43
+ # Tgt size must be (batch_size, tgt sequence length)
44
+
45
+ if self.use_positional_encoder:
46
+ src = self.positional_encoder(src)
47
+ tgt = self.positional_encoder(tgt)
48
+
49
+ # Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
50
+ out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask,
51
+ tgt_key_padding_mask=tgt_key_padding_mask)
52
+ return out
53
+
54
+
55
+ def get_tgt_mask(size) -> torch.tensor:
56
+ # Generates a square matrix where each row allows one word more to be seen
57
+ mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
58
+ mask = mask.float()
59
+ mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
60
+ mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
61
+
62
+ # EX for size=5:
63
+ # [[0., -inf, -inf, -inf, -inf],
64
+ # [0., 0., -inf, -inf, -inf],
65
+ # [0., 0., 0., -inf, -inf],
66
+ # [0., 0., 0., 0., -inf],
67
+ # [0., 0., 0., 0., 0.]]
68
+
69
+ return mask
70
+
71
+
72
+ def create_pad_mask(matrix: torch.tensor) -> torch.tensor:
73
+ pad_masks = []
74
+
75
+ # Iterate over each sequence in the batch.
76
+ for i in range(0, matrix.size(0)):
77
+ sequence = []
78
+
79
+ # Iterate over each element in the sequence and append True if padding value
80
+ for j in range(0, matrix.size(1)):
81
+ sequence.append(matrix[i, j, 0] == dataset_helper.PADDING_VALUE)
82
+
83
+ pad_masks.append(sequence)
84
+
85
+ #print("matrix", matrix, matrix.shape, "pad_mask", pad_masks)
86
+ return torch.tensor(pad_masks)
87
+
88
+
89
+ def _transformer_call_in_loops(model, batch, device, loss_function):
90
+ source, target = batch[0], batch[1]
91
+ source, target = source.to(device), target.to(device)
92
+
93
+ # First index is all batch entries, second is
94
+ target_input = target[:, :-1] # trg input is offset by one (SOS token and excluding EOS)
95
+ target_expected = target[:, 1:] # trg is offset by one (excluding SOS token)
96
+
97
+ # SOS - 1 - 2 - 3 - 4 - EOS - PAD - PAD // target_input
98
+ # 1 - 2 - 3 - 4 - EOS - PAD - PAD - PAD // target_expected
99
+
100
+ # Get mask to mask out the next words
101
+ tgt_mask = get_tgt_mask(target_input.size(1)).to(device)
102
+
103
+ # Standard training except we pass in y_input and tgt_mask
104
+ prediction = model(source, target_input,
105
+ tgt_mask=tgt_mask,
106
+ src_key_padding_mask=create_pad_mask(source).to(device),
107
+ # Mask with expected as EOS is no input (see above)
108
+ tgt_key_padding_mask=create_pad_mask(target_expected).to(device))
109
+
110
+ return loss_function(prediction, target_expected, create_pad_mask(target_expected).to(device))
111
+ #return loss_function(prediction, target_expected)
112
+
113
+ def train_loop(model, opt, loss_function, dataloader, device):
114
+ model.train()
115
+ total_loss = 0
116
+
117
+ t0 = time.time()
118
+ i = 1
119
+ for batch in dataloader:
120
+ loss = _transformer_call_in_loops(model, batch, device, loss_function)
121
+
122
+ opt.zero_grad()
123
+ loss.backward()
124
+ opt.step()
125
+
126
+ total_loss += loss.detach().item()
127
+
128
+ if i == 1 or i % 10 == 0:
129
+ elapsed_time = time.time() - t0
130
+ total_expected = elapsed_time / i * len(dataloader)
131
+ print(f">> {i}: Time per Batch {elapsed_time / i : .2f}s | "
132
+ f"Total expected {total_expected / 60 : .2f} min | "
133
+ f"Remaining {(total_expected - elapsed_time) / 60 : .2f} min ")
134
+ i += 1
135
+
136
+ print(f">> Epoch time: {(time.time() - t0)/60:.2f} min")
137
+ return total_loss / len(dataloader)
138
+
139
+
140
+ def validation_loop(model, loss_function, dataloader, device):
141
+ model.eval()
142
+ total_loss = 0
143
+
144
+ with torch.no_grad():
145
+ for batch in dataloader:
146
+ loss = _transformer_call_in_loops(model, batch, device, loss_function)
147
+
148
+ total_loss += loss.detach().item()
149
+
150
+ return total_loss / len(dataloader)
151
+
152
+
153
+ def fit(model, optimizer, loss_function, train_dataloader, val_dataloader, epochs, device):
154
+ train_loss_list, validation_loss_list = [], []
155
+
156
+ print("Training and validating model")
157
+ for epoch in range(epochs):
158
+ print("-" * 25, f"Epoch {epoch + 1}", "-" * 25)
159
+
160
+ train_loss = train_loop(model, optimizer, loss_function, train_dataloader, device)
161
+ train_loss_list += [train_loss]
162
+
163
+ validation_loss = validation_loop(model, loss_function, val_dataloader, device)
164
+ validation_loss_list += [validation_loss]
165
+
166
+ print(f"Training loss: {train_loss:.4f}")
167
+ print(f"Validation loss: {validation_loss:.4f}")
168
+ print()
169
+
170
+ return train_loss_list, validation_loss_list
171
+
172
+
173
+ def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True):
174
+ if backpropagate:
175
+ model.train()
176
+ else:
177
+ model.eval()
178
+
179
+ source_sequence = source_sequence.float().to(device)
180
+ y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
181
+
182
+ i = 0
183
+ while i < max_length:
184
+ # Get source mask
185
+ prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
186
+ # tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
187
+ src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
188
+
189
+ next_embedding = prediction[0, -1, :] # prediction on last token
190
+ pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
191
+ #print(pred_deep_svg, pred_type, pred_parameters)
192
+ pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
193
+ device)
194
+
195
+ # === TYPE ===
196
+ # Apply Softmax
197
+ type_softmax = torch.softmax(pred_type, dim=0)
198
+ type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
199
+ animation_type = torch.argmax(type_softmax, dim=0)
200
+
201
+ # Break if EOS is most likely
202
+ if animation_type == 0:
203
+ print("END OF ANIMATION")
204
+ y_input = torch.cat((y_input, sos_token.unsqueeze(0).to(device)), dim=0)
205
+ return y_input
206
+
207
+ pred_type = torch.zeros(11)
208
+ pred_type[animation_type] = 1
209
+
210
+ # === DEEP SVG ===
211
+ # Find the closest path
212
+ distances = [torch.norm(pred_deep_svg - embedding[:-26]) for embedding in source_sequence]
213
+ closest_index = distances.index(min(distances))
214
+ closest_token = source_sequence[closest_index]
215
+
216
+ # === PARAMETERS ===
217
+ # overwrite unused parameters
218
+ for j in range(len(pred_parameters)):
219
+ if j in dataset_helper.ANIMATION_PARAMETER_INDICES[int(animation_type)]:
220
+ continue
221
+ pred_parameters[j] = -1
222
+
223
+ # === SEQUENCE ===
224
+ y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
225
+ y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
226
+
227
+ # === INFO PRINT ===
228
+ if showResult:
229
+ print(f"{int(y_input.size(0))}: Path {closest_index} ({round(float(distances[closest_index]), 3)}) "
230
+ f"got animation {animation_type} ({round(float(type_softmax[animation_type]), 3)}%) "
231
+ f"with parameters {[round(num, 2) for num in pred_parameters.tolist()]}")
232
+
233
+ i += 1
234
+
235
+ return y_input
236
+
237
+
238
+ class PositionalEncoding(nn.Module):
239
+ def __init__(self, dim_model, dropout_p, max_len=5000):
240
+ """
241
+ Initializes the PositionalEncoding module which injects information about the relative or absolute position
242
+ of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the
243
+ two can be summed. Uses a sinusoidal pattern for positional encoding.
244
+
245
+ Args:
246
+ dim_model (int): The dimension of the embeddings and the expected dimension of the positional encoding.
247
+ dropout_p (float): Dropout probability to be applied to the summed embeddings and positional encodings.
248
+ max_len (int): The max length of the sequences for which positional encodings are precomputed and stored.
249
+ """
250
+ super(PositionalEncoding, self).__init__()
251
+ self.dropout = nn.Dropout(p=dropout_p)
252
+
253
+ position = torch.arange(max_len).unsqueeze(1)
254
+ div_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0) / dim_model))
255
+ pos_encoding = torch.zeros(max_len, 1, dim_model)
256
+ pos_encoding[:, 0, 0::2] = torch.sin(position * div_term)
257
+ pos_encoding[:, 0, 1::2] = torch.cos(position * div_term)
258
+
259
+ self.register_buffer('pos_encoding', pos_encoding)
260
+
261
+ def forward(self, embedding: torch.Tensor) -> torch.Tensor:
262
+ """
263
+ Applies positional encoding to the input embeddings and applies dropout.
264
+
265
+ Args:
266
+ embedding (torch.Tensor): The input embeddings with shape [batch_size, seq_len, dim_model]
267
+
268
+ Returns:
269
+ torch.Tensor: The embeddings with positional encoding applied, and dropout, having the same shape as the
270
+ input token embeddings [seq_len, batch_size, dim_model].
271
+ """
272
+ return self.dropout(embedding + self.pos_encoding[:embedding.size(0), :])
models/animation_transformer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12ae92d0b1a5ada8a8681122f76ea7c4e6b3fdf0169dd4b3a5d908899e563f86
3
+ size 60658902
models/reward_function_mode_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49ea58f01ad6a281e005b9b2793e53f901037402c96411f444ed9630fff05fbf
3
+ size 111027985
src/postprocessing/__init__.py ADDED
File without changes
src/postprocessing/get_style_attributes.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from svgpathtools import svg2paths
2
+ import pandas as pd
3
+ import numpy as np
4
+ from xml.dom import minidom
5
+
6
+ pd.options.mode.chained_assignment = None # default='warn'
7
+
8
+
9
+ def get_style_attributes_svg(file):
10
+ """ Get style attributes of an SVG.
11
+
12
+ Args:
13
+ file (str): Path of SVG file.
14
+
15
+ Returns:
16
+ pd.DataFrame: Dataframe containing the attributes of each path.
17
+
18
+ """
19
+ local_styles = get_local_style_attributes(file)
20
+ global_styles = get_global_style_attributes(file)
21
+ global_group_styles = get_global_group_style_attributes(file)
22
+ return combine_style_attributes(local_styles, global_styles, global_group_styles)
23
+
24
+
25
+ def get_style_attributes_path(file, animation_id, attribute):
26
+ """ Get style attributes of a specific path in an SVG.
27
+
28
+ Args:
29
+ file (str): Path of SVG file.
30
+ animation_id (int): ID of element.
31
+ attribute (str): One of the following: fill, stroke, stroke_width, opacity, stroke_opacity.
32
+
33
+ Returns:
34
+ str: Attribute of specific path.
35
+
36
+ """
37
+ styles = get_style_attributes_svg(file)
38
+ styles_animation_id = styles[styles["animation_id"] == str(animation_id)]
39
+ return styles_animation_id.iloc[0][attribute]
40
+
41
+
42
+ def parse_svg(file):
43
+ """ Parse a SVG file.
44
+
45
+ Args:
46
+ file (str): Path of SVG file.
47
+
48
+ Returns:
49
+ list, list: List of path objects, list of dictionaries containing the attributes of each path.
50
+
51
+ """
52
+ paths, attrs = svg2paths(file)
53
+ return paths, attrs
54
+
55
+
56
+ def get_local_style_attributes(file):
57
+ """ Generate dataframe containing local style attributes of an SVG.
58
+
59
+ Args:
60
+ file (str): Path of SVG file.
61
+
62
+ Returns:
63
+ pd.DataFrame: Dataframe containing filename, animation_id, class, fill, stroke, stroke_width, opacity, stroke_opacity.
64
+
65
+ """
66
+ return pd.DataFrame.from_records(_get_local_style_attributes(file))
67
+
68
+
69
+ def _get_local_style_attributes(file):
70
+ try:
71
+ _, attributes = parse_svg(file)
72
+ except:
73
+ print(f'{file}: Attributes not defined.')
74
+ for i, attr in enumerate(attributes):
75
+ animation_id = attr['animation_id']
76
+ class_ = ''
77
+ fill = '#000000'
78
+ stroke = '#000000'
79
+ stroke_width = '0'
80
+ opacity = '1.0'
81
+ stroke_opacity = '1.0'
82
+
83
+ if 'style' in attr:
84
+ a = attr['style']
85
+ if a.find('fill') != -1:
86
+ fill = a.split('fill:', 1)[-1].split(';', 1)[0]
87
+ if a.find('stroke') != -1:
88
+ stroke = a.split('stroke:', 1)[-1].split(';', 1)[0]
89
+ if a.find('stroke-width') != -1:
90
+ stroke_width = a.split('stroke-width:', 1)[-1].split(';', 1)[0]
91
+ if a.find('opacity') != -1:
92
+ opacity = a.split('opacity:', 1)[-1].split(';', 1)[0]
93
+ if a.find('stroke-opacity') != -1:
94
+ stroke_opacity = a.split('stroke-opacity:', 1)[-1].split(';', 1)[0]
95
+ else:
96
+ if 'fill' in attr:
97
+ fill = attr['fill']
98
+ if 'stroke' in attr:
99
+ stroke = attr['stroke']
100
+ if 'stroke-width' in attr:
101
+ stroke_width = attr['stroke-width']
102
+ if 'opacity' in attr:
103
+ opacity = attr['opacity']
104
+ if 'stroke-opacity' in attr:
105
+ stroke_opacity = attr['stroke-opacity']
106
+
107
+ if 'class' in attr:
108
+ class_ = attr['class']
109
+
110
+ # transform None and RGB to hex
111
+ if '#' not in fill and fill != '':
112
+ fill = transform_to_hex(fill)
113
+ if '#' not in stroke and stroke != '':
114
+ stroke = transform_to_hex(stroke)
115
+
116
+ yield dict(filename=file.split('.svg')[0], animation_id=animation_id, class_=class_, fill=fill, stroke=stroke,
117
+ stroke_width=stroke_width, opacity=opacity, stroke_opacity=stroke_opacity)
118
+
119
+
120
+ def get_global_style_attributes(file):
121
+ """ Generate dataframe containing global style attributes of an SVG.
122
+
123
+ Args:
124
+ file (str): Path of SVG file.
125
+
126
+ Returns:
127
+ pd.DataFrame: Dataframe containing filename, class, fill, stroke, stroke_width, opacity, stroke_opacity.
128
+
129
+ """
130
+ return pd.DataFrame.from_records(_get_global_style_attributes(file))
131
+
132
+
133
+ def _get_global_style_attributes(file):
134
+ doc = minidom.parse(file)
135
+ style = doc.getElementsByTagName('style')
136
+ for i, attr in enumerate(style):
137
+ a = attr.toxml()
138
+ for j in range(0, len(a.split(';}')) - 1):
139
+ fill = ''
140
+ stroke = ''
141
+ stroke_width = ''
142
+ opacity = ''
143
+ stroke_opacity = ''
144
+ attr = a.split(';}')[j]
145
+ class_ = attr.split('.', 1)[-1].split('{', 1)[0]
146
+ if attr.find('fill:') != -1:
147
+ fill = attr.split('fill:', 1)[-1].split(';', 1)[0]
148
+ if attr.find('stroke:') != -1:
149
+ stroke = attr.split('stroke:', 1)[-1].split(';', 1)[0]
150
+ if attr.find('stroke-width:') != -1:
151
+ stroke_width = attr.split('stroke-width:', 1)[-1].split(';', 1)[0]
152
+ if attr.find('opacity:') != -1:
153
+ opacity = attr.split('opacity:', 1)[-1].split(';', 1)[0]
154
+ if attr.find('stroke-opacity:') != -1:
155
+ stroke_opacity = attr.split('stroke-opacity:', 1)[-1].split(';', 1)[0]
156
+
157
+ # transform None and RGB to hex
158
+ if '#' not in fill and fill != '':
159
+ fill = transform_to_hex(fill)
160
+ if '#' not in stroke and stroke != '':
161
+ stroke = transform_to_hex(stroke)
162
+
163
+ yield dict(filename=file.split('.svg')[0], class_=class_, fill=fill, stroke=stroke,
164
+ stroke_width=stroke_width, opacity=opacity, stroke_opacity=stroke_opacity)
165
+
166
+
167
+ def get_global_group_style_attributes(file):
168
+ """ Generate dataframe containing global style attributes defined through <g> tags of an SVG.
169
+
170
+ Args:
171
+ file (str): Path of SVG file.
172
+
173
+ Returns:
174
+ pd.DataFrame: Dataframe containing filename, href, animation_id, fill, stroke, stroke_width, opacity, stroke_opacity.
175
+
176
+ """
177
+ df_group_animation_id_matching = pd.DataFrame.from_records(_get_group_animation_id_matching(file))
178
+
179
+ df_group_attributes = pd.DataFrame.from_records(_get_global_group_style_attributes(file))
180
+ df_group_attributes.drop_duplicates(inplace=True)
181
+ df_group_attributes.replace("", float("NaN"), inplace=True)
182
+ df_group_attributes.dropna(thresh=3, inplace=True)
183
+
184
+ if "href" in df_group_attributes.columns:
185
+ df_group_attributes.dropna(subset=["href"], inplace=True)
186
+
187
+ if df_group_attributes.empty:
188
+ return df_group_attributes
189
+ else:
190
+ return df_group_animation_id_matching.merge(df_group_attributes, how='left', on=['filename', 'href'])
191
+
192
+
193
+ def _get_global_group_style_attributes(file):
194
+ doc = minidom.parse(file)
195
+ groups = doc.getElementsByTagName('g')
196
+ for i, _ in enumerate(groups):
197
+ style = groups[i].getAttribute('style')
198
+ href = ''
199
+ fill = ''
200
+ stroke = ''
201
+ stroke_width = ''
202
+ opacity = ''
203
+ stroke_opacity = ''
204
+ if len(groups[i].getElementsByTagName('use')) != 0:
205
+ href = groups[i].getElementsByTagName('use')[0].getAttribute('xlink:href')
206
+ if style != '':
207
+ attributes = style.split(';')
208
+ for j, _ in enumerate(attributes):
209
+ attr = attributes[j]
210
+ if attr.find('fill:') != -1:
211
+ fill = attr.split('fill:', 1)[-1].split(';', 1)[0]
212
+ if attr.find('stroke:') != -1:
213
+ stroke = attr.split('stroke:', 1)[-1].split(';', 1)[0]
214
+ if attr.find('stroke-width:') != -1:
215
+ stroke_width = attr.split('stroke-width:', 1)[-1].split(';', 1)[0]
216
+ if attr.find('opacity:') != -1:
217
+ opacity = attr.split('opacity:', 1)[-1].split(';', 1)[0]
218
+ if attr.find('stroke-opacity:') != -1:
219
+ stroke_opacity = attr.split('stroke-opacity:', 1)[-1].split(';', 1)[0]
220
+ else:
221
+ fill = groups[i].getAttribute('fill')
222
+ stroke = groups[i].getAttribute('stroke')
223
+ stroke_width = groups[i].getAttribute('stroke-width')
224
+ opacity = groups[i].getAttribute('opacity')
225
+ stroke_opacity = groups[i].getAttribute('stroke-opacity')
226
+
227
+ # transform None and RGB to hex
228
+ if '#' not in fill and fill != '':
229
+ fill = transform_to_hex(fill)
230
+ if '#' not in stroke and stroke != '':
231
+ stroke = transform_to_hex(stroke)
232
+
233
+ yield dict(filename=file.split('.svg')[0], href=href.replace('#', ''), fill=fill, stroke=stroke,
234
+ stroke_width=stroke_width, opacity=opacity, stroke_opacity=stroke_opacity)
235
+
236
+
237
+ def _get_group_animation_id_matching(file):
238
+ doc = minidom.parse(file)
239
+ try:
240
+ symbol = doc.getElementsByTagName('symbol')
241
+ for i, _ in enumerate(symbol):
242
+ href = symbol[i].getAttribute('id')
243
+ animation_id = symbol[i].getElementsByTagName('path')[0].getAttribute('animation_id')
244
+ yield dict(filename=file.split('.svg')[0], href=href, animation_id=animation_id)
245
+ except:
246
+ defs = doc.getElementsByTagName('defs')
247
+ for i, _ in enumerate(defs):
248
+ href = defs[i].getElementsByTagName('symbol')[0].getAttribute('id')
249
+ animation_id = defs[i].getElementsByTagName('clipPath')[0].getElementsByTagName('path')[0].getAttribute('animation_id')
250
+ yield dict(filename=file.split('.svg')[0], href=href, animation_id=animation_id)
251
+
252
+
253
+ def combine_style_attributes(df_local, df_global, df_global_groups):
254
+ """ Combine local und global style attributes. Global attributes have priority.
255
+
256
+ Args:
257
+ df_local (pd.DataFrame): Dataframe with local style attributes.
258
+ df_global (pd.DataFrame): Dataframe with global style attributes.
259
+ df_global_groups (pd.DataFrame): Dataframe with global style attributes defined through <g> tags.
260
+
261
+ Returns:
262
+ pd.DataFrame: Dataframe with all style attributes.
263
+
264
+ """
265
+ if df_global.empty and df_global_groups.empty:
266
+ df_local.insert(loc=3, column='href', value="")
267
+ return df_local
268
+
269
+ if not df_global.empty:
270
+ df = df_local.merge(df_global, how='left', on=['filename', 'class_'])
271
+ df_styles = df[["filename", "animation_id", "class_"]]
272
+ df_styles["fill"] = _combine_columns(df, "fill")
273
+ df_styles["stroke"] = _combine_columns(df, "stroke")
274
+ df_styles["stroke_width"] = _combine_columns(df, "stroke_width")
275
+ df_styles["opacity"] = _combine_columns(df, "opacity")
276
+ df_styles["stroke_opacity"] = _combine_columns(df, "stroke_opacity")
277
+ df_local = df_styles.copy(deep=True)
278
+ if not df_global_groups.empty:
279
+ df = df_local.merge(df_global_groups, how='left', on=['filename', 'animation_id'])
280
+ df_styles = df[["filename", "animation_id", "class_", "href"]]
281
+ df_styles["href"] = df_styles["href"].fillna('')
282
+ df_styles["fill"] = _combine_columns(df, "fill")
283
+ df_styles["stroke"] = _combine_columns(df, "stroke")
284
+ df_styles["stroke_width"] = _combine_columns(df, "stroke_width")
285
+ df_styles["opacity"] = _combine_columns(df, "opacity")
286
+ df_styles["stroke_opacity"] = _combine_columns(df, "stroke_opacity")
287
+
288
+ return df_styles
289
+
290
+
291
+ def _combine_columns(df, col_name):
292
+ col = np.where(~df[f"{col_name}_y"].astype(str).isin(["", "nan"]),
293
+ df[f"{col_name}_y"], df[f"{col_name}_x"])
294
+ return col
295
+
296
+
297
+ def transform_to_hex(rgb):
298
+ """ Transform RGB to hex.
299
+
300
+ Args:
301
+ rgb (str): RGB code.
302
+
303
+ Returns:
304
+ str: Hex code.
305
+
306
+ """
307
+ if rgb == 'none':
308
+ return '#000000'
309
+ if 'rgb' in rgb:
310
+ rgb = rgb.replace('rgb(', '').replace(')', '')
311
+ if '%' in rgb:
312
+ rgb = rgb.replace('%', '')
313
+ rgb_list = rgb.split(',')
314
+ r_value, g_value, b_value = [int(float(i) / 100 * 255) for i in rgb_list]
315
+ else:
316
+ rgb_list = rgb.split(',')
317
+ r_value, g_value, b_value = [int(float(i)) for i in rgb_list]
318
+ return '#%02x%02x%02x' % (r_value, g_value, b_value)
src/postprocessing/get_svg_color_tendency.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.postprocessing.get_style_attributes import get_style_attributes_svg
2
+
3
+
4
+ def get_svg_color_tendencies(file):
5
+ """ Get two most frequent colors in SVG. Black and white are excluded.
6
+
7
+ Args:
8
+ file (str): Path of SVG file.
9
+
10
+ Returns:
11
+ list: List of two most frequent colors in SVG.
12
+
13
+ """
14
+ df = get_style_attributes_svg(file)
15
+ df = df[~df['fill'].isin(['#FFFFFF', '#ffffff'])]
16
+ colour_tendencies_list = df["fill"].value_counts()[:2].index.tolist()
17
+ colour_tendencies_list.append("#000000")
18
+ return colour_tendencies_list[:2]
19
+
src/postprocessing/get_svg_size_pos.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xml.dom import minidom
2
+ from svgpathtools import svg2paths
3
+
4
+
5
+ def get_svg_size(file):
6
+ """ Get width and height of an SVG.
7
+
8
+ Args:
9
+ file (str): Path of SVG file.
10
+
11
+ Returns:
12
+ float, float: Width and height of SVG.
13
+
14
+ """
15
+ doc = minidom.parse(file)
16
+ width = doc.getElementsByTagName('svg')[0].getAttribute('width')
17
+ height = doc.getElementsByTagName('svg')[0].getAttribute('height')
18
+
19
+ if width != "" and height != "":
20
+ if not width[-1].isdigit():
21
+ width = width.replace('px', '').replace('pt', '')
22
+ if not height[-1].isdigit():
23
+ height = height.replace('px', '').replace('pt', '')
24
+
25
+ if width == "" or height == "" or not width[-1].isdigit() or not height[-1].isdigit():
26
+ # get bounding box of svg
27
+ xmin_svg, xmax_svg, ymin_svg, ymax_svg = 100, -100, 100, -100
28
+ paths, _ = svg2paths(file)
29
+ for path in paths:
30
+ xmin, xmax, ymin, ymax = path.bbox()
31
+ if xmin < xmin_svg:
32
+ xmin_svg = xmin
33
+ if xmax > xmax_svg:
34
+ xmax_svg = xmax
35
+ if ymin < ymin_svg:
36
+ ymin_svg = ymin
37
+ if ymax > ymax_svg:
38
+ ymax_svg = ymax
39
+ width = xmax_svg - xmin_svg
40
+ height = ymax_svg - ymin_svg
41
+
42
+ return float(width), float(height)
43
+
44
+
45
+ def get_svg_bbox(file):
46
+ """ Get bounding box coordinates of an SVG.
47
+
48
+ xmin, ymin: Upper left corner.
49
+
50
+ xmax, ymax: Lower right corner.
51
+
52
+ Args:
53
+ file (str): Path of SVG file.
54
+
55
+ Returns:
56
+ float, float, float, float: Bounding box of SVG (xmin, xmax, ymin, ymax).
57
+
58
+ """
59
+ try:
60
+ paths, _ = svg2paths(file)
61
+ except Exception as e:
62
+ print(f"{file}: svg2path fails. SVG bbox is computed by using get_svg_size. {e}")
63
+ width, height = get_svg_size(file)
64
+ return 0, width, 0, height
65
+
66
+ xmin_svg, xmax_svg, ymin_svg, ymax_svg = 100, -100, 100, -100
67
+ for path in paths:
68
+ try:
69
+ xmin, xmax, ymin, ymax = path.bbox()
70
+ if xmin < xmin_svg:
71
+ xmin_svg = xmin
72
+ if xmax > xmax_svg:
73
+ xmax_svg = xmax
74
+ if ymin < ymin_svg:
75
+ ymin_svg = ymin
76
+ if ymax > ymax_svg:
77
+ ymax_svg = ymax
78
+ except:
79
+ pass
80
+
81
+ return xmin_svg, xmax_svg, ymin_svg, ymax_svg
82
+
83
+
84
+ def get_path_bbox(file, animation_id):
85
+ """ Get bounding box coordinates of a path in an SVG.
86
+
87
+ Args:
88
+ file (str): Path of SVG file.
89
+ animation_id (int): ID of element.
90
+
91
+ Returns:
92
+ float, float, float, float: Bounding box of path (xmin, xmax, ymin, ymax).
93
+
94
+ """
95
+ try:
96
+ paths, attributes = svg2paths(file)
97
+ except Exception as e1:
98
+ print(f"{file}, animation ID {animation_id}: svg2path fails and path bbox cannot be computed. {e1}")
99
+ return 0, 0, 0, 0
100
+
101
+ for i, path in enumerate(paths):
102
+ if attributes[i]["animation_id"] == str(animation_id):
103
+ try:
104
+ xmin, xmax, ymin, ymax = path.bbox()
105
+ return xmin, xmax, ymin, ymax
106
+ except Exception as e2:
107
+ print(f"{file}, animation ID {animation_id}: svg2path fails and path bbox cannot be computed. {e2}")
108
+ return 0, 0, 0, 0
109
+
110
+
111
+ def get_midpoint_of_path_bbox(file, animation_id):
112
+ """ Get midpoint of bounding box of path.
113
+
114
+ Args:
115
+ file (str): Path of SVG file.
116
+ animation_id (int): ID of element.
117
+
118
+ Returns:
119
+ float, float: Midpoint of bounding box of path (x_midpoint, y_midpoint).
120
+
121
+ """
122
+ try:
123
+ xmin, xmax, ymin, ymax = get_path_bbox(file, animation_id)
124
+ x_midpoint = (xmin + xmax) / 2
125
+ y_midpoint = (ymin + ymax) / 2
126
+
127
+ return x_midpoint, y_midpoint
128
+ except Exception as e:
129
+ print(f'Could not get midpoint for file {file} and animation ID {animation_id}: {e}')
130
+ return 0, 0
131
+
132
+
133
+ def get_bbox_of_multiple_paths(file, animation_ids):
134
+ """ Get bounding box of multiple paths in an SVG.
135
+
136
+ Args:
137
+ file (str): Path of SVG file.
138
+ animation_ids (list(int)): List of element IDs.
139
+
140
+ Returns:
141
+ float, float, float, float: Bounding box of given paths (xmin, xmax, ymin, ymax).
142
+
143
+ """
144
+ try:
145
+ paths, attributes = svg2paths(file)
146
+ except Exception as e1:
147
+ print(f"{file}: svg2path fails and bbox of multiple paths cannot be computed. {e1}")
148
+ return 0, 0, 0, 0
149
+
150
+ xmin_paths, xmax_paths, ymin_paths, ymax_paths = 100, -100, 100, -100
151
+
152
+ for i, path in enumerate(paths):
153
+ if attributes[i]["animation_id"] in list(map(str, animation_ids)):
154
+ try:
155
+ xmin, xmax, ymin, ymax = path.bbox()
156
+ if xmin < xmin_paths:
157
+ xmin_paths = xmin
158
+ if xmax > xmax_paths:
159
+ xmax_paths = xmax
160
+ if ymin < ymin_paths:
161
+ ymin_paths = ymin
162
+ if ymax > ymax_paths:
163
+ ymax_paths = ymax
164
+ except:
165
+ pass
166
+
167
+ return xmin_paths, xmax_paths, ymin_paths, ymax_paths
168
+
169
+
170
+ def get_relative_path_pos(file, animation_id):
171
+ """ Get relative position of a path in an SVG.
172
+
173
+ Args:
174
+ file (string): Path of SVG file.
175
+ animation_id (int): ID of element.
176
+
177
+ Returns:
178
+ float, float: Relative x- and y-position of path.
179
+
180
+ """
181
+ path_midpoint_x, path_midpoint_y = get_midpoint_of_path_bbox(file, animation_id)
182
+ svg_xmin, svg_xmax, svg_ymin, svg_ymax = get_svg_bbox(file)
183
+ rel_x_position = (path_midpoint_x - svg_xmin) / (svg_xmax - svg_xmin)
184
+ rel_y_position = (path_midpoint_y - svg_ymin) / (svg_ymax - svg_ymin)
185
+ return rel_x_position, rel_y_position
186
+
187
+
188
+ def get_relative_pos_to_bounding_box_of_animated_paths(file, animation_id, animated_animation_ids):
189
+ """ Get relative position of a path to the bounding box of all animated paths.
190
+
191
+ Args:
192
+ file (str): Path of SVG file.
193
+ animation_id (int): ID of element.
194
+ animated_animation_ids (list(int)): List of animated element IDs.
195
+
196
+ Returns:
197
+ float, float: Relative x- and y-position of path to bounding box of all animated paths.
198
+
199
+ """
200
+ path_midpoint_x, path_midpoint_y = get_midpoint_of_path_bbox(file, animation_id)
201
+ xmin, xmax, ymin, ymax = get_bbox_of_multiple_paths(file, animated_animation_ids)
202
+ try:
203
+ rel_x_position = (path_midpoint_x - xmin) / (xmax - xmin)
204
+ except Exception as e1:
205
+ rel_x_position = 0.5
206
+ print(f"{file}, animation_id {animation_id}, animated_animation_ids {animated_animation_ids}: rel_x_position not defined and set to 0.5. {e1}")
207
+ try:
208
+ rel_y_position = (path_midpoint_y - ymin) / (ymax - ymin)
209
+ except Exception as e2:
210
+ rel_y_position = 0.5
211
+ print(f"{file}, animation_id {animation_id}, animated_animation_ids {animated_animation_ids}: rel_y_position not defined and set to 0.5. {e2}")
212
+
213
+ return rel_x_position, rel_y_position
214
+
215
+
216
+ def get_relative_path_size(file, animation_id):
217
+ """ Get relative size of a path in an SVG.
218
+
219
+ Args:
220
+ file (str): Path of SVG file.
221
+ animation_id (int): ID of element.
222
+
223
+ Returns:
224
+ float, float: Relative width and height of path.
225
+
226
+ """
227
+ svg_xmin, svg_xmax, svg_ymin, svg_ymax = get_svg_bbox(file)
228
+ svg_width = float(svg_xmax - svg_xmin)
229
+ svg_height = float(svg_ymax - svg_ymin)
230
+
231
+ path_xmin, path_xmax, path_ymin, path_ymax = get_path_bbox(file, animation_id)
232
+ path_width = float(path_xmax - path_xmin)
233
+ path_height = float(path_ymax - path_ymin)
234
+
235
+ rel_width = path_width / svg_width
236
+ rel_height = path_height / svg_height
237
+
238
+ return rel_width, rel_height
239
+
240
+
241
+ def get_begin_values_by_starting_pos(file, animation_ids, start=1, step=0.5):
242
+ """ Get begin values by sorting from left to right.
243
+
244
+ Args:
245
+ file (str): Path of SVG file.
246
+ animation_ids (list(int)): List of element IDs.
247
+ start (float): First begin value.
248
+ step (float): Time between begin values.
249
+
250
+ Returns:
251
+ list: Begin values of element IDs.
252
+
253
+ """
254
+ starting_point_list = []
255
+ begin_list = []
256
+ begin = start
257
+ for i in range(len(animation_ids)):
258
+ x, _, _, _ = get_path_bbox(file, animation_ids[i]) # get x value of upper left corner
259
+ starting_point_list.append(x)
260
+ begin_list.append(begin)
261
+ begin = begin + step
262
+
263
+ animation_id_order = [z for _, z in sorted(zip(starting_point_list, range(len(starting_point_list))))]
264
+ begin_values = [z for _, z in sorted(zip(animation_id_order, begin_list))]
265
+
266
+ return begin_values
267
+
268
+
src/postprocessing/insert_animation.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from xml.dom import minidom
3
+ from pathlib import Path
4
+ from src.postprocessing.get_svg_size_pos import get_midpoint_of_path_bbox, get_begin_values_by_starting_pos
5
+ from src.postprocessing.transform_animation_predictor_output import transform_animation_predictor_output
6
+
7
+
8
+ def create_animated_svg(file, animation_ids, model_output, filename_suffix="", save=True):
9
+ """ Insert multiple animation statements.
10
+
11
+ Args:
12
+ file (str): Path of SVG file.
13
+ animation_ids (list[int]): List of element IDs that get animated.
14
+ model_output (ndarray): Array of 13 dimensional arrays with animation predictor model output.
15
+ filename_suffix (str): Suffix of animated SVG.
16
+
17
+ Returns:
18
+ list(float): List of begin values of elements in SVG.
19
+ xml.dom.minidom.Document: Parsed file with inserted animation statements.
20
+
21
+ """
22
+ doc = svg_to_doc(file)
23
+ begin_values = get_begin_values_by_starting_pos(file, animation_ids, start=1, step=0.25)
24
+ for i in range(len(animation_ids)):
25
+ if not (model_output[i][:6] == np.array([0] * 6)).all():
26
+ try: # there are some paths that can't be embedded and don't have style attributes
27
+ output_dict = transform_animation_predictor_output(file, animation_ids[i], model_output[i])
28
+ output_dict["begin"] = begin_values[i]
29
+ if output_dict["type"] == "translate":
30
+ doc = insert_translate_statement(doc, animation_ids[i], output_dict)
31
+ if output_dict["type"] == "scale":
32
+ doc = insert_scale_statement(doc, animation_ids[i], output_dict, file)
33
+ if output_dict["type"] == "rotate":
34
+ doc = insert_rotate_statement(doc, animation_ids[i], output_dict)
35
+ if output_dict["type"] in ["skewX", "skewY"]:
36
+ doc = insert_skew_statement(doc, animation_ids[i], output_dict)
37
+ if output_dict["type"] == "fill":
38
+ doc = insert_fill_statement(doc, animation_ids[i], output_dict)
39
+ if output_dict["type"] in ["opacity"]:
40
+ doc = insert_opacity_statement(doc, animation_ids[i], output_dict)
41
+ except Exception as e:
42
+ print(f"File {file}, animation ID {animation_ids[i]} can't be animated. {e}")
43
+ pass
44
+
45
+ if save:
46
+ filename = file.split('/')[-1].replace(".svg", "") + "_animated"
47
+ save_animated_svg(doc, filename)
48
+
49
+ return begin_values, doc
50
+
51
+
52
+ def svg_to_doc(file):
53
+ """ Parse an SVG file.
54
+
55
+ Args:
56
+ file (string): Path of SVG file.
57
+
58
+ Returns:
59
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
60
+
61
+ """
62
+ return minidom.parse(file)
63
+
64
+
65
+ def save_animated_svg(doc, filename):
66
+ """ Save animated SVGs to folder animated_svgs.
67
+
68
+ Args:
69
+ doc (xml.dom.minidom.Document): Parsed file.
70
+ filename (str): Name of output file.
71
+
72
+ """
73
+ Path("data/animated_svgs").mkdir(parents=True, exist_ok=True)
74
+
75
+ with open('data/animated_svgs/' + filename + '.svg', 'wb') as f:
76
+ f.write(doc.toprettyxml(encoding="iso-8859-1"))
77
+
78
+
79
+ def insert_translate_statement(doc, animation_id, model_output_dict):
80
+ """ Insert translate statement.
81
+
82
+ Args:
83
+ doc (xml.dom.minidom.Document): Parsed file.
84
+ animation_id (int): ID of element that gets animated.
85
+ model_output_dict (dict): Dictionary containing animation statement.
86
+
87
+ Returns:
88
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
89
+
90
+ """
91
+ pre_animations = []
92
+ opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
93
+ pre_animations.append(create_animation_statement(opacity_dict_1))
94
+ pre_animations.append(create_animation_statement(opacity_dict_2))
95
+
96
+ animation = create_animation_statement(model_output_dict)
97
+ doc = insert_animation(doc, animation_id, animation, pre_animations)
98
+ return doc
99
+
100
+
101
+ def insert_scale_statement(doc, animation_id, model_output_dict, file):
102
+ """ Insert scale statement.
103
+
104
+ Args:
105
+ doc (xml.dom.minidom.Document): Parsed file.
106
+ animation_id (int): ID of element that gets animated.
107
+ model_output_dict (dict): Dictionary containing animation statement.
108
+ file (str): Path of SVG file. Needed to get midpoint of path bbox to suppress simultaneous translate movement.
109
+
110
+ Returns:
111
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
112
+
113
+ """
114
+ pre_animations = []
115
+ opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
116
+ pre_animations.append(create_animation_statement(opacity_dict_1))
117
+ pre_animations.append(create_animation_statement(opacity_dict_2))
118
+
119
+ x_midpoint, y_midpoint = get_midpoint_of_path_bbox(file, animation_id)
120
+ if model_output_dict["from_"] > 1:
121
+ model_output_dict["from_"] = 2
122
+ pre_animation_from = f"-{x_midpoint} -{y_midpoint}" # negative midpoint
123
+ else:
124
+ model_output_dict["from_"] = 0
125
+ pre_animation_from = f"{x_midpoint} {y_midpoint}" # positive midpoint
126
+
127
+ translate_pre_animation_dict = {"type": "translate",
128
+ "begin": model_output_dict["begin"],
129
+ "dur": model_output_dict["dur"],
130
+ "from_": pre_animation_from,
131
+ "to": "0 0",
132
+ "fill": "freeze"}
133
+ pre_animations.append(create_animation_statement(translate_pre_animation_dict))
134
+
135
+ animation = create_animation_statement(model_output_dict) + ' additive="sum" '
136
+ doc = insert_animation(doc, animation_id, animation, pre_animations)
137
+ return doc
138
+
139
+
140
+ def insert_rotate_statement(doc, animation_id, model_output_dict):
141
+ """ Insert rotate statement.
142
+
143
+ Args:
144
+ doc (xml.dom.minidom.Document): Parsed file.
145
+ animation_id (int): ID of element that gets animated.
146
+ model_output_dict (dict): Dictionary containing animation statement.
147
+
148
+ Returns:
149
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
150
+
151
+ """
152
+ pre_animations = []
153
+ opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
154
+ pre_animations.append(create_animation_statement(opacity_dict_1))
155
+ pre_animations.append(create_animation_statement(opacity_dict_2))
156
+
157
+ animation = create_animation_statement(model_output_dict)
158
+ doc = insert_animation(doc, animation_id, animation, pre_animations)
159
+ return doc
160
+
161
+
162
+ def insert_skew_statement(doc, animation_id, model_output_dict):
163
+ """ Insert skew statement.
164
+
165
+ Args:
166
+ doc (xml.dom.minidom.Document): Parsed file.
167
+ animation_id (int): ID of element that gets animated.
168
+ model_output_dict (dict): Dictionary containing animation statement.
169
+
170
+ Returns:
171
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
172
+
173
+ """
174
+ pre_animations = []
175
+ opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
176
+ pre_animations.append(create_animation_statement(opacity_dict_1))
177
+ pre_animations.append(create_animation_statement(opacity_dict_2))
178
+
179
+ animation = create_animation_statement(model_output_dict)
180
+ doc = insert_animation(doc, animation_id, animation, pre_animations)
181
+ return doc
182
+
183
+
184
+ def insert_fill_statement(doc, animation_id, model_output_dict):
185
+ """ Insert fill statement.
186
+
187
+ Args:
188
+ doc (xml.dom.minidom.Document): Parsed file
189
+ animation_id (int): ID of element that gets animated.
190
+ model_output_dict (dict): Dictionary containing animation statement.
191
+
192
+ Returns:
193
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
194
+
195
+ """
196
+ pre_animations = []
197
+ model_output_dict['dur'] = 2
198
+ if model_output_dict['begin'] < 2:
199
+ model_output_dict['begin'] = 0
200
+ else: # Wave
201
+ pre_animation_dict = {"type": "fill",
202
+ "begin": 0,
203
+ "dur": model_output_dict["begin"],
204
+ "from_": model_output_dict["to"],
205
+ "to": model_output_dict["from_"],
206
+ "fill": "remove"}
207
+ pre_animations.append(create_animation_statement(pre_animation_dict))
208
+
209
+ animation = create_animation_statement(model_output_dict)
210
+ doc = insert_animation(doc, animation_id, animation, pre_animations)
211
+ return doc
212
+
213
+
214
+ def insert_opacity_statement(doc, animation_id, model_output_dict):
215
+ """ Insert opacity statement.
216
+
217
+ Args:
218
+ doc (xml.dom.minidom.Document): Parsed file.
219
+ animation_id (int): ID of element that gets animated.
220
+ model_output_dict (dict): Dictionary containing animation statement.
221
+
222
+ Returns:
223
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
224
+
225
+ """
226
+ pre_animations = []
227
+ opacity_pre_animation_dict = {"type": "opacity",
228
+ "begin": "0",
229
+ "dur": model_output_dict["begin"],
230
+ "from_": "0",
231
+ "to": "0",
232
+ "fill": "remove"}
233
+ pre_animations.append(create_animation_statement(opacity_pre_animation_dict))
234
+
235
+ animation = create_animation_statement(model_output_dict)
236
+ doc = insert_animation(doc, animation_id, animation, pre_animations)
237
+ return doc
238
+
239
+
240
+ def insert_animation(doc, animation_id, animation, pre_animations=None):
241
+ """ Insert animation statements including pre-animation statements.
242
+
243
+ Args:
244
+ doc (xml.dom.minidom.Document): Parsed file.
245
+ animation_id (int): ID of element that gets animated.
246
+ animation (string): Animation that needs to be inserted.
247
+ pre_animations (list): List of animations that needs to be inserted before actual animation.
248
+
249
+ Returns:
250
+ xml.dom.minidom.Document: Parsed file with inserted animation statement.
251
+
252
+ """
253
+ elements = doc.getElementsByTagName('path') + doc.getElementsByTagName('circle') + doc.getElementsByTagName(
254
+ 'ellipse') + doc.getElementsByTagName('line') + doc.getElementsByTagName(
255
+ 'polygon') + doc.getElementsByTagName('polyline') + doc.getElementsByTagName(
256
+ 'rect') + doc.getElementsByTagName('text')
257
+
258
+ for element in elements:
259
+ if element.getAttribute('animation_id') == str(animation_id):
260
+ if pre_animations is not None:
261
+ for i in range(len(pre_animations)):
262
+ element.appendChild(doc.createElement(pre_animations[i]))
263
+ element.appendChild(doc.createElement(animation))
264
+
265
+ return doc
266
+
267
+
268
+ def create_animation_statement(animation_dict):
269
+ """ Set up animation statement from a dictionary.
270
+
271
+ Args:
272
+ animation_dict (dict): Dictionary that is transformed into animation statement.
273
+
274
+ Returns:
275
+ str: Animation statement.
276
+
277
+ """
278
+ if animation_dict["type"] in ["translate", "scale", "rotate", "skewX", "skewY"]:
279
+ return _create_animate_transform_statement(animation_dict)
280
+ elif animation_dict["type"] in ["fill", "opacity"]:
281
+ return _create_animate_statement(animation_dict)
282
+
283
+
284
+ def _create_animate_transform_statement(animation_dict):
285
+ """ Set up animation statement from model output for ANIMATETRANSFORM animations """
286
+ animation = f'animateTransform attributeName = "transform" attributeType = "XML" ' \
287
+ f'type = "{animation_dict["type"]}" ' \
288
+ f'begin = "{str(animation_dict["begin"])}" ' \
289
+ f'dur = "{str(animation_dict["dur"])}" ' \
290
+ f'from = "{str(animation_dict["from_"])}" ' \
291
+ f'to = "{str(animation_dict["to"])}" ' \
292
+ f'fill = "{str(animation_dict["fill"])}"'
293
+
294
+ return animation
295
+
296
+
297
+ def _create_animate_statement(animation_dict):
298
+ """ Set up animation statement from model output for ANIMATE animations """
299
+ animation = f'animate attributeName = "{animation_dict["type"]}" ' \
300
+ f'begin = "{str(animation_dict["begin"])}" ' \
301
+ f'dur = "{str(animation_dict["dur"])}" ' \
302
+ f'from = "{str(animation_dict["from_"])}" ' \
303
+ f'to = "{str(animation_dict["to"])}" ' \
304
+ f'fill = "{str(animation_dict["fill"])}"'
305
+
306
+ return animation
307
+
308
+
309
+ def create_opacity_pre_animation_dicts(animation_dict):
310
+ """ Set up pre_animation statements.
311
+
312
+ Args:
313
+ animation_dict (dict): Dictionary from animation that is needed to set up opacity pre-animations.
314
+
315
+ Returns:
316
+ str: Animation Statement.
317
+
318
+ """
319
+ opacity_pre_animation_dict_1 = {"type": "opacity",
320
+ "begin": "0",
321
+ "dur": animation_dict["begin"],
322
+ "from_": "0",
323
+ "to": "0",
324
+ "fill": "remove"}
325
+
326
+ opacity_pre_animation_dict_2 = {"type": "opacity",
327
+ "begin": animation_dict["begin"],
328
+ "dur": "0.5",
329
+ "from_": "0",
330
+ "to": "1",
331
+ "fill": "remove"}
332
+
333
+ return opacity_pre_animation_dict_1, opacity_pre_animation_dict_2
src/postprocessing/logo_0.svg ADDED
src/postprocessing/postprocessing.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import random
4
+ import os
5
+ import sys
6
+ from xml.dom import minidom
7
+ from collections import defaultdict
8
+
9
+ sys.path.append(os.getcwd())
10
+ from src.postprocessing.get_svg_size_pos import get_svg_bbox, get_path_bbox, get_midpoint_of_path_bbox
11
+ from src.postprocessing.get_style_attributes import get_style_attributes_path
12
+
13
+ random.seed(0)
14
+
15
+ filter_id = 0
16
+
17
+ def animate_logo(model_output: pd.DataFrame, logo_path: str):
18
+ logo_xmin, logo_xmax, logo_ymin, logo_ymax = get_svg_bbox(logo_path)
19
+ # ---- Normalize model output ----
20
+ animations_by_id = defaultdict(list)
21
+ for row in model_output.iterrows():
22
+ # Structure animations by animation id
23
+ animation_id = row[1]['animation_id']
24
+ output = row[1]['model_output']
25
+ animations_by_id[animation_id].append(output)
26
+ total_animations = []
27
+ for animation_id in animations_by_id.keys():
28
+ print(animation_id)
29
+ path_xmin, path_xmax, path_ymin, path_ymax = get_path_bbox(logo_path, animation_id)
30
+ xmin = logo_xmin - path_xmin
31
+ xmax = logo_xmax - path_xmax
32
+ ymin = logo_ymin - path_ymin
33
+ ymax = logo_ymax - path_ymax
34
+ # Structure animations by type (check first 10 parameters)
35
+ animations_by_type = defaultdict(list)
36
+ for animation in animations_by_id[animation_id]:
37
+ if animation[0] == 1:
38
+ # EOS
39
+ continue
40
+ try:
41
+ animation_type = animation[1:10].index(1)
42
+ animations_by_type[animation_type].append(animation)
43
+ except:
44
+ # No value found
45
+ print('Model output invalid: no animation type found')
46
+ return
47
+
48
+
49
+
50
+ for animation_type in animations_by_type.keys():
51
+ # Set up list of animations for later distribution
52
+ current_animations = []
53
+ # Sort animations by begin
54
+ animations_by_type[animation_type].sort(key=lambda l : l[10]) # Sort by begin
55
+ # For every animation, check consistency of begin and duration, then set parameters
56
+ for i in range(len(animations_by_type[animation_type])):
57
+ # Check if begin is equal to next animation's begin - in this case, set second begin to average of first and third animation
58
+ # Get next animation with different begin time
59
+ if len(animations_by_type[animation_type]) > 1:
60
+ j = 1
61
+ next_animation = animations_by_type[animation_type][j]
62
+ while (i + j) < len(animations_by_type[animation_type]) and animations_by_type[animation_type][i][10] == next_animation[10]:
63
+ j += 1
64
+ next_animation = animations_by_type[animation_type][j]
65
+ if j != 1:
66
+ # Get difference
67
+ difference = animations_by_type[animation_type][j][10] - animations_by_type[animation_type][i][10]
68
+ interval = difference / (j - i)
69
+ factor = 0
70
+ for a in range(i, j):
71
+ animations_by_type[animation_type][a][10] = animations_by_type[animation_type][i][10] + interval * factor
72
+ factor += 1
73
+ # Check if duration and begin of next animation are consistent - if not, shorten duration
74
+ if i < len(animations_by_type[animation_type]) - 1:
75
+ max_duration = animations_by_type[animation_type][i+1][10] - animations_by_type[animation_type][i][10]
76
+ if animations_by_type[animation_type][i][11] > max_duration:
77
+ animations_by_type[animation_type][i][11] = max_duration
78
+
79
+ # Get general parameters
80
+ begin = animations_by_type[animation_type][i][10]
81
+ dur = animations_by_type[animation_type][i][10]
82
+ # Check type and call method
83
+ if animation_type == 1:
84
+ # animation: translate
85
+ from_x = animations_by_type[animation_type][i][12]
86
+ from_y = animations_by_type[animation_type][i][13]
87
+ # Check if there is a next translate animation
88
+ if i < len(animations_by_type[animation_type]) - 1:
89
+ # animation endpoint is next translate animation's starting point
90
+ to_x = animations_by_type[animation_type][i+1][12]
91
+ to_y = animations_by_type[animation_type][i+1][13]
92
+ else:
93
+ # animation endpoint is final position of object
94
+ to_x = 0
95
+ to_y = 0
96
+ # Check if parameters are within boundary
97
+ if from_x < xmin:
98
+ from_x = xmin
99
+ elif from_x > xmax:
100
+ from_x = xmax
101
+ if from_y < ymin:
102
+ from_y = ymin
103
+ elif from_y > ymax:
104
+ from_y = ymax
105
+ if to_x < xmin:
106
+ to_x = xmin
107
+ elif to_x > xmax:
108
+ to_x = xmax
109
+ if to_y < ymin:
110
+ to_y = ymin
111
+ elif to_y > ymax:
112
+ to_y = ymax
113
+ # Append animation to list
114
+ current_animations.append(_animation_translate(animation_id, begin, dur, from_x, from_y, to_x, to_y))
115
+ elif animation_type == 2:
116
+ print('curve')
117
+ from_x = animations_by_type[animation_type][i][12]
118
+ from_y = animations_by_type[animation_type][i][13]
119
+ via_x = animations_by_type[animation_type][i][14]
120
+ via_y = animations_by_type[animation_type][i][15]
121
+ # Check if there is a next curve animation
122
+ if i < len(animations_by_type[animation_type]) - 1:
123
+ # animation endpoint is next curve animation's starting point
124
+ to_x = animations_by_type[animation_type][i+1][12]
125
+ to_y = animations_by_type[animation_type][i+1][13]
126
+ else:
127
+ # animation endpoint is final position of object
128
+ to_x = 0
129
+ to_y = 0
130
+ # Check if parameters are within boundary
131
+ if from_x < xmin:
132
+ from_x = xmin
133
+ elif from_x > xmax:
134
+ from_x = xmax
135
+ if from_y < ymin:
136
+ from_y = ymin
137
+ elif from_y > ymax:
138
+ from_y = ymax
139
+ if via_x < xmin:
140
+ via_x = xmin
141
+ elif via_x > xmax:
142
+ via_x = xmax
143
+ if via_y < ymin:
144
+ via_y = ymin
145
+ elif via_y > ymax:
146
+ via_y = ymax
147
+ if to_x < xmin:
148
+ to_x = xmin
149
+ elif to_x > xmax:
150
+ to_x = xmax
151
+ if to_y < ymin:
152
+ to_y = ymin
153
+ elif to_y > ymax:
154
+ to_y = ymax
155
+ # Append animation to list
156
+ current_animations.append(_animation_curve(animation_id, begin, dur, from_x, from_y, via_x, via_y, to_x, to_y))
157
+ elif animation_type == 3:
158
+ # animation: scale
159
+ from_f = animations_by_type[animation_type][i][16]
160
+ # Check if there is a next scale animation
161
+ if i < len(animations_by_type[animation_type]) - 1:
162
+ # animation endpoint is next scale animation's starting point
163
+ to_f = animations_by_type[animation_type][i+1][16]
164
+ else:
165
+ # animation endpoint is final position of object
166
+ to_f = 1
167
+ current_animations.append(_animation_scale(animation_id, begin, dur, from_f, to_f))
168
+ elif animation_type == 4:
169
+ # animation: rotate
170
+ from_degree = animations_by_type[animation_type][i][17]
171
+ # Get midpoints
172
+ midpoints = get_midpoint_of_path_bbox(logo_path, animation_id)
173
+ # Check if there is a next scale animation
174
+ if i < len(animations_by_type[animation_type]) - 1:
175
+ # animation endpoint is next scale animation's starting point
176
+ to_degree = animations_by_type[animation_type][i+1][17]
177
+ else:
178
+ # animation endpoint is final position of object
179
+ to_degree = 360
180
+ current_animations.append(_animation_rotate(animation_id, begin, dur, from_degree, to_degree, midpoints))
181
+ elif animation_type == 5:
182
+ # animation: skewX
183
+ from_x = animations_by_type[animation_type][i][18]
184
+ # Check if there is a next skewX animation
185
+ if i < len(animations_by_type[animation_type]) - 1:
186
+ # animation endpoint is next skewX animation's starting point
187
+ to_x = animations_by_type[animation_type][i+1][18]
188
+ else:
189
+ # animation endpoint is final position of object
190
+ to_x = 1
191
+ # Check if parameters are within boundary
192
+ if from_x < xmin:
193
+ from_x = xmin
194
+ elif from_x > xmax:
195
+ from_x = xmax
196
+ if to_x < xmin:
197
+ to_x = xmin
198
+ elif to_x > xmax:
199
+ to_x = xmax
200
+ current_animations.append(_animation_skewX(animation_id, begin, dur, from_x, to_x))
201
+ elif animation_type == 6:
202
+ # animation: skewY
203
+ from_y = animations_by_type[animation_type][i][19]
204
+ # Check if there is a next skewY animation
205
+ if i < len(animations_by_type[animation_type]) - 1:
206
+ # animation endpoint is next skewY animation's starting point
207
+ to_y = animations_by_type[animation_type][i+1][19]
208
+ else:
209
+ # animation endpoint is final position of object
210
+ to_y = 1
211
+ # Check if parameters are within boundary
212
+ if from_y < ymin:
213
+ from_y = ymin
214
+ elif from_y > ymax:
215
+ from_y = ymax
216
+ if to_y < ymin:
217
+ to_y = ymin
218
+ elif to_y > ymax:
219
+ to_y = ymax
220
+ current_animations.append(_animation_skewY(animation_id, begin, dur, from_y, to_y))
221
+ elif animation_type == 7:
222
+ # animation: fill
223
+ from_rgb = '#' + _convert_to_hex_str(animations_by_type[animation_type][i][20]) + _convert_to_hex_str(animations_by_type[animation_type][i][21]) + _convert_to_hex_str(animations_by_type[animation_type][i][22])
224
+ # Check if there is a next fill animation
225
+ if i < len(animations_by_type[animation_type]) - 1:
226
+ # animation endpoint is next fill animation's starting point
227
+ to_rgb = '#' + _convert_to_hex_str(animations_by_type[animation_type][i+1][20]) + _convert_to_hex_str(animations_by_type[animation_type][i+1][21]) + _convert_to_hex_str(animations_by_type[animation_type][i+1][22])
228
+ else:
229
+ fill_style = get_style_attributes_path(logo_path, animation_id, "fill")
230
+ stroke_style = get_style_attributes_path(logo_path, animation_id, "stroke")
231
+ if fill_style == "none" and stroke_style != "none":
232
+ color_hex = stroke_style
233
+ else:
234
+ color_hex = fill_style
235
+ to_rgb = color_hex
236
+ current_animations.append(_animation_fill(animation_id, begin, dur, from_rgb, to_rgb))
237
+ elif animation_type == 8:
238
+ # animation: opacity
239
+ from_f = animations_by_type[animation_type][i][23] / 100 # percent
240
+ # Check if there is a next opacity animation
241
+ if i < len(animations_by_type[animation_type]) - 1:
242
+ # animation endpoint is next opacity animation's starting point
243
+ to_f = animations_by_type[animation_type][i+1][23] / 100 # percent
244
+ else:
245
+ # animation endpoint is final position of object
246
+ to_f = 1
247
+ current_animations.append(_animation_opacity(animation_id, begin, dur, from_f, to_f))
248
+ elif animation_type == 9:
249
+ # animation: blur
250
+ from_f = animations_by_type[animation_type][i][24]
251
+ # Check if there is a next blur animation
252
+ if i < len(animations_by_type[animation_type]) - 1:
253
+ # animation endpoint is next blur animation's starting point
254
+ to_f = animations_by_type[animation_type][i+1][24]
255
+ else:
256
+ # animation endpoint is final position of object
257
+ to_f = 1
258
+ current_animations.append(_animation_blur(animation_id, begin, dur, from_f, to_f))
259
+ total_animations += current_animations
260
+ # Shift begin - TODO test
261
+ min_b = np.inf
262
+ for animation in total_animations:
263
+ print(animation["begin"], min_b)
264
+ if float(animation["begin"]) < float(min_b):
265
+ min_b = animation["begin"]
266
+ for animation in total_animations:
267
+ animation["begin"] = float(animation["begin"]) - float(min_b)
268
+
269
+ _insert_animations(total_animations, logo_path, logo_path)
270
+
271
+ def _convert_to_hex_str(i: int):
272
+ h = str(hex(i))[2:]
273
+ if i < 16:
274
+ h = '0' + h
275
+ return h
276
+
277
+ def _animation_translate(animation_id: int, begin: float, dur: float, from_x: int, from_y: int, to_x: int, to_y: int):
278
+ print('animation: translate')
279
+ animation_dict = {}
280
+ animation_dict['animation_id'] = animation_id
281
+ animation_dict['animation_type'] = 'animate_transform'
282
+ animation_dict['attributeName'] = 'transform'
283
+ animation_dict['attributeType'] = 'XML'
284
+ animation_dict['type'] = 'translate'
285
+ animation_dict['begin'] = str(begin)
286
+ animation_dict['dur'] = str(dur)
287
+ animation_dict['fill'] = 'freeze'
288
+ animation_dict['from'] = f'{from_x} {from_y}'
289
+ animation_dict['to'] = f'{to_x} {to_y}'
290
+ return animation_dict
291
+
292
+ def _animation_curve(animation_id: int, begin: float, dur: float, from_x: int, from_y: int, via_x: int, via_y: int, to_x: int, to_y: int):
293
+ print('animation: curve')
294
+ animation_dict = {}
295
+ animation_dict['animation_id'] = animation_id
296
+ animation_dict['animation_type'] = 'animate_motion'
297
+ animation_dict['begin'] = str(begin)
298
+ animation_dict['dur'] = str(dur)
299
+ animation_dict['fill'] = 'freeze'
300
+ animation_dict['from'] = f'{from_x} {from_y}'
301
+ animation_dict['via'] = f'{via_x} {via_y}'
302
+ animation_dict['to'] = f'{to_x} {to_y}'
303
+ return animation_dict
304
+
305
+ def _animation_scale(animation_id: int, begin: float, dur: float, from_f: float, to_f: float):
306
+ print('animation: scale')
307
+ animation_dict = {}
308
+ animation_dict['animation_id'] = animation_id
309
+ animation_dict['animation_type'] = 'animate_transform'
310
+ animation_dict['attributeName'] = 'transform'
311
+ animation_dict['attributeType'] = 'XML'
312
+ animation_dict['type'] = 'scale'
313
+ animation_dict['begin'] = str(begin)
314
+ animation_dict['dur'] = str(dur)
315
+ animation_dict['fill'] = 'freeze'
316
+ animation_dict['from'] = str(from_f)
317
+ animation_dict['to'] = str(to_f)
318
+ return animation_dict
319
+
320
+ def _animation_rotate(animation_id: int, begin: float, dur: float, from_degree: int, to_degree: int, midpoints: list):
321
+ print('animation: rotate')
322
+ animation_dict = {}
323
+ animation_dict['animation_id'] = animation_id
324
+ animation_dict['animation_type'] = 'animate_transform'
325
+ animation_dict['attributeName'] = 'transform'
326
+ animation_dict['attributeType'] = 'XML'
327
+ animation_dict['type'] = 'rotate'
328
+ animation_dict['begin'] = str(begin)
329
+ animation_dict['dur'] = str(dur)
330
+ animation_dict['fill'] = 'freeze'
331
+ animation_dict['from'] = f'{from_degree} {midpoints[0]} {midpoints[1]}'
332
+ animation_dict['to'] = f'{to_degree} {midpoints[0]} {midpoints[1]}'
333
+ return animation_dict
334
+
335
+ def _animation_skewX(animation_id: int, begin: float, dur: float, from_i: int, to_i: int):
336
+ print('animation: skew')
337
+ animation_dict = {}
338
+ animation_dict['animation_id'] = animation_id
339
+ animation_dict['animation_type'] = 'animate_transform'
340
+ animation_dict['attributeName'] = 'transform'
341
+ animation_dict['attributeType'] = 'XML'
342
+ animation_dict['type'] = 'skewX'
343
+ animation_dict['begin'] = str(begin)
344
+ animation_dict['dur'] = str(dur)
345
+ animation_dict['fill'] = 'freeze'
346
+ animation_dict['from'] = f'{from_i}'
347
+ animation_dict['to'] = f'{to_i}'
348
+ return animation_dict
349
+
350
+ def _animation_skewY(animation_id: int, begin: float, dur: float, from_i: int, to_i: int):
351
+ print('animation: skew')
352
+ animation_dict = {}
353
+ animation_dict['animation_id'] = animation_id
354
+ animation_dict['animation_type'] = 'animate_transform'
355
+ animation_dict['attributeName'] = 'transform'
356
+ animation_dict['attributeType'] = 'XML'
357
+ animation_dict['type'] = 'skewY'
358
+ animation_dict['begin'] = str(begin)
359
+ animation_dict['dur'] = str(dur)
360
+ animation_dict['fill'] = 'freeze'
361
+ animation_dict['from'] = f'{from_i}'
362
+ animation_dict['to'] = f'{to_i}'
363
+ return animation_dict
364
+
365
+ def _animation_fill(animation_id: int, begin: float, dur: float, from_rgb: str, to_rgb: str):
366
+ print('animation: fill')
367
+ animation_dict = {}
368
+ animation_dict['animation_id'] = animation_id
369
+ animation_dict['animation_type'] = 'animate'
370
+ animation_dict['attributeName'] = 'fill'
371
+ animation_dict['attributeType'] = 'XML'
372
+ animation_dict['type'] = 'fill'
373
+ animation_dict['begin'] = str(begin)
374
+ animation_dict['dur'] = str(dur)
375
+ animation_dict['fill'] = 'freeze'
376
+ animation_dict['from'] = from_rgb
377
+ animation_dict['to'] = to_rgb
378
+ return animation_dict
379
+
380
+ def _animation_opacity(animation_id: int, begin: float, dur: float, from_f: float, to_f: float):
381
+ print('animation: opacity')
382
+ animation_dict = {}
383
+ animation_dict['animation_id'] = animation_id
384
+ animation_dict['animation_type'] = 'animate'
385
+ animation_dict['attributeName'] = 'opacity'
386
+ animation_dict['attributeType'] = 'XML'
387
+ animation_dict['type'] = 'opacity'
388
+ animation_dict['begin'] = str(begin)
389
+ animation_dict['dur'] = str(dur)
390
+ animation_dict['fill'] = 'freeze'
391
+ animation_dict['from'] = str(from_f)
392
+ animation_dict['to'] = str(to_f)
393
+ return animation_dict
394
+
395
+ def _animation_blur(animation_id: int, begin: float, dur: float, from_f: float, to_f: float):
396
+ print('animation: blur')
397
+ animation_dict = {}
398
+ animation_dict['animation_id'] = animation_id
399
+ animation_dict['animation_type'] = 'animate_filter'
400
+ animation_dict['attributeName'] = 'transform'
401
+ animation_dict['attributeType'] = 'XML'
402
+ animation_dict['type'] = 'blur'
403
+ animation_dict['begin'] = str(begin)
404
+ animation_dict['dur'] = str(dur)
405
+ animation_dict['fill'] = 'freeze'
406
+ animation_dict['from'] = str(from_f)
407
+ animation_dict['to'] = str(to_f)
408
+ return animation_dict
409
+
410
+ def _insert_animations(animations: list, path: str, target_path: str):
411
+ print('Insert animations')
412
+ # Load XML
413
+ document = minidom.parse(path)
414
+ # Collect all elements
415
+ elements = document.getElementsByTagName('path') + document.getElementsByTagName('circle') + document.getElementsByTagName(
416
+ 'ellipse') + document.getElementsByTagName('line') + document.getElementsByTagName(
417
+ 'polygon') + document.getElementsByTagName('polyline') + document.getElementsByTagName(
418
+ 'rect') + document.getElementsByTagName('text')
419
+ # Create statement
420
+ for animation in animations:
421
+
422
+ # Search for element
423
+ current_element = None
424
+ for element in elements:
425
+ if element.getAttribute('animation_id') == str(animation['animation_id']):
426
+ current_element = element
427
+ if current_element == None:
428
+ # Animation id not found - take next animation
429
+ continue
430
+ if animation['animation_type'] == 'animate_transform':
431
+ animate_statement = _create_animate_transform_statement(animation)
432
+ current_element.appendChild(document.createElement(animate_statement))
433
+ elif animation['animation_type'] == 'animate_motion':
434
+ animate_statement = _create_animate_motion_statement(animation)
435
+ current_element.appendChild(document.createElement(animate_statement))
436
+ elif animation['animation_type'] == 'animate':
437
+ animate_statement = _create_animate_statement(animation)
438
+ current_element.appendChild(document.createElement(animate_statement))
439
+ elif animation['animation_type'] == 'animate_filter':
440
+ filter_element, fe_element, animate_statement = _create_animate_filter_statement(animation, document)
441
+ defs = document.getElementsByTagName('defs')
442
+ current_defs = None
443
+ # Check if defs tag exists; create otherwise
444
+ if len(defs) == 0:
445
+ svg = document.getElementsByTagName('svg')[0]
446
+ current_defs = document.createElement('defs')
447
+ svg.appendChild(current_defs)
448
+ else:
449
+ current_defs = defs[0]
450
+ # Check if filter to be appended
451
+ if filter_element != None:
452
+ # Create filter
453
+ print('append filter')
454
+ current_defs.appendChild(filter_element)
455
+ # Check if FE to be created
456
+ if fe_element != None:
457
+ print('create fe statement')
458
+ # Check if filter set; else search
459
+ if filter_element == None:
460
+ # Search for filter
461
+ id = 'filter_' + str(animation['animation_id'])
462
+ for f in document.getElementsByTagName('filter'):
463
+ if f.getAttribute('id') == id:
464
+ filter_element = f
465
+ # Create FE
466
+ filter_element.appendChild(fe_element)
467
+ current_defs.appendChild(document.createElement(animate_statement))
468
+ current_element.setAttribute('filter', f'url(#filter_{animation["animation_id"]})')
469
+
470
+ # Save XML to target path
471
+ with open(target_path, 'wb') as f:
472
+ f.write(document.toprettyxml(encoding="iso-8859-1"))
473
+
474
+
475
+
476
+ def _create_animate_transform_statement(animation_dict: dict):
477
+ """ Set up animation statement from model output for ANIMATETRANSFORM animations
478
+ (Adapted from AnimateSVG)
479
+ """
480
+ animation = f'animateTransform attributeName="transform" attributeType="XML" ' \
481
+ f'type="{animation_dict["type"]}" ' \
482
+ f'begin="{str(animation_dict["begin"])}" ' \
483
+ f'dur="{str(animation_dict["dur"])}" ' \
484
+ f'from="{str(animation_dict["from"])}" ' \
485
+ f'to="{str(animation_dict["to"])}" ' \
486
+ f'fill="{str(animation_dict["fill"])}" ' \
487
+ 'additive="sum"'
488
+
489
+ return animation
490
+
491
+ def _create_animate_statement(animation_dict: dict):
492
+ """ Set up animation statement from model output for ANIMATE animations
493
+ (adapted from AnimateSVG)
494
+ """
495
+ animation = f'animate attributeName="{animation_dict["type"]}" ' \
496
+ f'begin="{str(animation_dict["begin"])}" ' \
497
+ f'dur="{str(animation_dict["dur"])}" ' \
498
+ f'from="{str(animation_dict["from"])}" ' \
499
+ f'to="{str(animation_dict["to"])}" ' \
500
+ f'fill="{str(animation_dict["fill"])}" '\
501
+ 'additive="sum"'
502
+
503
+ return animation
504
+
505
+ def _create_animate_motion_statement(animation_dict: dict):
506
+ """ Set up animatie motion statement from model output for ANIMATE_MOTION animations
507
+ """
508
+ animation = f'animateMotion ' \
509
+ f'begin="{str(animation_dict["begin"])}" ' \
510
+ f'dur="{str(animation_dict["dur"])}" ' \
511
+ f'path="M{animation_dict["from"]}" Q{animation_dict["via"]} {animation_dict["to"]}' \
512
+ f'fill="{str(animation_dict["fill"])}" '\
513
+ 'additive="sum"'
514
+ return animation
515
+
516
+ def _create_animate_filter_statement(animation_dict: dict, document: minidom.Document):
517
+ global filter_id
518
+ filter_id += 1
519
+ filter_element = None
520
+ fe_element = None
521
+ animate_statement = None
522
+ if animation_dict['type'] == 'blur':
523
+ # Check if filter already exists
524
+ filters = document.getElementsByTagName('filter')
525
+ current_filter = None
526
+ current_fe = None
527
+ for f in filters:
528
+ #print(f.getAttribute('id') == f'filter_{str(animation_dict["animation_id"])}')
529
+ if f.getAttribute('id') == f'filter_{str(animation_dict["animation_id"])}':
530
+ current_filter = f
531
+ fe_elements = document.getElementsByTagName('feGaussianBlur')
532
+ for fe in fe_elements:
533
+ if fe.getAttribute('id') == f'filter_blur_{str(animation_dict["animation_id"])}':
534
+ current_fe = fe
535
+ if current_filter == None:
536
+ filter_element = document.createElement('filter')
537
+ filter_element.setAttribute('id', f'filter_{str(animation_dict["animation_id"])}')
538
+ if current_fe == None:
539
+ fe_element = document.createElement('feGaussianBlur')
540
+ fe_element.setAttribute('id', f'filter_blur_{str(animation_dict["animation_id"])}')
541
+ fe_element.setAttribute('stdDeviation', '0')
542
+ animate_statement = f'animate href="#filter_blur_{str(animation_dict["animation_id"])}" ' \
543
+ f'attributeName="stdDeviation" ' \
544
+ f'begin="{str(animation_dict["begin"])}" ' \
545
+ f'dur="{str(animation_dict["dur"])}" ' \
546
+ f'from="{str(animation_dict["from"])}" ' \
547
+ f'to="{str(animation_dict["to"])}" ' \
548
+ f'fill="{str(animation_dict["fill"])}"'\
549
+ 'additive="sum"'
550
+ return filter_element, fe_element, animate_statement
551
+
552
+
553
+
554
+
555
+
556
+
557
+
558
+
559
+ def randomly_animate_logo(logo_path: str, target_path: str, number_of_animations: int, previously_generated: pd.DataFrame = None):
560
+ # Creates model output equal to defined number of animations. They are then randomly distributed over the paths.
561
+ # Assign animation id to every path - TODO this changes the original logo!
562
+ document = minidom.parse(logo_path)
563
+ paths = document.getElementsByTagName('path') + document.getElementsByTagName('circle') + document.getElementsByTagName(
564
+ 'ellipse') + document.getElementsByTagName('line') + document.getElementsByTagName(
565
+ 'polygon') + document.getElementsByTagName('polyline') + document.getElementsByTagName(
566
+ 'rect') + document.getElementsByTagName('text')
567
+ for i in range(len(paths)):
568
+ paths[i].setAttribute('animation_id', str(i))
569
+ with open(target_path, 'wb') as svg_file:
570
+ svg_file.write(document.toxml(encoding='iso-8859-1'))
571
+ # Create random animations
572
+ for i in range(0, number_of_animations):
573
+ animation_type = random.randint(0, 8) # Determine animation type (as of now only primitive animation types)
574
+ model_output = np.zeros(18)
575
+ model_output[animation_type] = 1 # Set animation type
576
+ # Set animation parameters
577
+
578
+
579
+
580
+
581
+
582
+
583
+ # model_output = [
584
+ # {
585
+ # 'animation_id': 1,
586
+ # 'model_output': [0, 0, 0, 0, 0, 0, 0, 1, 1, 10, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10]
587
+ # },
588
+ # {
589
+ # 'animation_id': 1,
590
+ # 'model_output': [0, 0, 0, 0, 0, 0, 0, 1, 5, 3, 4, 5, 2, 1, 2, 3, 4, 5, 6, 7, 1000, 20]
591
+ # }
592
+ # ]
593
+ # model_output = pd.DataFrame(model_output)
594
+ # #print(model_output)
595
+ # path = 'src/postprocessing/logo_0.svg'
596
+ # # Assign animation id to every path - TODO this changes the original logo!
597
+ # document = minidom.parse(path)
598
+ # paths = document.getElementsByTagName('path')
599
+ # for i in range(len(paths)):
600
+ # paths[i].setAttribute('animation_id', str(i))
601
+ # with open(path, 'wb') as svg_file:
602
+ # svg_file.write(document.toxml(encoding='iso-8859-1'))
603
+ # #print('Inserted animation id')
604
+ # animate_logo(model_output, path)
src/postprocessing/transform_animation_predictor_output.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.postprocessing.get_svg_size_pos import get_svg_size, get_midpoint_of_path_bbox
2
+ from src.postprocessing.get_style_attributes import get_style_attributes_path
3
+ from src.postprocessing.get_svg_color_tendency import get_svg_color_tendencies
4
+
5
+
6
+ def transform_animation_predictor_output(file, animation_id, output):
7
+ """ Function to translate the numeric model output to animation commands.
8
+ Example: transform_animation_predictor_output("data/svgs/logo_1.svg", 0, [0,0,1,0,0,0,-1,-1,-1,0.42,-1,-1])
9
+
10
+ Args:
11
+ file (str): Path of SVG file.
12
+ animation_id (int): ID of element in SVG that gets animated.
13
+ output (list): 12-dimensional list of numeric values of which first 6 determine the animation to be used and
14
+ the last 6 determine the attribute from. Format: [translate, scale, rotate, skew, fill, opacity, translate_from_1, translate_from_2, scale_from, rotate_from, skew_from_1, skew_from_2].
15
+
16
+ Returns:
17
+ dict: Animation statement as dictionary.
18
+
19
+ """
20
+ animation = {}
21
+ width, height = get_svg_size(file)
22
+ x_midpoint, y_midpoint = get_midpoint_of_path_bbox(file, animation_id)
23
+ fill_style = get_style_attributes_path(file, animation_id, "fill")
24
+ stroke_style = get_style_attributes_path(file, animation_id, "stroke")
25
+ opacity_style = get_style_attributes_path(file, animation_id, "opacity")
26
+ color_1, color_2 = get_svg_color_tendencies(file)
27
+
28
+ if output[0] == 1:
29
+ animation["type"] = "translate"
30
+ x = (output[6] * 2 - 1) * width # between -width and width
31
+ y = (output[7] * 2 - 1) * height # between -height and height
32
+ animation["from_"] = f"{str(x)} {str(y)}"
33
+ animation["to"] = "0 0"
34
+
35
+ elif output[1] == 1:
36
+ animation["type"] = "scale"
37
+ animation["from_"] = output[8] * 2 # between 0 and 2
38
+ animation["to"] = 1
39
+
40
+ elif output[2] == 1:
41
+ animation["type"] = "rotate"
42
+ degree = int(output[9]*720) - 360 # between -360 and 360
43
+ animation["from_"] = f"{str(degree)} {str(x_midpoint)} {str(y_midpoint)}"
44
+ animation["to"] = f"0 {str(x_midpoint)} {str(y_midpoint)}"
45
+
46
+ elif output[3] == 1:
47
+ if output[10] > 0.5:
48
+ animation["type"] = "skewX"
49
+ animation["from_"] = (output[11] * 2 - 1) * width/20 # between -width/20 and width/20
50
+ else:
51
+ animation["type"] = "skewY"
52
+ animation["from_"] = (output[11] * 2 - 1) * height/20 # between -height/20 and height/20
53
+ animation["to"] = 0
54
+
55
+ elif output[4] == 1:
56
+ animation["type"] = "fill"
57
+ if fill_style == "none" and stroke_style != "none":
58
+ color_hex = stroke_style
59
+ else:
60
+ color_hex = fill_style
61
+ animation["to"] = color_hex
62
+
63
+ if color_hex != color_1:
64
+ color_from = color_1
65
+ else:
66
+ color_from = color_2
67
+ animation["from_"] = color_from
68
+
69
+ elif output[5] == 1:
70
+ animation["type"] = "opacity"
71
+ animation["from_"] = 0
72
+ animation["to"] = opacity_style
73
+
74
+ animation["dur"] = 4
75
+ animation["begin"] = 1
76
+ animation["fill"] = "freeze"
77
+
78
+ return animation
src/preprocessing/deepsvg/deepsvg_config/config.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch.optim as optim
7
+ from src.preprocessing.deepsvg.deepsvg_schedulers.warmup import GradualWarmupScheduler
8
+
9
+
10
+ class _Config:
11
+ """
12
+ Training config.
13
+ """
14
+ def __init__(self, num_gpus=1):
15
+
16
+ self.num_gpus = num_gpus #
17
+
18
+ self.dataloader_module = "deepsvg.svgtensor_dataset" #
19
+ self.collate_fn = None #
20
+ self.data_dir = "./data/svgs_tensors/" #
21
+ self.meta_filepath = "./data/svgs_meta.csv" #
22
+ self.loader_num_workers = 0 #
23
+
24
+ self.pretrained_path = "./models/hierarchical_ordered.pth.tar" #
25
+
26
+ self.model_cfg = None #
27
+
28
+ self.num_epochs = None #
29
+ self.num_steps = None #
30
+ self.learning_rate = 1e-3 #
31
+ self.batch_size = 100 #
32
+ self.warmup_steps = 500 #
33
+
34
+
35
+ # Dataset
36
+ self.train_ratio = 1.0 #
37
+ self.nb_augmentations = 1 #
38
+
39
+ self.max_num_groups = 15 #
40
+ self.max_seq_len = 30 #
41
+ self.max_total_len = None #
42
+
43
+ self.filter_uni = None #
44
+ self.filter_category = None #
45
+ self.filter_platform = None #
46
+
47
+ self.filter_labels = None #
48
+
49
+ self.grad_clip = None #
50
+
51
+ self.log_every = 20 #
52
+ self.val_every = 1000 #
53
+ self.ckpt_every = 1000 #
54
+
55
+ self.stats_to_print = {
56
+ "train": ["lr", "time"]
57
+ }
58
+
59
+ self.model_args = [] #
60
+ self.optimizer_starts = [0] #
61
+
62
+ # Overridable methods
63
+ def make_model(self):
64
+ raise NotImplementedError
65
+
66
+ def make_losses(self):
67
+ raise NotImplementedError
68
+
69
+ def make_optimizers(self, model):
70
+ return [optim.AdamW(model.parameters(), self.learning_rate)]
71
+
72
+ def make_schedulers(self, optimizers, epoch_size):
73
+ return [None] * len(optimizers)
74
+
75
+ def make_warmup_schedulers(self, optimizers, scheduler_lrs):
76
+ return [GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=self.warmup_steps, after_scheduler=scheduler_lr)
77
+ for optimizer, scheduler_lr in zip(optimizers, scheduler_lrs)]
78
+
79
+ def get_params(self, step, epoch):
80
+ return {}
81
+
82
+ def get_weights(self, step, epoch):
83
+ return {}
84
+
85
+ def set_train_vars(self, train_vars, dataloader):
86
+ pass
87
+
88
+ def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
89
+ pass
90
+
91
+ # Utility methods
92
+ def values(self):
93
+ for key in dir(self):
94
+ if not key.startswith("__") and not callable(getattr(self, key)):
95
+ yield key, getattr(self, key)
96
+
97
+ def to_dict(self):
98
+ return {key: val for key, val in self.values()}
99
+
100
+ def load_dict(self, dict):
101
+ for key, val in dict.items():
102
+ setattr(self, key, val)
103
+
104
+ def print_params(self):
105
+ for key, val in self.values():
106
+ print(f" {key} = {val}")
src/preprocessing/deepsvg/deepsvg_config/config_hierarchical_ordered.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from .default_icons import *
7
+
8
+
9
+ class ModelConfig(Hierarchical):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ self.label_condition = False
14
+ self.use_vae = False
15
+
16
+
17
+ class Config(Config):
18
+ def __init__(self, num_gpus=1):
19
+ super().__init__(num_gpus=num_gpus)
20
+
21
+ self.model_cfg = ModelConfig()
22
+ self.model_args = self.model_cfg.get_model_args()
23
+
24
+ self.filter_category = None
25
+
26
+ self.learning_rate = 1e-3 * num_gpus
27
+ self.batch_size = 20 #60 * num_gpus
28
+
29
+ self.val_every = 10 #2000
src/preprocessing/deepsvg/deepsvg_config/default_icons.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from src.preprocessing.deepsvg.deepsvg_config.config import _Config
7
+ from src.preprocessing.deepsvg.deepsvg_models.model import SVGTransformer
8
+ from src.preprocessing.deepsvg.deepsvg_models.loss import SVGLoss
9
+ from src.preprocessing.deepsvg.deepsvg_models.model_config import *
10
+ from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
11
+ from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
12
+ from src.preprocessing.deepsvg.deepsvg_svglib.svglib_utils import make_grid
13
+ from src.preprocessing.deepsvg.deepsvg_svglib.geom import Bbox
14
+ from src.preprocessing.deepsvg.deepsvg_utils.utils import batchify, linear
15
+
16
+ import torchvision.transforms.functional as TF
17
+ import torch.optim.lr_scheduler as lr_scheduler
18
+ import random
19
+
20
+
21
+ class ModelConfig(Hierarchical):
22
+ """
23
+ Overriding default model config.
24
+ """
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+
29
+ class Config(_Config):
30
+ """
31
+ Overriding default training config.
32
+ """
33
+ def __init__(self, num_gpus=1):
34
+ super().__init__(num_gpus=num_gpus)
35
+
36
+ # Model
37
+ self.model_cfg = ModelConfig()
38
+ self.model_args = self.model_cfg.get_model_args()
39
+
40
+ # Dataset
41
+ self.filter_category = None
42
+
43
+ self.train_ratio = 1.0
44
+
45
+ self.max_num_groups = 8
46
+ self.max_total_len = 50
47
+
48
+ # Dataloader
49
+ self.loader_num_workers = 4 * num_gpus
50
+
51
+ # Training
52
+ self.num_epochs = 50
53
+ self.val_every = 1000
54
+
55
+ # Optimization
56
+ self.learning_rate = 1e-3 * num_gpus
57
+ self.batch_size = 60 * num_gpus
58
+ self.grad_clip = 1.0
59
+
60
+ def make_schedulers(self, optimizers, epoch_size):
61
+ optimizer, = optimizers
62
+ return [lr_scheduler.StepLR(optimizer, step_size=2.5 * epoch_size, gamma=0.9)]
63
+
64
+ def make_model(self):
65
+ return SVGTransformer(self.model_cfg)
66
+
67
+ def make_losses(self):
68
+ return [SVGLoss(self.model_cfg)]
69
+
70
+ def get_weights(self, step, epoch):
71
+ return {
72
+ "kl_tolerance": 0.1,
73
+ "loss_kl_weight": linear(0, 10, step, 0, 10000),
74
+ "loss_hierarch_weight": 1.0,
75
+ "loss_cmd_weight": 1.0,
76
+ "loss_args_weight": 2.0,
77
+ "loss_visibility_weight": 1.0
78
+ }
79
+
80
+ def set_train_vars(self, train_vars, dataloader):
81
+ train_vars.x_inputs_train = [dataloader.dataset.get(idx, [*self.model_args, "tensor_grouped"])
82
+ for idx in random.sample(range(len(dataloader.dataset)), k=10)]
83
+
84
+ def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
85
+ device = next(model.parameters()).device
86
+
87
+ # Reconstruction
88
+ for i, data in enumerate(train_vars.x_inputs_train):
89
+ model_args = batchify((data[key] for key in self.model_args), device)
90
+ commands_y, args_y = model.module.greedy_sample(*model_args)
91
+ tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
92
+
93
+ try:
94
+ svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color("random")
95
+ except:
96
+ continue
97
+
98
+ tensor_target = data["tensor_grouped"][0].copy().drop_sos().unpad()
99
+ svg_path_gt = SVG.from_tensor(tensor_target.data, viewbox=Bbox(256)).normalize().split_paths().set_color("random")
100
+
101
+ img = make_grid([svg_path_sample, svg_path_gt]).draw(do_display=False, return_png=True, fill=False, with_points=False)
102
+ summary_writer.add_image(f"reconstructions_train/{i}", TF.to_tensor(img), step)
src/preprocessing/deepsvg/deepsvg_dataloader/svg_dataset.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from src.preprocessing.deepsvg.deepsvg_config.config import _Config
7
+ from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
8
+ from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
9
+ from src.preprocessing.deepsvg.deepsvg_svglib.geom import Point
10
+
11
+ import math
12
+ import torch
13
+ import torch.utils.data
14
+ import random
15
+ from typing import List, Union
16
+ import pandas as pd
17
+ import os
18
+ import pickle
19
+ Num = Union[int, float]
20
+
21
+
22
+ class SVGDataset(torch.utils.data.Dataset):
23
+ def __init__(self, data_dir, meta_filepath, model_args, max_num_groups, max_seq_len, max_total_len=None,
24
+ filter_uni=None, filter_platform=None, filter_category=None, train_ratio=1.0, df=None, PAD_VAL=-1,
25
+ nb_augmentations=1, already_preprocessed=True):
26
+ self.data_dir = data_dir
27
+
28
+ self.already_preprocessed = already_preprocessed
29
+
30
+ self.MAX_NUM_GROUPS = max_num_groups
31
+ self.MAX_SEQ_LEN = max_seq_len
32
+ self.MAX_TOTAL_LEN = max_total_len
33
+
34
+ if max_total_len is None:
35
+ self.MAX_TOTAL_LEN = max_num_groups * max_seq_len
36
+
37
+ if df is None:
38
+ df = pd.read_csv(meta_filepath)
39
+
40
+ if len(df) > 0:
41
+ if filter_uni is not None:
42
+ df = df[df.uni.isin(filter_uni)]
43
+
44
+ if filter_platform is not None:
45
+ df = df[df.platform.isin(filter_platform)]
46
+
47
+ if filter_category is not None:
48
+ df = df[df.category.isin(filter_category)]
49
+
50
+ df = df[(df.nb_groups <= max_num_groups) & (df.max_len_group <= max_seq_len)]
51
+ if max_total_len is not None:
52
+ df = df[df.total_len <= max_total_len]
53
+
54
+ self.df = df.sample(frac=train_ratio) if train_ratio < 1.0 else df
55
+
56
+ self.model_args = model_args
57
+
58
+ self.PAD_VAL = PAD_VAL
59
+
60
+ self.nb_augmentations = nb_augmentations
61
+
62
+ def search_name(self, name):
63
+ return self.df[self.df.commonName.str.contains(name)]
64
+
65
+ def _filter_categories(self, filter_category):
66
+ self.df = self.df[self.df.category.isin(filter_category)]
67
+
68
+ @staticmethod
69
+ def _uni_to_label(uni):
70
+ if 48 <= uni <= 57:
71
+ return uni - 48
72
+ elif 65 <= uni <= 90:
73
+ return uni - 65 + 10
74
+ return uni - 97 + 36
75
+
76
+ @staticmethod
77
+ def _label_to_uni(label_id):
78
+ if 0 <= label_id <= 9:
79
+ return label_id + 48
80
+ elif 10 <= label_id <= 35:
81
+ return label_id + 65 - 10
82
+ return label_id + 97 - 36
83
+
84
+ @staticmethod
85
+ def _category_to_label(category):
86
+ categories = ['characters', 'free-icons', 'logos', 'alphabet', 'animals', 'arrows', 'astrology', 'baby', 'beauty',
87
+ 'business', 'cinema', 'city', 'clothing', 'computer-hardware', 'crime', 'cultures', 'data', 'diy',
88
+ 'drinks', 'ecommerce', 'editing', 'files', 'finance', 'folders', 'food', 'gaming', 'hands', 'healthcare',
89
+ 'holidays', 'household', 'industry', 'maps', 'media-controls', 'messaging', 'military', 'mobile',
90
+ 'music', 'nature', 'network', 'photo-video', 'plants', 'printing', 'profile', 'programming', 'science',
91
+ 'security', 'shopping', 'social-networks', 'sports', 'time-and-date', 'transport', 'travel', 'user-interface',
92
+ 'users', 'weather', 'flags', 'emoji', 'men', 'women']
93
+ return categories.index(category)
94
+
95
+ def get_label(self, idx=0, entry=None):
96
+ if entry is None:
97
+ entry = self.df.iloc[idx]
98
+
99
+ if "uni" in self.df.columns: # Font dataset
100
+ label = self._uni_to_label(entry.uni)
101
+ return torch.tensor(label)
102
+ elif "category" in self.df.columns: # Icons dataset
103
+ label = self._category_to_label(entry.category)
104
+ return torch.tensor(label)
105
+
106
+ return None
107
+
108
+ def idx_to_id(self, idx):
109
+ return self.df.iloc[idx].id
110
+
111
+ def entry_from_id(self, id):
112
+ return self.df[self.df.id == str(id)].iloc[0]
113
+
114
+ def _load_svg(self, icon_id):
115
+ svg = SVG.load_svg(os.path.join(self.data_dir, f"{icon_id}.svg"))
116
+
117
+ if not self.already_preprocessed:
118
+ svg.fill_(False)
119
+ svg.normalize().zoom(0.9)
120
+ svg.canonicalize()
121
+ svg = svg.simplify_heuristic()
122
+
123
+ return svg
124
+
125
+ def __len__(self):
126
+ return len(self.df) * self.nb_augmentations
127
+
128
+ def random_icon(self):
129
+ return self[random.randrange(0, len(self))]
130
+
131
+ def random_id(self):
132
+ idx = random.randrange(0, len(self)) % len(self.df)
133
+ return self.idx_to_id(idx)
134
+
135
+ def random_id_by_uni(self, uni):
136
+ df = self.df[self.df.uni == uni]
137
+ return df.id.sample().iloc[0]
138
+
139
+ def __getitem__(self, idx):
140
+ return self.get(idx, self.model_args)
141
+
142
+ @staticmethod
143
+ def _augment(svg, mean=False):
144
+ dx, dy = (0, 0) if mean else (5 * random.random() - 2.5, 5 * random.random() - 2.5)
145
+ factor = 0.7 if mean else 0.2 * random.random() + 0.6
146
+
147
+ return svg.zoom(factor).translate(Point(dx, dy))
148
+
149
+ @staticmethod
150
+ def simplify(svg, normalize=True):
151
+ svg.canonicalize(normalize=normalize)
152
+ svg = svg.simplify_heuristic()
153
+ return svg.normalize()
154
+
155
+ @staticmethod
156
+ def preprocess(svg, augment=True, numericalize=True, mean=False):
157
+ if augment:
158
+ svg = SVGDataset._augment(svg, mean=mean)
159
+ if numericalize:
160
+ return svg.numericalize(256)
161
+ return svg
162
+
163
+ def get(self, idx=0, model_args=None, random_aug=True, id=None, svg: SVG=None):
164
+ if id is None:
165
+ idx = idx % len(self.df)
166
+ id = self.idx_to_id(idx)
167
+
168
+ if svg is None:
169
+ svg = self._load_svg(id)
170
+
171
+ svg = SVGDataset.preprocess(svg, augment=random_aug)
172
+
173
+ t_sep, fillings = svg.to_tensor(concat_groups=False, PAD_VAL=self.PAD_VAL), svg.to_fillings()
174
+
175
+ # Note: DeepSVG can only handle 8 paths in a SVG and 30 sequences per path
176
+ if len(t_sep) > 8:
177
+ #print(f"SVG {id} has more than 30 segments.")
178
+ t_sep = t_sep[0:8]
179
+ fillings = fillings[0:8]
180
+
181
+ for i in range(len(t_sep)):
182
+ if len(t_sep[i]) > 30:
183
+ #print(f"SVG {id}: Path nr {i} has more than 30 segments.")
184
+ t_sep[i] = t_sep[i][0:30]
185
+
186
+ label = self.get_label(idx)
187
+
188
+ return self.get_data(t_sep, fillings, model_args=model_args, label=label)
189
+
190
+ def get_data(self, t_sep, fillings, model_args=None, label=None):
191
+ res = {}
192
+
193
+ if model_args is None:
194
+ model_args = self.model_args
195
+
196
+ pad_len = max(self.MAX_NUM_GROUPS - len(t_sep), 0)
197
+
198
+ t_sep.extend([torch.empty(0, 14)] * pad_len)
199
+ fillings.extend([0] * pad_len)
200
+
201
+ t_grouped = [SVGTensor.from_data(torch.cat(t_sep, dim=0), PAD_VAL=self.PAD_VAL).add_eos().add_sos().pad(
202
+ seq_len=self.MAX_TOTAL_LEN + 2)]
203
+
204
+ t_sep = [SVGTensor.from_data(t, PAD_VAL=self.PAD_VAL, filling=f).add_eos().add_sos().pad(seq_len=self.MAX_SEQ_LEN + 2) for
205
+ t, f in zip(t_sep, fillings)]
206
+
207
+ for arg in set(model_args):
208
+ if "_grouped" in arg:
209
+ arg_ = arg.split("_grouped")[0]
210
+ t_list = t_grouped
211
+ else:
212
+ arg_ = arg
213
+ t_list = t_sep
214
+
215
+ if arg_ == "tensor":
216
+ res[arg] = t_list
217
+
218
+ if arg_ == "commands":
219
+ res[arg] = torch.stack([t.cmds() for t in t_list])
220
+
221
+ if arg_ == "args_rel":
222
+ res[arg] = torch.stack([t.get_relative_args() for t in t_list])
223
+ if arg_ == "args":
224
+ res[arg] = torch.stack([t.args() for t in t_list])
225
+
226
+ if "filling" in model_args:
227
+ res["filling"] = torch.stack([torch.tensor(t.filling) for t in t_sep]).unsqueeze(-1)
228
+
229
+ if "label" in model_args:
230
+ res["label"] = label
231
+
232
+ return res
233
+
234
+
235
+ def load_dataset(cfg: _Config, already_preprocessed=True):
236
+ dataset = SVGDataset(cfg.data_dir, cfg.meta_filepath, cfg.model_args, cfg.max_num_groups, cfg.max_seq_len, cfg.max_total_len,
237
+ cfg.filter_uni, cfg.filter_platform, cfg.filter_category, cfg.train_ratio,
238
+ nb_augmentations=cfg.nb_augmentations, already_preprocessed=already_preprocessed)
239
+ return dataset
src/preprocessing/deepsvg/deepsvg_difflib/tensor.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import torch
8
+ import torch.utils.data
9
+ from typing import Union
10
+ Num = Union[int, float]
11
+
12
+
13
+ class AnimationTensor:
14
+
15
+ COMMANDS_SIMPLIFIED = ['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9']
16
+
17
+ CMD_ARGS_MASK = torch.tensor([[0, 0, 0], # a0
18
+ [0, 0, 0], # a1
19
+ [0, 0, 0], # a2
20
+ [1, 1, 1], # a3
21
+ [0, 0, 0], # a4
22
+ [0, 0, 0], # a5
23
+ [0, 0, 0], # a6
24
+ [0, 0, 0], # a7
25
+ [1, 1, 1], # a8
26
+ [0, 0, 0]]) # a9
27
+
28
+ class Index:
29
+ COMMAND = 0
30
+ DURATION = 1
31
+ FROM = 2
32
+ BEGIN = 3
33
+
34
+ class IndexArgs:
35
+ DURATION = 0
36
+ FROM = 1
37
+ BEGIN = 2
38
+
39
+ all_arg_keys = ['duration', 'from', 'begin']
40
+ cmd_arg_keys = ["commands", *all_arg_keys]
41
+ all_keys = ["commands", *all_arg_keys]
42
+
43
+ def __init__(self, commands, duration, from_, begin,
44
+ seq_len=None, label=None, PAD_VAL=-1, ARGS_DIM=256, filling=0):
45
+
46
+ self.commands = commands.reshape(-1, 1).float()
47
+
48
+ self.duration = duration.float()
49
+ self.from_ = from_.float()
50
+ self.begin = begin.float()
51
+
52
+ self.seq_len = torch.tensor(len(commands)) if seq_len is None else seq_len
53
+ self.label = label
54
+
55
+ self.PAD_VAL = PAD_VAL
56
+ self.ARGS_DIM = ARGS_DIM
57
+
58
+ # self.sos_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("SOS")]).unsqueeze(-1)
59
+ # self.eos_token = self.pad_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("EOS")]).unsqueeze(-1)
60
+
61
+ self.filling = filling
62
+
63
+
64
+ class SVGTensor:
65
+ # 0 1 2 3 4 5 6
66
+ COMMANDS_SIMPLIFIED = ["m", "l", "c", "a", "EOS", "SOS", "z"]
67
+
68
+ # rad x lrg sw ctrl ctrl end
69
+ # ius axs arc eep 1 2 pos
70
+ # rot fg fg
71
+ CMD_ARGS_MASK = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # m
72
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # l
73
+ [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], # c
74
+ [1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1], # a
75
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # EOS
76
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # SOS
77
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) # z
78
+
79
+ class Index:
80
+ COMMAND = 0
81
+ RADIUS = slice(1, 3)
82
+ X_AXIS_ROT = 3
83
+ LARGE_ARC_FLG = 4
84
+ SWEEP_FLG = 5
85
+ START_POS = slice(6, 8)
86
+ CONTROL1 = slice(8, 10)
87
+ CONTROL2 = slice(10, 12)
88
+ END_POS = slice(12, 14)
89
+
90
+ class IndexArgs:
91
+ RADIUS = slice(0, 2)
92
+ X_AXIS_ROT = 2
93
+ LARGE_ARC_FLG = 3
94
+ SWEEP_FLG = 4
95
+ CONTROL1 = slice(5, 7)
96
+ CONTROL2 = slice(7, 9)
97
+ END_POS = slice(9, 11)
98
+
99
+ position_keys = ["control1", "control2", "end_pos"]
100
+ all_position_keys = ["start_pos", *position_keys]
101
+ arg_keys = ["radius", "x_axis_rot", "large_arc_flg", "sweep_flg", *position_keys]
102
+ all_arg_keys = [*arg_keys[:4], "start_pos", *arg_keys[4:]]
103
+ cmd_arg_keys = ["commands", *arg_keys]
104
+ all_keys = ["commands", *all_arg_keys]
105
+
106
+ def __init__(self, commands, radius, x_axis_rot, large_arc_flg, sweep_flg, control1, control2, end_pos,
107
+ seq_len=None, label=None, PAD_VAL=-1, ARGS_DIM=256, filling=0):
108
+
109
+ self.commands = commands.reshape(-1, 1).float()
110
+
111
+ self.radius = radius.float()
112
+ self.x_axis_rot = x_axis_rot.reshape(-1, 1).float()
113
+ self.large_arc_flg = large_arc_flg.reshape(-1, 1).float()
114
+ self.sweep_flg = sweep_flg.reshape(-1, 1).float()
115
+
116
+ self.control1 = control1.float()
117
+ self.control2 = control2.float()
118
+ self.end_pos = end_pos.float()
119
+
120
+ self.seq_len = torch.tensor(len(commands)) if seq_len is None else seq_len
121
+ self.label = label
122
+
123
+ self.PAD_VAL = PAD_VAL
124
+ self.ARGS_DIM = ARGS_DIM
125
+
126
+ self.sos_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("SOS")]).unsqueeze(-1)
127
+ self.eos_token = self.pad_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("EOS")]).unsqueeze(-1)
128
+
129
+ self.filling = filling
130
+
131
+ @property
132
+ def start_pos(self):
133
+ start_pos = self.end_pos[:-1]
134
+
135
+ return torch.cat([
136
+ start_pos.new_zeros(1, 2),
137
+ start_pos
138
+ ])
139
+
140
+ @staticmethod
141
+ def from_data(data, *args, **kwargs):
142
+ return SVGTensor(data[:, SVGTensor.Index.COMMAND], data[:, SVGTensor.Index.RADIUS], data[:, SVGTensor.Index.X_AXIS_ROT],
143
+ data[:, SVGTensor.Index.LARGE_ARC_FLG], data[:, SVGTensor.Index.SWEEP_FLG], data[:, SVGTensor.Index.CONTROL1],
144
+ data[:, SVGTensor.Index.CONTROL2], data[:, SVGTensor.Index.END_POS], *args, **kwargs)
145
+
146
+ @staticmethod
147
+ def from_cmd_args(commands, args, *nargs, **kwargs):
148
+ return SVGTensor(commands, args[:, SVGTensor.IndexArgs.RADIUS], args[:, SVGTensor.IndexArgs.X_AXIS_ROT],
149
+ args[:, SVGTensor.IndexArgs.LARGE_ARC_FLG], args[:, SVGTensor.IndexArgs.SWEEP_FLG], args[:, SVGTensor.IndexArgs.CONTROL1],
150
+ args[:, SVGTensor.IndexArgs.CONTROL2], args[:, SVGTensor.IndexArgs.END_POS], *nargs, **kwargs)
151
+
152
+ def get_data(self, keys):
153
+ return torch.cat([self.__getattribute__(key) for key in keys], dim=-1)
154
+
155
+ @property
156
+ def data(self):
157
+ return self.get_data(self.all_keys)
158
+
159
+ def copy(self):
160
+ return SVGTensor(*[self.__getattribute__(key).clone() for key in self.cmd_arg_keys],
161
+ seq_len=self.seq_len.clone(), label=self.label, PAD_VAL=self.PAD_VAL, ARGS_DIM=self.ARGS_DIM,
162
+ filling=self.filling)
163
+
164
+ def add_sos(self):
165
+ self.commands = torch.cat([self.sos_token, self.commands])
166
+
167
+ for key in self.arg_keys:
168
+ v = self.__getattribute__(key)
169
+ self.__setattr__(key, torch.cat([v.new_full((1, v.size(-1)), self.PAD_VAL), v]))
170
+
171
+ self.seq_len += 1
172
+ return self
173
+
174
+ def drop_sos(self):
175
+ for key in self.cmd_arg_keys:
176
+ self.__setattr__(key, self.__getattribute__(key)[1:])
177
+
178
+ self.seq_len -= 1
179
+ return self
180
+
181
+ def add_eos(self):
182
+ self.commands = torch.cat([self.commands, self.eos_token])
183
+
184
+ for key in self.arg_keys:
185
+ v = self.__getattribute__(key)
186
+ self.__setattr__(key, torch.cat([v, v.new_full((1, v.size(-1)), self.PAD_VAL)]))
187
+
188
+ return self
189
+
190
+ def pad(self, seq_len=51):
191
+ pad_len = max(seq_len - len(self.commands), 0)
192
+
193
+ self.commands = torch.cat([self.commands, self.pad_token.repeat(pad_len, 1)])
194
+
195
+ for key in self.arg_keys:
196
+ v = self.__getattribute__(key)
197
+ self.__setattr__(key, torch.cat([v, v.new_full((pad_len, v.size(-1)), self.PAD_VAL)]))
198
+
199
+ return self
200
+
201
+ def unpad(self):
202
+ # Remove EOS + padding
203
+ for key in self.cmd_arg_keys:
204
+ self.__setattr__(key, self.__getattribute__(key)[:self.seq_len])
205
+ return self
206
+
207
+ def draw(self, *args, **kwags):
208
+ from deepsvg.svglib.svg import SVGPath
209
+ return SVGPath.from_tensor(self.data).draw(*args, **kwags)
210
+
211
+ def cmds(self):
212
+ return self.commands.reshape(-1)
213
+
214
+ def args(self, with_start_pos=False):
215
+ if with_start_pos:
216
+ return self.get_data(self.all_arg_keys)
217
+
218
+ return self.get_data(self.arg_keys)
219
+
220
+ def _get_real_commands_mask(self):
221
+ mask = self.cmds() < self.COMMANDS_SIMPLIFIED.index("EOS")
222
+ return mask
223
+
224
+ def _get_args_mask(self):
225
+ mask = SVGTensor.CMD_ARGS_MASK[self.cmds().long()].bool()
226
+ return mask
227
+
228
+ def get_relative_args(self):
229
+ data = self.args().clone()
230
+
231
+ real_commands = self._get_real_commands_mask()
232
+ data_real_commands = data[real_commands]
233
+
234
+ start_pos = data_real_commands[:-1, SVGTensor.IndexArgs.END_POS].clone()
235
+
236
+ data_real_commands[1:, SVGTensor.IndexArgs.CONTROL1] -= start_pos
237
+ data_real_commands[1:, SVGTensor.IndexArgs.CONTROL2] -= start_pos
238
+ data_real_commands[1:, SVGTensor.IndexArgs.END_POS] -= start_pos
239
+ data[real_commands] = data_real_commands
240
+
241
+ mask = self._get_args_mask()
242
+ data[mask] += self.ARGS_DIM - 1
243
+ data[~mask] = self.PAD_VAL
244
+
245
+ return data
246
+
247
+ def sample_points(self, n=10):
248
+ device = self.commands.device
249
+
250
+ z = torch.linspace(0, 1, n, device=device)
251
+ Z = torch.stack([torch.ones_like(z), z, z.pow(2), z.pow(3)], dim=1)
252
+
253
+ Q = torch.tensor([
254
+ [[0., 0., 0., 0.], #  "m"
255
+ [0., 0., 0., 0.],
256
+ [0., 0., 0., 0.],
257
+ [0., 0., 0., 0.]],
258
+
259
+ [[1., 0., 0., 0.], # "l"
260
+ [-1, 0., 0., 1.],
261
+ [0., 0., 0., 0.],
262
+ [0., 0., 0., 0.]],
263
+
264
+ [[1., 0., 0., 0.], #  "c"
265
+ [-3, 3., 0., 0.],
266
+ [3., -6, 3., 0.],
267
+ [-1, 3., -3, 1.]],
268
+
269
+ torch.zeros(4, 4), # "a", no support yet
270
+
271
+ torch.zeros(4, 4), # "EOS"
272
+ torch.zeros(4, 4), # "SOS"
273
+ torch.zeros(4, 4), # "z"
274
+ ], device=device)
275
+
276
+ commands, pos = self.commands.reshape(-1).long(), self.get_data(self.all_position_keys).reshape(-1, 4, 2)
277
+ inds = (commands == self.COMMANDS_SIMPLIFIED.index("l")) | (commands == self.COMMANDS_SIMPLIFIED.index("c"))
278
+ commands, pos = commands[inds], pos[inds]
279
+
280
+ Z_coeffs = torch.matmul(Q[commands], pos)
281
+
282
+ # Last point being first point of next command, we drop last point except the one from the last command
283
+ sample_points = torch.matmul(Z, Z_coeffs)
284
+ sample_points = torch.cat([sample_points[:, :-1].reshape(-1, 2), sample_points[-1, -1].unsqueeze(0)])
285
+
286
+ return sample_points
287
+
288
+ @staticmethod
289
+ def get_length_distribution(p, normalize=True):
290
+ start, end = p[:-1], p[1:]
291
+ length_distr = torch.norm(end - start, dim=-1).cumsum(dim=0)
292
+ length_distr = torch.cat([length_distr.new_zeros(1), length_distr])
293
+ if normalize:
294
+ length_distr = length_distr / length_distr[-1]
295
+ return length_distr
296
+
297
+ def sample_uniform_points(self, n=100):
298
+ p = self.sample_points(n=n)
299
+
300
+ distr_unif = torch.linspace(0., 1., n).to(p.device)
301
+ distr = self.get_length_distribution(p, normalize=True)
302
+ d = torch.cdist(distr_unif.unsqueeze(-1), distr.unsqueeze(-1))
303
+ matching = d.argmin(dim=-1)
304
+
305
+ return p[matching]
src/preprocessing/deepsvg/deepsvg_models/basic_blocks.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class FCN(nn.Module):
11
+ def __init__(self, d_model, n_commands, n_args, args_dim=256):
12
+ super().__init__()
13
+
14
+ self.n_args = n_args
15
+ self.args_dim = args_dim
16
+
17
+ self.command_fcn = nn.Linear(d_model, n_commands)
18
+ self.args_fcn = nn.Linear(d_model, n_args * args_dim)
19
+
20
+ def forward(self, out):
21
+ S, N, _ = out.shape
22
+
23
+ command_logits = self.command_fcn(out) # Shape [S, N, n_commands]
24
+
25
+ args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim]
26
+ args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim]
27
+
28
+ return command_logits, args_logits
29
+
30
+
31
+ class HierarchFCN(nn.Module):
32
+ def __init__(self, d_model, dim_z):
33
+ super().__init__()
34
+
35
+ self.visibility_fcn = nn.Linear(d_model, 2)
36
+ self.z_fcn = nn.Linear(d_model, dim_z)
37
+
38
+ def forward(self, out):
39
+ G, N, _ = out.shape
40
+
41
+ visibility_logits = self.visibility_fcn(out) # Shape [G, N, 2]
42
+ z = self.z_fcn(out) # Shape [G, N, dim_z]
43
+
44
+ return visibility_logits.unsqueeze(0), z.unsqueeze(0)
45
+
46
+
47
+ class ResNet(nn.Module):
48
+ def __init__(self, d_model):
49
+ super().__init__()
50
+
51
+ self.linear1 = nn.Sequential(
52
+ nn.Linear(d_model, d_model), nn.ReLU()
53
+ )
54
+ self.linear2 = nn.Sequential(
55
+ nn.Linear(d_model, d_model), nn.ReLU()
56
+ )
57
+ self.linear3 = nn.Sequential(
58
+ nn.Linear(d_model, d_model), nn.ReLU()
59
+ )
60
+ self.linear4 = nn.Sequential(
61
+ nn.Linear(d_model, d_model), nn.ReLU()
62
+ )
63
+
64
+ def forward(self, z):
65
+ z = z + self.linear1(z)
66
+ z = z + self.linear2(z)
67
+ z = z + self.linear3(z)
68
+ z = z + self.linear4(z)
69
+
70
+ return z
src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar ADDED
File without changes
src/preprocessing/deepsvg/deepsvg_models/layers/attention.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch
7
+ from torch.nn import Linear
8
+ from torch.nn.init import xavier_uniform_
9
+ from torch.nn.init import constant_
10
+ from torch.nn.init import xavier_normal_
11
+ from torch.nn.parameter import Parameter
12
+ from torch.nn.modules.module import Module
13
+
14
+ from .functional import multi_head_attention_forward
15
+
16
+
17
+ class MultiheadAttention(Module):
18
+ r"""Allows the model to jointly attend to information
19
+ from different representation subspaces.
20
+ See reference: Attention Is All You Need
21
+
22
+ .. math::
23
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
24
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
25
+
26
+ Args:
27
+ embed_dim: total dimension of the model.
28
+ num_heads: parallel attention heads.
29
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
30
+ bias: add bias as module parameter. Default: True.
31
+ add_bias_kv: add bias to the key and value sequences at dim=0.
32
+ add_zero_attn: add a new batch of zeros to the key and
33
+ value sequences at dim=1.
34
+ kdim: total number of features in key. Default: None.
35
+ vdim: total number of features in key. Default: None.
36
+
37
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
38
+ query, key, and value have the same number of features.
39
+
40
+ Examples::
41
+
42
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
43
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
44
+ """
45
+ __annotations__ = {
46
+ 'bias_k': torch._jit_internal.Optional[torch.Tensor],
47
+ 'bias_v': torch._jit_internal.Optional[torch.Tensor],
48
+ }
49
+ __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
50
+
51
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
52
+ super(MultiheadAttention, self).__init__()
53
+ self.embed_dim = embed_dim
54
+ self.kdim = kdim if kdim is not None else embed_dim
55
+ self.vdim = vdim if vdim is not None else embed_dim
56
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
57
+
58
+ self.num_heads = num_heads
59
+ self.dropout = dropout
60
+ self.head_dim = embed_dim // num_heads
61
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
62
+
63
+ if self._qkv_same_embed_dim is False:
64
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
65
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
66
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
67
+ self.register_parameter('in_proj_weight', None)
68
+ else:
69
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
70
+ self.register_parameter('q_proj_weight', None)
71
+ self.register_parameter('k_proj_weight', None)
72
+ self.register_parameter('v_proj_weight', None)
73
+
74
+ if bias:
75
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
76
+ else:
77
+ self.register_parameter('in_proj_bias', None)
78
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
79
+
80
+ if add_bias_kv:
81
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
82
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
83
+ else:
84
+ self.bias_k = self.bias_v = None
85
+
86
+ self.add_zero_attn = add_zero_attn
87
+
88
+ self._reset_parameters()
89
+
90
+ def _reset_parameters(self):
91
+ if self._qkv_same_embed_dim:
92
+ xavier_uniform_(self.in_proj_weight)
93
+ else:
94
+ xavier_uniform_(self.q_proj_weight)
95
+ xavier_uniform_(self.k_proj_weight)
96
+ xavier_uniform_(self.v_proj_weight)
97
+
98
+ if self.in_proj_bias is not None:
99
+ constant_(self.in_proj_bias, 0.)
100
+ constant_(self.out_proj.bias, 0.)
101
+ if self.bias_k is not None:
102
+ xavier_normal_(self.bias_k)
103
+ if self.bias_v is not None:
104
+ xavier_normal_(self.bias_v)
105
+
106
+ def __setstate__(self, state):
107
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
108
+ if '_qkv_same_embed_dim' not in state:
109
+ state['_qkv_same_embed_dim'] = True
110
+
111
+ super(MultiheadAttention, self).__setstate__(state)
112
+
113
+ def forward(self, query, key, value, key_padding_mask=None,
114
+ need_weights=True, attn_mask=None):
115
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
116
+ r"""
117
+ Args:
118
+ query, key, value: map a query and a set of key-value pairs to an output.
119
+ See "Attention Is All You Need" for more details.
120
+ key_padding_mask: if provided, specified padding elements in the key will
121
+ be ignored by the attention. This is an binary mask. When the value is True,
122
+ the corresponding value on the attention layer will be filled with -inf.
123
+ need_weights: output attn_output_weights.
124
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
125
+ (i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
126
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
127
+
128
+ Shape:
129
+ - Inputs:
130
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
131
+ the embedding dimension.
132
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
133
+ the embedding dimension.
134
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
135
+ the embedding dimension.
136
+ - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
137
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
138
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
139
+ S is the source sequence length.
140
+
141
+ - Outputs:
142
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
143
+ E is the embedding dimension.
144
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
145
+ L is the target sequence length, S is the source sequence length.
146
+ """
147
+ if not self._qkv_same_embed_dim:
148
+ return multi_head_attention_forward(
149
+ query, key, value, self.embed_dim, self.num_heads,
150
+ self.in_proj_weight, self.in_proj_bias,
151
+ self.bias_k, self.bias_v, self.add_zero_attn,
152
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
153
+ training=self.training,
154
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
155
+ attn_mask=attn_mask, use_separate_proj_weight=True,
156
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
157
+ v_proj_weight=self.v_proj_weight)
158
+ else:
159
+ return multi_head_attention_forward(
160
+ query, key, value, self.embed_dim, self.num_heads,
161
+ self.in_proj_weight, self.in_proj_bias,
162
+ self.bias_k, self.bias_v, self.add_zero_attn,
163
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
164
+ training=self.training,
165
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
166
+ attn_mask=attn_mask)
src/preprocessing/deepsvg/deepsvg_models/layers/functional.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from __future__ import division
7
+
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def multi_head_attention_forward(query, # type: Tensor
14
+ key, # type: Tensor
15
+ value, # type: Tensor
16
+ embed_dim_to_check, # type: int
17
+ num_heads, # type: int
18
+ in_proj_weight, # type: Tensor
19
+ in_proj_bias, # type: Tensor
20
+ bias_k, # type: Optional[Tensor]
21
+ bias_v, # type: Optional[Tensor]
22
+ add_zero_attn, # type: bool
23
+ dropout_p, # type: float
24
+ out_proj_weight, # type: Tensor
25
+ out_proj_bias, # type: Tensor
26
+ training=True, # type: bool
27
+ key_padding_mask=None, # type: Optional[Tensor]
28
+ need_weights=True, # type: bool
29
+ attn_mask=None, # type: Optional[Tensor]
30
+ use_separate_proj_weight=False, # type: bool
31
+ q_proj_weight=None, # type: Optional[Tensor]
32
+ k_proj_weight=None, # type: Optional[Tensor]
33
+ v_proj_weight=None, # type: Optional[Tensor]
34
+ static_k=None, # type: Optional[Tensor]
35
+ static_v=None # type: Optional[Tensor]
36
+ ):
37
+ # type: (...) -> Tuple[Tensor, Optional[Tensor]]
38
+ r"""
39
+ Args:
40
+ query, key, value: map a query and a set of key-value pairs to an output.
41
+ See "Attention Is All You Need" for more details.
42
+ embed_dim_to_check: total dimension of the model.
43
+ num_heads: parallel attention heads.
44
+ in_proj_weight, in_proj_bias: input projection weight and bias.
45
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
46
+ add_zero_attn: add a new batch of zeros to the key and
47
+ value sequences at dim=1.
48
+ dropout_p: probability of an element to be zeroed.
49
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
50
+ training: apply dropout if is ``True``.
51
+ key_padding_mask: if provided, specified padding elements in the key will
52
+ be ignored by the attention. This is an binary mask. When the value is True,
53
+ the corresponding value on the attention layer will be filled with -inf.
54
+ need_weights: output attn_output_weights.
55
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
56
+ (i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
57
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
58
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
59
+ and value in different forms. If false, in_proj_weight will be used, which is
60
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
61
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
62
+ static_k, static_v: static key and value used for attention operators.
63
+ Shape:
64
+ Inputs:
65
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
66
+ the embedding dimension.
67
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
68
+ the embedding dimension.
69
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
70
+ the embedding dimension.
71
+ - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
72
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
73
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
74
+ S is the source sequence length.
75
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
76
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
77
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
78
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
79
+ Outputs:
80
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
81
+ E is the embedding dimension.
82
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
83
+ L is the target sequence length, S is the source sequence length.
84
+ """
85
+
86
+ tgt_len, bsz, embed_dim = query.size()
87
+ assert embed_dim == embed_dim_to_check
88
+ assert key.size() == value.size()
89
+
90
+ head_dim = embed_dim // num_heads
91
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
92
+ scaling = float(head_dim) ** -0.5
93
+
94
+ if not use_separate_proj_weight:
95
+ if torch.equal(query, key) and torch.equal(key, value):
96
+ # self-attention
97
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
98
+
99
+ elif torch.equal(key, value):
100
+ # encoder-decoder attention
101
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
102
+ _b = in_proj_bias
103
+ _start = 0
104
+ _end = embed_dim
105
+ _w = in_proj_weight[_start:_end, :]
106
+ if _b is not None:
107
+ _b = _b[_start:_end]
108
+ q = F.linear(query, _w, _b)
109
+
110
+ if key is None:
111
+ assert value is None
112
+ k = None
113
+ v = None
114
+ else:
115
+
116
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
117
+ _b = in_proj_bias
118
+ _start = embed_dim
119
+ _end = None
120
+ _w = in_proj_weight[_start:, :]
121
+ if _b is not None:
122
+ _b = _b[_start:]
123
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
124
+
125
+ else:
126
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
127
+ _b = in_proj_bias
128
+ _start = 0
129
+ _end = embed_dim
130
+ _w = in_proj_weight[_start:_end, :]
131
+ if _b is not None:
132
+ _b = _b[_start:_end]
133
+ q = F.linear(query, _w, _b)
134
+
135
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
136
+ _b = in_proj_bias
137
+ _start = embed_dim
138
+ _end = embed_dim * 2
139
+ _w = in_proj_weight[_start:_end, :]
140
+ if _b is not None:
141
+ _b = _b[_start:_end]
142
+ k = F.linear(key, _w, _b)
143
+
144
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
145
+ _b = in_proj_bias
146
+ _start = embed_dim * 2
147
+ _end = None
148
+ _w = in_proj_weight[_start:, :]
149
+ if _b is not None:
150
+ _b = _b[_start:]
151
+ v = F.linear(value, _w, _b)
152
+ else:
153
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
154
+ len1, len2 = q_proj_weight_non_opt.size()
155
+ assert len1 == embed_dim and len2 == query.size(-1)
156
+
157
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
158
+ len1, len2 = k_proj_weight_non_opt.size()
159
+ assert len1 == embed_dim and len2 == key.size(-1)
160
+
161
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
162
+ len1, len2 = v_proj_weight_non_opt.size()
163
+ assert len1 == embed_dim and len2 == value.size(-1)
164
+
165
+ if in_proj_bias is not None:
166
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
167
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
168
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
169
+ else:
170
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
171
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
172
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
173
+ q = q * scaling
174
+
175
+ if attn_mask is not None:
176
+ if attn_mask.dim() == 2:
177
+ attn_mask = attn_mask.unsqueeze(0)
178
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
179
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
180
+ elif attn_mask.dim() == 3:
181
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
182
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
183
+ else:
184
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
185
+ # attn_mask's dim is 3 now.
186
+
187
+ if bias_k is not None and bias_v is not None:
188
+ if static_k is None and static_v is None:
189
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
190
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
191
+ if attn_mask is not None:
192
+ attn_mask = F.pad(attn_mask, (0, 1))
193
+ if key_padding_mask is not None:
194
+ key_padding_mask = F.pad(key_padding_mask, (0, 1))
195
+ else:
196
+ assert static_k is None, "bias cannot be added to static key."
197
+ assert static_v is None, "bias cannot be added to static value."
198
+ else:
199
+ assert bias_k is None
200
+ assert bias_v is None
201
+
202
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
203
+ if k is not None:
204
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
205
+ if v is not None:
206
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
207
+
208
+ if static_k is not None:
209
+ assert static_k.size(0) == bsz * num_heads
210
+ assert static_k.size(2) == head_dim
211
+ k = static_k
212
+
213
+ if static_v is not None:
214
+ assert static_v.size(0) == bsz * num_heads
215
+ assert static_v.size(2) == head_dim
216
+ v = static_v
217
+
218
+ src_len = k.size(1)
219
+
220
+ if key_padding_mask is not None:
221
+ assert key_padding_mask.size(0) == bsz
222
+ assert key_padding_mask.size(1) == src_len
223
+
224
+ if add_zero_attn:
225
+ src_len += 1
226
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
227
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
228
+ if attn_mask is not None:
229
+ attn_mask = F.pad(attn_mask, (0, 1))
230
+ if key_padding_mask is not None:
231
+ key_padding_mask = F.pad(key_padding_mask, (0, 1))
232
+
233
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
234
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
235
+
236
+ if attn_mask is not None:
237
+ attn_output_weights += attn_mask
238
+
239
+ if key_padding_mask is not None:
240
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
241
+ attn_output_weights = attn_output_weights.masked_fill(
242
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
243
+ float('-inf'),
244
+ )
245
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
246
+
247
+ attn_output_weights = F.softmax(
248
+ attn_output_weights, dim=-1)
249
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
250
+
251
+ attn_output = torch.bmm(attn_output_weights, v)
252
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
253
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
254
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
255
+
256
+ if need_weights:
257
+ # average attention weights over heads
258
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
259
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
260
+ else:
261
+ return attn_output, None
src/preprocessing/deepsvg/deepsvg_models/layers/improved_transformer.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch
7
+ import copy
8
+
9
+ from torch.nn import functional as F
10
+ from torch.nn.modules.module import Module
11
+ from torch.nn.modules.container import ModuleList
12
+ from torch.nn.init import xavier_uniform_
13
+ from torch.nn.modules.dropout import Dropout
14
+ from torch.nn.modules.linear import Linear
15
+ from torch.nn.modules.normalization import LayerNorm
16
+
17
+ from .attention import MultiheadAttention
18
+ from .transformer import _get_activation_fn
19
+
20
+
21
+ class TransformerEncoderLayerImproved(Module):
22
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
23
+ super(TransformerEncoderLayerImproved, self).__init__()
24
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
25
+
26
+ if d_global2 is not None:
27
+ self.linear_global2 = Linear(d_global2, d_model)
28
+
29
+ # Implementation of Feedforward model
30
+ self.linear1 = Linear(d_model, dim_feedforward)
31
+ self.dropout = Dropout(dropout)
32
+ self.linear2 = Linear(dim_feedforward, d_model)
33
+
34
+ self.norm1 = LayerNorm(d_model)
35
+ self.norm2 = LayerNorm(d_model)
36
+ self.dropout1 = Dropout(dropout)
37
+ self.dropout2_2 = Dropout(dropout)
38
+ self.dropout2 = Dropout(dropout)
39
+
40
+ self.activation = _get_activation_fn(activation)
41
+
42
+ def __setstate__(self, state):
43
+ if 'activation' not in state:
44
+ state['activation'] = F.relu
45
+ super(TransformerEncoderLayerImproved, self).__setstate__(state)
46
+
47
+ def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None):
48
+ src1 = self.norm1(src)
49
+ src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
50
+ src = src + self.dropout1(src2)
51
+
52
+ if memory2 is not None:
53
+ src2_2 = self.linear_global2(memory2)
54
+ src = src + self.dropout2_2(src2_2)
55
+
56
+ src1 = self.norm2(src)
57
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src1))))
58
+ src = src + self.dropout2(src2)
59
+ return src
60
+
61
+
62
+ class TransformerDecoderLayerImproved(Module):
63
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
64
+ super(TransformerDecoderLayerImproved, self).__init__()
65
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
66
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
67
+ # Implementation of Feedforward model
68
+ self.linear1 = Linear(d_model, dim_feedforward)
69
+ self.dropout = Dropout(dropout)
70
+ self.linear2 = Linear(dim_feedforward, d_model)
71
+
72
+ self.norm1 = LayerNorm(d_model)
73
+ self.norm2 = LayerNorm(d_model)
74
+ self.norm3 = LayerNorm(d_model)
75
+ self.dropout1 = Dropout(dropout)
76
+ self.dropout2 = Dropout(dropout)
77
+ self.dropout3 = Dropout(dropout)
78
+
79
+ self.activation = _get_activation_fn(activation)
80
+
81
+ def __setstate__(self, state):
82
+ if 'activation' not in state:
83
+ state['activation'] = F.relu
84
+ super(TransformerDecoderLayerImproved, self).__setstate__(state)
85
+
86
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
87
+ tgt_key_padding_mask=None, memory_key_padding_mask=None):
88
+ tgt1 = self.norm1(tgt)
89
+ tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
90
+ tgt = tgt + self.dropout1(tgt2)
91
+
92
+ tgt1 = self.norm2(tgt)
93
+ tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
94
+ tgt = tgt + self.dropout2(tgt2)
95
+
96
+ tgt1 = self.norm3(tgt)
97
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt
100
+
101
+
102
+ class TransformerDecoderLayerGlobalImproved(Module):
103
+ def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
104
+ super(TransformerDecoderLayerGlobalImproved, self).__init__()
105
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
106
+
107
+ self.linear_global = Linear(d_global, d_model)
108
+
109
+ if d_global2 is not None:
110
+ self.linear_global2 = Linear(d_global2, d_model)
111
+
112
+ # Implementation of Feedforward model
113
+ self.linear1 = Linear(d_model, dim_feedforward)
114
+ self.dropout = Dropout(dropout)
115
+ self.linear2 = Linear(dim_feedforward, d_model)
116
+
117
+ self.norm1 = LayerNorm(d_model)
118
+ self.norm2 = LayerNorm(d_model)
119
+ self.dropout1 = Dropout(dropout)
120
+ self.dropout2 = Dropout(dropout)
121
+ self.dropout2_2 = Dropout(dropout)
122
+ self.dropout3 = Dropout(dropout)
123
+
124
+ self.activation = _get_activation_fn(activation)
125
+
126
+ def __setstate__(self, state):
127
+ if 'activation' not in state:
128
+ state['activation'] = F.relu
129
+ super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state)
130
+
131
+ def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs):
132
+ tgt1 = self.norm1(tgt)
133
+ tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
134
+ tgt = tgt + self.dropout1(tgt2)
135
+
136
+ tgt2 = self.linear_global(memory)
137
+ tgt = tgt + self.dropout2(tgt2) # implicit broadcast
138
+
139
+ if memory2 is not None:
140
+ tgt2_2 = self.linear_global2(memory2)
141
+ tgt = tgt + self.dropout2_2(tgt2_2)
142
+
143
+ tgt1 = self.norm2(tgt)
144
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
145
+ tgt = tgt + self.dropout3(tgt2)
146
+ return tgt
src/preprocessing/deepsvg/deepsvg_models/layers/positional_encoding.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class PositionalEncodingSinCos(nn.Module):
12
+ def __init__(self, d_model, dropout=0.1, max_len=250):
13
+ super(PositionalEncodingSinCos, self).__init__()
14
+ self.dropout = nn.Dropout(p=dropout)
15
+
16
+ pe = torch.zeros(max_len, d_model)
17
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
18
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
19
+ pe[:, 0::2] = torch.sin(position * div_term)
20
+ pe[:, 1::2] = torch.cos(position * div_term)
21
+ pe = pe.unsqueeze(0).transpose(0, 1)
22
+ self.register_buffer('pe', pe)
23
+
24
+ def forward(self, x):
25
+ x = x + self.pe[:x.size(0), :]
26
+ return self.dropout(x)
27
+
28
+
29
+ class PositionalEncodingLUT(nn.Module):
30
+
31
+ def __init__(self, d_model, dropout=0.1, max_len=250):
32
+ super(PositionalEncodingLUT, self).__init__()
33
+ self.dropout = nn.Dropout(p=dropout)
34
+
35
+ position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
36
+ self.register_buffer('position', position)
37
+
38
+ self.pos_embed = nn.Embedding(max_len, d_model)
39
+
40
+ self._init_embeddings()
41
+
42
+ def _init_embeddings(self):
43
+ nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")
44
+
45
+ def forward(self, x):
46
+ pos = self.position[:x.size(0)]
47
+ x = x + self.pos_embed(pos)
48
+ return self.dropout(x)
src/preprocessing/deepsvg/deepsvg_models/layers/transformer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch
7
+ import copy
8
+
9
+ from torch.nn import functional as F
10
+ from torch.nn.modules.module import Module
11
+ from torch.nn.modules.container import ModuleList
12
+ from torch.nn.init import xavier_uniform_
13
+ from torch.nn.modules.dropout import Dropout
14
+ from torch.nn.modules.linear import Linear
15
+ from torch.nn.modules.normalization import LayerNorm
16
+
17
+ from .attention import MultiheadAttention
18
+
19
+
20
+ class Transformer(Module):
21
+ r"""A transformer model. User is able to modify the attributes as needed. The architecture
22
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
23
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
24
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
25
+ Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
26
+ model with corresponding parameters.
27
+
28
+ Args:
29
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
30
+ nhead: the number of heads in the multiheadattention models (default=8).
31
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
32
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
33
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
34
+ dropout: the dropout value (default=0.1).
35
+ activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
36
+ custom_encoder: custom encoder (default=None).
37
+ custom_decoder: custom decoder (default=None).
38
+
39
+ Examples::
40
+ >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
41
+ >>> src = torch.rand((10, 32, 512))
42
+ >>> tgt = torch.rand((20, 32, 512))
43
+ >>> out = transformer_model(src, tgt)
44
+
45
+ Note: A full example to apply nn.Transformer module for the word language model is available in
46
+ https://github.com/pytorch/examples/tree/master/word_language_model
47
+ """
48
+
49
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
50
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
51
+ activation="relu", custom_encoder=None, custom_decoder=None):
52
+ super(Transformer, self).__init__()
53
+
54
+ if custom_encoder is not None:
55
+ self.encoder = custom_encoder
56
+ else:
57
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
58
+ encoder_norm = LayerNorm(d_model)
59
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
60
+
61
+ if custom_decoder is not None:
62
+ self.decoder = custom_decoder
63
+ else:
64
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
65
+ decoder_norm = LayerNorm(d_model)
66
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
67
+
68
+ self._reset_parameters()
69
+
70
+ self.d_model = d_model
71
+ self.nhead = nhead
72
+
73
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None,
74
+ memory_mask=None, src_key_padding_mask=None,
75
+ tgt_key_padding_mask=None, memory_key_padding_mask=None):
76
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa
77
+ r"""Take in and process masked source/target sequences.
78
+
79
+ Args:
80
+ src: the sequence to the encoder (required).
81
+ tgt: the sequence to the decoder (required).
82
+ src_mask: the additive mask for the src sequence (optional).
83
+ tgt_mask: the additive mask for the tgt sequence (optional).
84
+ memory_mask: the additive mask for the encoder output (optional).
85
+ src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
86
+ tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
87
+ memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
88
+
89
+ Shape:
90
+ - src: :math:`(S, N, E)`.
91
+ - tgt: :math:`(T, N, E)`.
92
+ - src_mask: :math:`(S, S)`.
93
+ - tgt_mask: :math:`(T, T)`.
94
+ - memory_mask: :math:`(T, S)`.
95
+ - src_key_padding_mask: :math:`(N, S)`.
96
+ - tgt_key_padding_mask: :math:`(N, T)`.
97
+ - memory_key_padding_mask: :math:`(N, S)`.
98
+
99
+ Note: [src/tgt/memory]_mask should be filled with
100
+ float('-inf') for the masked positions and float(0.0) else. These masks
101
+ ensure that predictions for position i depend only on the unmasked positions
102
+ j and are applied identically for each sequence in a batch.
103
+ [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
104
+ that should be masked with float('-inf') and False values will be unchanged.
105
+ This mask ensures that no information will be taken from position i if
106
+ it is masked, and has a separate mask for each sequence in a batch.
107
+
108
+ - output: :math:`(T, N, E)`.
109
+
110
+ Note: Due to the multi-head attention architecture in the transformer model,
111
+ the output sequence length of a transformer is same as the input sequence
112
+ (i.e. target) length of the decode.
113
+
114
+ where S is the source sequence length, T is the target sequence length, N is the
115
+ batch size, E is the feature number
116
+
117
+ Examples:
118
+ >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
119
+ """
120
+
121
+ if src.size(1) != tgt.size(1):
122
+ raise RuntimeError("the batch number of src and tgt must be equal")
123
+
124
+ if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
125
+ raise RuntimeError("the feature number of src and tgt must be equal to d_model")
126
+
127
+ memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
128
+ output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
129
+ tgt_key_padding_mask=tgt_key_padding_mask,
130
+ memory_key_padding_mask=memory_key_padding_mask)
131
+ return output
132
+
133
+
134
+ def generate_square_subsequent_mask(self, sz):
135
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
136
+ Unmasked positions are filled with float(0.0).
137
+ """
138
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
139
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
140
+ return mask
141
+
142
+
143
+ def _reset_parameters(self):
144
+ r"""Initiate parameters in the transformer model."""
145
+
146
+ for p in self.parameters():
147
+ if p.dim() > 1:
148
+ xavier_uniform_(p)
149
+
150
+
151
+ class TransformerEncoder(Module):
152
+ r"""TransformerEncoder is a stack of N encoder layers
153
+
154
+ Args:
155
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
156
+ num_layers: the number of sub-encoder-layers in the encoder (required).
157
+ norm: the layer normalization component (optional).
158
+
159
+ Examples::
160
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
161
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
162
+ >>> src = torch.rand(10, 32, 512)
163
+ >>> out = transformer_encoder(src)
164
+ """
165
+ __constants__ = ['norm']
166
+
167
+ def __init__(self, encoder_layer, num_layers, norm=None):
168
+ super(TransformerEncoder, self).__init__()
169
+ self.layers = _get_clones(encoder_layer, num_layers)
170
+ self.num_layers = num_layers
171
+ self.norm = norm
172
+
173
+ def forward(self, src, memory2=None, mask=None, src_key_padding_mask=None):
174
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
175
+ r"""Pass the input through the encoder layers in turn.
176
+
177
+ Args:
178
+ src: the sequence to the encoder (required).
179
+ mask: the mask for the src sequence (optional).
180
+ src_key_padding_mask: the mask for the src keys per batch (optional).
181
+
182
+ Shape:
183
+ see the docs in Transformer class.
184
+ """
185
+ output = src
186
+
187
+ for mod in self.layers:
188
+ output = mod(output, memory2=memory2, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
189
+
190
+ if self.norm is not None:
191
+ output = self.norm(output)
192
+
193
+ return output
194
+
195
+
196
+ class TransformerDecoder(Module):
197
+ r"""TransformerDecoder is a stack of N decoder layers
198
+
199
+ Args:
200
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
201
+ num_layers: the number of sub-decoder-layers in the decoder (required).
202
+ norm: the layer normalization component (optional).
203
+
204
+ Examples::
205
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
206
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
207
+ >>> memory = torch.rand(10, 32, 512)
208
+ >>> tgt = torch.rand(20, 32, 512)
209
+ >>> out = transformer_decoder(tgt, memory)
210
+ """
211
+ __constants__ = ['norm']
212
+
213
+ def __init__(self, decoder_layer, num_layers, norm=None):
214
+ super(TransformerDecoder, self).__init__()
215
+ self.layers = _get_clones(decoder_layer, num_layers)
216
+ self.num_layers = num_layers
217
+ self.norm = norm
218
+
219
+ def forward(self, tgt, memory, memory2=None, tgt_mask=None,
220
+ memory_mask=None, tgt_key_padding_mask=None,
221
+ memory_key_padding_mask=None):
222
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
223
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
224
+
225
+ Args:
226
+ tgt: the sequence to the decoder (required).
227
+ memory: the sequence from the last layer of the encoder (required).
228
+ tgt_mask: the mask for the tgt sequence (optional).
229
+ memory_mask: the mask for the memory sequence (optional).
230
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
231
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
232
+
233
+ Shape:
234
+ see the docs in Transformer class.
235
+ """
236
+ output = tgt
237
+
238
+ for mod in self.layers:
239
+ output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
240
+ memory_mask=memory_mask,
241
+ tgt_key_padding_mask=tgt_key_padding_mask,
242
+ memory_key_padding_mask=memory_key_padding_mask)
243
+
244
+ if self.norm is not None:
245
+ output = self.norm(output)
246
+
247
+ return output
248
+
249
+
250
+ class TransformerEncoderLayer(Module):
251
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
252
+ This standard encoder layer is based on the paper "Attention Is All You Need".
253
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
254
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
255
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
256
+ in a different way during application.
257
+
258
+ Args:
259
+ d_model: the number of expected features in the input (required).
260
+ nhead: the number of heads in the multiheadattention models (required).
261
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
262
+ dropout: the dropout value (default=0.1).
263
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
264
+
265
+ Examples::
266
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
267
+ >>> src = torch.rand(10, 32, 512)
268
+ >>> out = encoder_layer(src)
269
+ """
270
+
271
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
272
+ super(TransformerEncoderLayer, self).__init__()
273
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
274
+ # Implementation of Feedforward model
275
+ self.linear1 = Linear(d_model, dim_feedforward)
276
+ self.dropout = Dropout(dropout)
277
+ self.linear2 = Linear(dim_feedforward, d_model)
278
+
279
+ self.norm1 = LayerNorm(d_model)
280
+ self.norm2 = LayerNorm(d_model)
281
+ self.dropout1 = Dropout(dropout)
282
+ self.dropout2 = Dropout(dropout)
283
+
284
+ self.activation = _get_activation_fn(activation)
285
+
286
+ def __setstate__(self, state):
287
+ if 'activation' not in state:
288
+ state['activation'] = F.relu
289
+ super(TransformerEncoderLayer, self).__setstate__(state)
290
+
291
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
292
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
293
+ r"""Pass the input through the encoder layer.
294
+
295
+ Args:
296
+ src: the sequence to the encoder layer (required).
297
+ src_mask: the mask for the src sequence (optional).
298
+ src_key_padding_mask: the mask for the src keys per batch (optional).
299
+
300
+ Shape:
301
+ see the docs in Transformer class.
302
+ """
303
+ src2 = self.self_attn(src, src, src, attn_mask=src_mask,
304
+ key_padding_mask=src_key_padding_mask)[0]
305
+ src = src + self.dropout1(src2)
306
+ src = self.norm1(src)
307
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
308
+ src = src + self.dropout2(src2)
309
+ src = self.norm2(src)
310
+ return src
311
+
312
+
313
+ class TransformerDecoderLayer(Module):
314
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
315
+ This standard decoder layer is based on the paper "Attention Is All You Need".
316
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
317
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
318
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
319
+ in a different way during application.
320
+
321
+ Args:
322
+ d_model: the number of expected features in the input (required).
323
+ nhead: the number of heads in the multiheadattention models (required).
324
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
325
+ dropout: the dropout value (default=0.1).
326
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
327
+
328
+ Examples::
329
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
330
+ >>> memory = torch.rand(10, 32, 512)
331
+ >>> tgt = torch.rand(20, 32, 512)
332
+ >>> out = decoder_layer(tgt, memory)
333
+ """
334
+
335
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
336
+ super(TransformerDecoderLayer, self).__init__()
337
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
338
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
339
+ # Implementation of Feedforward model
340
+ self.linear1 = Linear(d_model, dim_feedforward)
341
+ self.dropout = Dropout(dropout)
342
+ self.linear2 = Linear(dim_feedforward, d_model)
343
+
344
+ self.norm1 = LayerNorm(d_model)
345
+ self.norm2 = LayerNorm(d_model)
346
+ self.norm3 = LayerNorm(d_model)
347
+ self.dropout1 = Dropout(dropout)
348
+ self.dropout2 = Dropout(dropout)
349
+ self.dropout3 = Dropout(dropout)
350
+
351
+ self.activation = _get_activation_fn(activation)
352
+
353
+ def __setstate__(self, state):
354
+ if 'activation' not in state:
355
+ state['activation'] = F.relu
356
+ super(TransformerDecoderLayer, self).__setstate__(state)
357
+
358
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
359
+ tgt_key_padding_mask=None, memory_key_padding_mask=None):
360
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
361
+ r"""Pass the inputs (and mask) through the decoder layer.
362
+
363
+ Args:
364
+ tgt: the sequence to the decoder layer (required).
365
+ memory: the sequence from the last layer of the encoder (required).
366
+ tgt_mask: the mask for the tgt sequence (optional).
367
+ memory_mask: the mask for the memory sequence (optional).
368
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
369
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
370
+
371
+ Shape:
372
+ see the docs in Transformer class.
373
+ """
374
+ tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
375
+ key_padding_mask=tgt_key_padding_mask)[0]
376
+ tgt = tgt + self.dropout1(tgt2)
377
+ tgt = self.norm1(tgt)
378
+ tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
379
+ key_padding_mask=memory_key_padding_mask)[0]
380
+ tgt = tgt + self.dropout2(tgt2)
381
+ tgt = self.norm2(tgt)
382
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
383
+ tgt = tgt + self.dropout3(tgt2)
384
+ tgt = self.norm3(tgt)
385
+ return tgt
386
+
387
+
388
+ def _get_clones(module, N):
389
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
390
+
391
+
392
+ def _get_activation_fn(activation):
393
+ if activation == "relu":
394
+ return F.relu
395
+ elif activation == "gelu":
396
+ return F.gelu
397
+
398
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
src/preprocessing/deepsvg/deepsvg_models/loss.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
10
+ from .model_utils import _get_padding_mask, _get_visibility_mask
11
+ from .model_config import _DefaultConfig
12
+
13
+
14
+ class SVGLoss(nn.Module):
15
+ def __init__(self, cfg: _DefaultConfig):
16
+ super().__init__()
17
+
18
+ self.cfg = cfg
19
+
20
+ self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
21
+
22
+ self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
23
+
24
+ def forward(self, output, labels, weights):
25
+ loss = 0.
26
+ res = {}
27
+
28
+ # VAE
29
+ if self.cfg.use_vae:
30
+ mu, logsigma = output["mu"], output["logsigma"]
31
+ loss_kl = -0.5 * torch.mean(1 + logsigma - mu.pow(2) - torch.exp(logsigma))
32
+ loss_kl = loss_kl.clamp(min=weights["kl_tolerance"])
33
+
34
+ loss += weights["loss_kl_weight"] * loss_kl
35
+ res["loss_kl"] = loss_kl
36
+
37
+ # Target & predictions
38
+ tgt_commands, tgt_args = output["tgt_commands"], output["tgt_args"]
39
+
40
+ visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
41
+ padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)
42
+
43
+ command_logits, args_logits = output["command_logits"], output["args_logits"]
44
+
45
+ # 2-stage visibility
46
+ if self.cfg.decode_stages == 2:
47
+ visibility_logits = output["visibility_logits"]
48
+ loss_visibility = F.cross_entropy(visibility_logits.reshape(-1, 2), visibility_mask.reshape(-1).long())
49
+
50
+ loss += weights["loss_visibility_weight"] * loss_visibility
51
+ res["loss_visibility"] = loss_visibility
52
+
53
+ # Commands & args
54
+ tgt_commands, tgt_args, padding_mask = tgt_commands[..., 1:], tgt_args[..., 1:, :], padding_mask[..., 1:]
55
+
56
+ mask = self.cmd_args_mask[tgt_commands.long()]
57
+
58
+ loss_cmd = F.cross_entropy(command_logits[padding_mask.bool()].reshape(-1, self.cfg.n_commands), tgt_commands[padding_mask.bool()].reshape(-1).long())
59
+ loss_args = F.cross_entropy(args_logits[mask.bool()].reshape(-1, self.args_dim), tgt_args[mask.bool()].reshape(-1).long() + 1) # shift due to -1 PAD_VAL
60
+
61
+ loss += weights["loss_cmd_weight"] * loss_cmd \
62
+ + weights["loss_args_weight"] * loss_args
63
+
64
+ res.update({
65
+ "loss": loss,
66
+ "loss_cmd": loss_cmd,
67
+ "loss_args": loss_args
68
+ })
69
+
70
+ return res
src/preprocessing/deepsvg/deepsvg_models/model.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
7
+ from src.preprocessing.deepsvg.deepsvg_utils.utils import _pack_group_batch, _unpack_group_batch, _make_seq_first, _make_batch_first
8
+
9
+ from .layers.transformer import *
10
+ from .layers.improved_transformer import *
11
+ from .layers.positional_encoding import *
12
+ from .basic_blocks import FCN, HierarchFCN, ResNet
13
+ from .model_config import _DefaultConfig
14
+ from .model_utils import (_get_padding_mask, _get_key_padding_mask, _get_group_mask, _get_visibility_mask,
15
+ _get_key_visibility_mask, _generate_square_subsequent_mask, _sample_categorical, _threshold_sample)
16
+
17
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
18
+ from scipy.optimize import linear_sum_assignment
19
+
20
+
21
+ class SVGEmbedding(nn.Module):
22
+ def __init__(self, cfg: _DefaultConfig, seq_len, rel_args=False, use_group=True, group_len=None):
23
+ super().__init__()
24
+
25
+ self.cfg = cfg
26
+
27
+ self.command_embed = nn.Embedding(cfg.n_commands, cfg.d_model)
28
+
29
+ args_dim = 2 * cfg.args_dim if rel_args else cfg.args_dim + 1
30
+ self.arg_embed = nn.Embedding(args_dim, 64)
31
+ self.embed_fcn = nn.Linear(64 * cfg.n_args, cfg.d_model)
32
+
33
+ self.use_group = use_group
34
+ if use_group:
35
+ if group_len is None:
36
+ group_len = cfg.max_num_groups
37
+ self.group_embed = nn.Embedding(group_len+2, cfg.d_model)
38
+
39
+ self.pos_encoding = PositionalEncodingLUT(cfg.d_model, max_len=seq_len+2)
40
+
41
+ self._init_embeddings()
42
+
43
+ def _init_embeddings(self):
44
+ nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in")
45
+ nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in")
46
+ nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in")
47
+
48
+ if self.use_group:
49
+ nn.init.kaiming_normal_(self.group_embed.weight, mode="fan_in")
50
+
51
+ def forward(self, commands, args, groups=None):
52
+ S, GN = commands.shape
53
+
54
+ src = self.command_embed(commands.long()) + \
55
+ self.embed_fcn(self.arg_embed((args + 1).long()).view(S, GN, -1)) # shift due to -1 PAD_VAL
56
+
57
+ if self.use_group:
58
+ src = src + self.group_embed(groups.long())
59
+
60
+ src = self.pos_encoding(src)
61
+
62
+ return src
63
+
64
+
65
+ class ConstEmbedding(nn.Module):
66
+ def __init__(self, cfg: _DefaultConfig, seq_len):
67
+ super().__init__()
68
+
69
+ self.cfg = cfg
70
+
71
+ self.seq_len = seq_len
72
+
73
+ self.PE = PositionalEncodingLUT(cfg.d_model, max_len=seq_len)
74
+
75
+ def forward(self, z):
76
+ N = z.size(1)
77
+ src = self.PE(z.new_zeros(self.seq_len, N, self.cfg.d_model))
78
+ return src
79
+
80
+
81
+ class LabelEmbedding(nn.Module):
82
+ def __init__(self, cfg: _DefaultConfig):
83
+ super().__init__()
84
+
85
+ self.label_embedding = nn.Embedding(cfg.n_labels, cfg.dim_label)
86
+
87
+ self._init_embeddings()
88
+
89
+ def _init_embeddings(self):
90
+ nn.init.kaiming_normal_(self.label_embedding.weight, mode="fan_in")
91
+
92
+ def forward(self, label):
93
+ src = self.label_embedding(label)
94
+ return src
95
+
96
+
97
+ class Encoder(nn.Module):
98
+ def __init__(self, cfg: _DefaultConfig):
99
+ super().__init__()
100
+
101
+ self.cfg = cfg
102
+
103
+ seq_len = cfg.max_seq_len if cfg.encode_stages == 2 else cfg.max_total_len
104
+ self.use_group = cfg.encode_stages == 1
105
+ self.embedding = SVGEmbedding(cfg, seq_len, use_group=self.use_group)
106
+
107
+ if cfg.label_condition:
108
+ self.label_embedding = LabelEmbedding(cfg)
109
+ dim_label = cfg.dim_label if cfg.label_condition else None
110
+
111
+ if cfg.model_type == "transformer":
112
+ encoder_layer = TransformerEncoderLayerImproved(cfg.d_model, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
113
+ encoder_norm = LayerNorm(cfg.d_model)
114
+ self.encoder = TransformerEncoder(encoder_layer, cfg.n_layers, encoder_norm)
115
+ else: # "lstm"
116
+ self.encoder = nn.LSTM(cfg.d_model, cfg.d_model // 2, dropout=cfg.dropout, bidirectional=True)
117
+
118
+ if cfg.encode_stages == 2:
119
+ if not cfg.self_match:
120
+ self.hierarchical_PE = PositionalEncodingLUT(cfg.d_model, max_len=cfg.max_num_groups)
121
+
122
+ hierarchical_encoder_layer = TransformerEncoderLayerImproved(cfg.d_model, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
123
+ hierarchical_encoder_norm = LayerNorm(cfg.d_model)
124
+ self.hierarchical_encoder = TransformerEncoder(hierarchical_encoder_layer, cfg.n_layers, hierarchical_encoder_norm)
125
+
126
+ def forward(self, commands, args, label=None):
127
+ S, G, N = commands.shape
128
+ l = self.label_embedding(label).unsqueeze(0).unsqueeze(0).repeat(1, commands.size(1), 1, 1) if self.cfg.label_condition else None
129
+
130
+ if self.cfg.encode_stages == 2:
131
+ visibility_mask, key_visibility_mask = _get_visibility_mask(commands, seq_dim=0), _get_key_visibility_mask(commands, seq_dim=0)
132
+
133
+ commands, args, l = _pack_group_batch(commands, args, l)
134
+ padding_mask, key_padding_mask = _get_padding_mask(commands, seq_dim=0), _get_key_padding_mask(commands, seq_dim=0)
135
+ group_mask = _get_group_mask(commands, seq_dim=0) if self.use_group else None
136
+
137
+ src = self.embedding(commands, args, group_mask)
138
+
139
+ if self.cfg.model_type == "transformer":
140
+ memory = self.encoder(src, mask=None, src_key_padding_mask=key_padding_mask, memory2=l)
141
+
142
+ z = (memory * padding_mask).sum(dim=0, keepdim=True) / padding_mask.sum(dim=0, keepdim=True)
143
+ else: # "lstm"
144
+ hidden_cell = (src.new_zeros(2, N, self.cfg.d_model // 2),
145
+ src.new_zeros(2, N, self.cfg.d_model // 2))
146
+ sequence_lengths = padding_mask.sum(dim=0).squeeze(-1)
147
+ x = pack_padded_sequence(src, sequence_lengths, enforce_sorted=False)
148
+
149
+ packed_output, _ = self.encoder(x, hidden_cell)
150
+
151
+ memory, _ = pad_packed_sequence(packed_output)
152
+ idx = (sequence_lengths - 1).long().view(1, -1, 1).repeat(1, 1, self.cfg.d_model)
153
+ z = memory.gather(dim=0, index=idx)
154
+
155
+ z = _unpack_group_batch(N, z)
156
+
157
+ if self.cfg.encode_stages == 2:
158
+ src = z.transpose(0, 1)
159
+ src = _pack_group_batch(src)
160
+ l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None
161
+
162
+ if not self.cfg.self_match:
163
+ src = self.hierarchical_PE(src)
164
+
165
+ memory = self.hierarchical_encoder(src, mask=None, src_key_padding_mask=key_visibility_mask, memory2=l)
166
+ z = (memory * visibility_mask).sum(dim=0, keepdim=True) / visibility_mask.sum(dim=0, keepdim=True)
167
+ z = _unpack_group_batch(N, z)
168
+
169
+ return z
170
+
171
+
172
+ class VAE(nn.Module):
173
+ def __init__(self, cfg: _DefaultConfig):
174
+ super(VAE, self).__init__()
175
+
176
+ self.enc_mu_fcn = nn.Linear(cfg.d_model, cfg.dim_z)
177
+ self.enc_sigma_fcn = nn.Linear(cfg.d_model, cfg.dim_z)
178
+
179
+ self._init_embeddings()
180
+
181
+ def _init_embeddings(self):
182
+ nn.init.normal_(self.enc_mu_fcn.weight, std=0.001)
183
+ nn.init.constant_(self.enc_mu_fcn.bias, 0)
184
+ nn.init.normal_(self.enc_sigma_fcn.weight, std=0.001)
185
+ nn.init.constant_(self.enc_sigma_fcn.bias, 0)
186
+
187
+ def forward(self, z):
188
+ mu, logsigma = self.enc_mu_fcn(z), self.enc_sigma_fcn(z)
189
+ sigma = torch.exp(logsigma / 2.)
190
+ z = mu + sigma * torch.randn_like(sigma)
191
+
192
+ return z, mu, logsigma
193
+
194
+
195
+ class Bottleneck(nn.Module):
196
+ def __init__(self, cfg: _DefaultConfig):
197
+ super(Bottleneck, self).__init__()
198
+
199
+ self.bottleneck = nn.Linear(cfg.d_model, cfg.dim_z)
200
+
201
+ def forward(self, z):
202
+ return self.bottleneck(z)
203
+
204
+
205
+ class Decoder(nn.Module):
206
+ def __init__(self, cfg: _DefaultConfig):
207
+ super(Decoder, self).__init__()
208
+
209
+ self.cfg = cfg
210
+
211
+ if cfg.label_condition:
212
+ self.label_embedding = LabelEmbedding(cfg)
213
+ dim_label = cfg.dim_label if cfg.label_condition else None
214
+
215
+ if cfg.decode_stages == 2:
216
+ self.hierarchical_embedding = ConstEmbedding(cfg, cfg.num_groups_proposal)
217
+
218
+ hierarchical_decoder_layer = TransformerDecoderLayerGlobalImproved(cfg.d_model, cfg.dim_z, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
219
+ hierarchical_decoder_norm = LayerNorm(cfg.d_model)
220
+ self.hierarchical_decoder = TransformerDecoder(hierarchical_decoder_layer, cfg.n_layers_decode, hierarchical_decoder_norm)
221
+ self.hierarchical_fcn = HierarchFCN(cfg.d_model, cfg.dim_z)
222
+
223
+ if cfg.pred_mode == "autoregressive":
224
+ self.embedding = SVGEmbedding(cfg, cfg.max_total_len, rel_args=cfg.rel_targets, use_group=True, group_len=cfg.max_total_len)
225
+
226
+ square_subsequent_mask = _generate_square_subsequent_mask(self.cfg.max_total_len+1)
227
+ self.register_buffer("square_subsequent_mask", square_subsequent_mask)
228
+ else: # "one_shot"
229
+ seq_len = cfg.max_seq_len+1 if cfg.decode_stages == 2 else cfg.max_total_len+1
230
+ self.embedding = ConstEmbedding(cfg, seq_len)
231
+
232
+ if cfg.model_type == "transformer":
233
+ decoder_layer = TransformerDecoderLayerGlobalImproved(cfg.d_model, cfg.dim_z, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
234
+ decoder_norm = LayerNorm(cfg.d_model)
235
+ self.decoder = TransformerDecoder(decoder_layer, cfg.n_layers_decode, decoder_norm)
236
+ else: # "lstm"
237
+ self.fc_hc = nn.Linear(cfg.dim_z, 2 * cfg.d_model)
238
+ self.decoder = nn.LSTM(cfg.d_model, cfg.d_model, dropout=cfg.dropout)
239
+
240
+ args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
241
+ self.fcn = FCN(cfg.d_model, cfg.n_commands, cfg.n_args, args_dim)
242
+
243
+ def _get_initial_state(self, z):
244
+ hidden, cell = torch.split(torch.tanh(self.fc_hc(z)), self.cfg.d_model, dim=2)
245
+ hidden_cell = hidden.contiguous(), cell.contiguous()
246
+ return hidden_cell
247
+
248
+ def forward(self, z, commands, args, label=None, hierarch_logits=None, return_hierarch=False):
249
+ N = z.size(2)
250
+ l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None
251
+ if hierarch_logits is None:
252
+ z = _pack_group_batch(z)
253
+
254
+ if self.cfg.decode_stages == 2:
255
+ if hierarch_logits is None:
256
+ src = self.hierarchical_embedding(z)
257
+ out = self.hierarchical_decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
258
+ hierarch_logits, z = self.hierarchical_fcn(out)
259
+
260
+ if self.cfg.label_condition: l = l.unsqueeze(0).repeat(1, z.size(1), 1, 1)
261
+
262
+ hierarch_logits, z, l = _pack_group_batch(hierarch_logits, z, l)
263
+
264
+ if return_hierarch:
265
+ return _unpack_group_batch(N, hierarch_logits, z)
266
+
267
+ if self.cfg.pred_mode == "autoregressive":
268
+ S = commands.size(0)
269
+ commands, args = _pack_group_batch(commands, args)
270
+
271
+ group_mask = _get_group_mask(commands, seq_dim=0)
272
+
273
+ src = self.embedding(commands, args, group_mask)
274
+
275
+ if self.cfg.model_type == "transformer":
276
+ key_padding_mask = _get_key_padding_mask(commands, seq_dim=0)
277
+ out = self.decoder(src, z, tgt_mask=self.square_subsequent_mask[:S, :S], tgt_key_padding_mask=key_padding_mask, memory2=l)
278
+ else: # "lstm"
279
+ hidden_cell = self._get_initial_state(z)
280
+ out, _ = self.decoder(src, hidden_cell)
281
+
282
+ else: # "one_shot"
283
+ src = self.embedding(z)
284
+ out = self.decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
285
+
286
+ command_logits, args_logits = self.fcn(out)
287
+
288
+ out_logits = (command_logits, args_logits) + ((hierarch_logits,) if self.cfg.decode_stages == 2 else ())
289
+
290
+ return _unpack_group_batch(N, *out_logits)
291
+
292
+
293
+ class SVGTransformer(nn.Module):
294
+ def __init__(self, cfg: _DefaultConfig):
295
+ super(SVGTransformer, self).__init__()
296
+
297
+ self.cfg = cfg
298
+ self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
299
+
300
+ if self.cfg.encode_stages > 0:
301
+
302
+ self.encoder = Encoder(cfg)
303
+
304
+ if cfg.use_resnet:
305
+ self.resnet = ResNet(cfg.d_model)
306
+
307
+ if cfg.use_vae:
308
+ self.vae = VAE(cfg)
309
+ else:
310
+ self.bottleneck = Bottleneck(cfg)
311
+
312
+ self.decoder = Decoder(cfg)
313
+
314
+ self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
315
+
316
+ def perfect_matching(self, command_logits, args_logits, hierarch_logits, tgt_commands, tgt_args):
317
+ with torch.no_grad():
318
+ N, G, S, n_args = tgt_args.shape
319
+ visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
320
+ padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)
321
+
322
+ # Unsqueeze
323
+ tgt_commands, tgt_args, tgt_hierarch = tgt_commands.unsqueeze(2), tgt_args.unsqueeze(2), visibility_mask.unsqueeze(2)
324
+ command_logits, args_logits, hierarch_logits = command_logits.unsqueeze(1), args_logits.unsqueeze(1), hierarch_logits.unsqueeze(1).squeeze(-2)
325
+
326
+ # Loss
327
+ tgt_hierarch, hierarch_logits = tgt_hierarch.repeat(1, 1, self.cfg.num_groups_proposal), hierarch_logits.repeat(1, G, 1, 1)
328
+ tgt_commands, command_logits = tgt_commands.repeat(1, 1, self.cfg.num_groups_proposal, 1), command_logits.repeat(1, G, 1, 1, 1)
329
+ tgt_args, args_logits = tgt_args.repeat(1, 1, self.cfg.num_groups_proposal, 1, 1), args_logits.repeat(1, G, 1, 1, 1, 1)
330
+
331
+ padding_mask, mask = padding_mask.unsqueeze(2).repeat(1, 1, self.cfg.num_groups_proposal, 1), self.cmd_args_mask[tgt_commands.long()]
332
+
333
+ loss_args = F.cross_entropy(args_logits.reshape(-1, self.args_dim), tgt_args.reshape(-1).long() + 1, reduction="none").reshape(N, G, self.cfg.num_groups_proposal, S, n_args) # shift due to -1 PAD_VAL
334
+ loss_cmd = F.cross_entropy(command_logits.reshape(-1, self.cfg.n_commands), tgt_commands.reshape(-1).long(), reduction="none").reshape(N, G, self.cfg.num_groups_proposal, S)
335
+ loss_hierarch = F.cross_entropy(hierarch_logits.reshape(-1, 2), tgt_hierarch.reshape(-1).long(), reduction="none").reshape(N, G, self.cfg.num_groups_proposal)
336
+
337
+ loss_args = (loss_args * mask).sum(dim=[-1, -2]) / mask.sum(dim=[-1, -2])
338
+ loss_cmd = (loss_cmd * padding_mask).sum(dim=-1) / padding_mask.sum(dim=-1)
339
+
340
+ loss = 2.0 * loss_args + 1.0 * loss_cmd + 1.0 * loss_hierarch
341
+
342
+ # Iterate over the batch-dimension
343
+ assignment_list = []
344
+
345
+ full_set = set(range(self.cfg.num_groups_proposal))
346
+ for i in range(N):
347
+ costs = loss[i]
348
+ mask = visibility_mask[i]
349
+ _, assign = linear_sum_assignment(costs[mask].cpu())
350
+ assign = assign.tolist()
351
+ assignment_list.append(assign + list(full_set - set(assign)))
352
+
353
+ assignment = torch.tensor(assignment_list, device=command_logits.device)
354
+
355
+ return assignment.unsqueeze(-1).unsqueeze(-1)
356
+
357
+ def forward(self, commands_enc, args_enc, commands_dec, args_dec, label=None,
358
+ z=None, hierarch_logits=None,
359
+ return_tgt=True, params=None, encode_mode=False, return_hierarch=False):
360
+ commands_enc, args_enc = _make_seq_first(commands_enc, args_enc) # Possibly None, None
361
+ commands_dec_, args_dec_ = _make_seq_first(commands_dec, args_dec)
362
+
363
+ if z is None:
364
+ z = self.encoder(commands_enc, args_enc, label)
365
+
366
+ if self.cfg.use_resnet:
367
+ z = self.resnet(z)
368
+
369
+ if self.cfg.use_vae:
370
+ z, mu, logsigma = self.vae(z)
371
+ else:
372
+ z = self.bottleneck(z)
373
+ else:
374
+ z = _make_seq_first(z)
375
+
376
+ if encode_mode: return z
377
+
378
+ if return_tgt: # Train mode
379
+ commands_dec_, args_dec_ = commands_dec_[:-1], args_dec_[:-1]
380
+
381
+ out_logits = self.decoder(z, commands_dec_, args_dec_, label, hierarch_logits=hierarch_logits,
382
+ return_hierarch=return_hierarch)
383
+
384
+ if return_hierarch:
385
+ return out_logits
386
+
387
+ out_logits = _make_batch_first(*out_logits)
388
+
389
+ if return_tgt and self.cfg.self_match: # Assignment
390
+ assert self.cfg.decode_stages == 2 # Self-matching expects two-stage decoder
391
+ command_logits, args_logits, hierarch_logits = out_logits
392
+
393
+ assignment = self.perfect_matching(command_logits, args_logits, hierarch_logits, commands_dec[..., 1:], args_dec[..., 1:, :])
394
+
395
+ command_logits = torch.gather(command_logits, dim=1, index=assignment.expand_as(command_logits))
396
+ args_logits = torch.gather(args_logits, dim=1, index=assignment.unsqueeze(-1).expand_as(args_logits))
397
+ hierarch_logits = torch.gather(hierarch_logits, dim=1, index=assignment.expand_as(hierarch_logits))
398
+
399
+ out_logits = (command_logits, args_logits, hierarch_logits)
400
+
401
+ res = {
402
+ "command_logits": out_logits[0],
403
+ "args_logits": out_logits[1]
404
+ }
405
+
406
+ if self.cfg.decode_stages == 2:
407
+ res["visibility_logits"] = out_logits[2]
408
+
409
+ if return_tgt:
410
+ res["tgt_commands"] = commands_dec
411
+ res["tgt_args"] = args_dec
412
+
413
+ if self.cfg.use_vae:
414
+ res["mu"] = _make_batch_first(mu)
415
+ res["logsigma"] = _make_batch_first(logsigma)
416
+
417
+ return res
418
+
419
+ def greedy_sample(self, commands_enc=None, args_enc=None, commands_dec=None, args_dec=None, label=None,
420
+ z=None, hierarch_logits=None,
421
+ concat_groups=True, temperature=0.0001):
422
+ if self.cfg.pred_mode == "one_shot":
423
+ res = self.forward(commands_enc, args_enc, commands_dec, args_dec, label=label, z=z, hierarch_logits=hierarch_logits, return_tgt=False)
424
+ commands_y, args_y = _sample_categorical(temperature, res["command_logits"], res["args_logits"])
425
+ args_y -= 1 # shift due to -1 PAD_VAL
426
+ visibility_y = _threshold_sample(res["visibility_logits"], threshold=0.7).bool().squeeze(-1) if self.cfg.decode_stages == 2 else None
427
+ commands_y, args_y = self._make_valid(commands_y, args_y, visibility_y)
428
+ else:
429
+ if z is None:
430
+ z = self.forward(commands_enc, args_enc, None, None, label=label, encode_mode=True)
431
+
432
+ PAD_VAL = -1
433
+ commands_y, args_y = z.new_zeros(1, 1, 1).fill_(SVGTensor.COMMANDS_SIMPLIFIED.index("SOS")).long(), z.new_ones(1, 1, 1, self.cfg.n_args).fill_(PAD_VAL).long()
434
+
435
+ for i in range(self.cfg.max_total_len):
436
+ res = self.forward(None, None, commands_y, args_y, label=label, z=z, hierarch_logits=hierarch_logits, return_tgt=False)
437
+ commands_new_y, args_new_y = _sample_categorical(temperature, res["command_logits"], res["args_logits"])
438
+ args_new_y -= 1 # shift due to -1 PAD_VAL
439
+ _, args_new_y = self._make_valid(commands_new_y, args_new_y)
440
+
441
+ commands_y, args_y = torch.cat([commands_y, commands_new_y[..., -1:]], dim=-1), torch.cat([args_y, args_new_y[..., -1:, :]], dim=-2)
442
+
443
+ commands_y, args_y = commands_y[..., 1:], args_y[..., 1:, :] # Discard SOS token
444
+
445
+ if self.cfg.rel_targets:
446
+ args_y = self._make_absolute(commands_y, args_y)
447
+
448
+ if concat_groups:
449
+ N = commands_y.size(0)
450
+ padding_mask_y = _get_padding_mask(commands_y, seq_dim=-1).bool()
451
+ commands_y, args_y = commands_y[padding_mask_y].reshape(N, -1), args_y[padding_mask_y].reshape(N, -1, self.cfg.n_args)
452
+
453
+ return commands_y, args_y
454
+
455
+ def _make_valid(self, commands_y, args_y, visibility_y=None, PAD_VAL=-1):
456
+ if visibility_y is not None:
457
+ S = commands_y.size(-1)
458
+ commands_y[~visibility_y] = commands_y.new_tensor([SVGTensor.COMMANDS_SIMPLIFIED.index("m"), *[SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")] * (S - 1)])
459
+ args_y[~visibility_y] = PAD_VAL
460
+
461
+ mask = self.cmd_args_mask[commands_y.long()].bool()
462
+ args_y[~mask] = PAD_VAL
463
+
464
+ return commands_y, args_y
465
+
466
+ def _make_absolute(self, commands_y, args_y):
467
+
468
+ mask = self.cmd_args_mask[commands_y.long()].bool()
469
+ args_y[mask] -= self.cfg.args_dim - 1
470
+
471
+ real_commands = commands_y < SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")
472
+
473
+ args_real_commands = args_y[real_commands]
474
+ end_pos = args_real_commands[:-1, SVGTensor.IndexArgs.END_POS].cumsum(dim=0)
475
+
476
+ args_real_commands[1:, SVGTensor.IndexArgs.CONTROL1] += end_pos
477
+ args_real_commands[1:, SVGTensor.IndexArgs.CONTROL2] += end_pos
478
+ args_real_commands[1:, SVGTensor.IndexArgs.END_POS] += end_pos
479
+
480
+ args_y[real_commands] = args_real_commands
481
+
482
+ _, args_y = self._make_valid(commands_y, args_y)
483
+
484
+ return args_y
src/preprocessing/deepsvg/deepsvg_models/model_config.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
7
+
8
+
9
+ class _DefaultConfig:
10
+ """
11
+ Model config.
12
+ """
13
+ def __init__(self):
14
+ self.args_dim = 256 # Coordinate numericalization, default: 256 (8-bit)
15
+ self.n_args = 11 # Tensor nb of arguments, default: 11 (rx,ry,phi,fA,fS,qx1,qy1,qx2,qy2,x1,x2)
16
+ self.n_commands = len(SVGTensor.COMMANDS_SIMPLIFIED) # m, l, c, a, EOS, SOS, z
17
+
18
+ self.dropout = 0.1 # Dropout rate used in basic layers and Transformers
19
+
20
+ self.model_type = "transformer" # "transformer" ("lstm" implementation is work in progress)
21
+
22
+ self.encode_stages = 1 # One-stage or two-stage: 1 | 2
23
+ self.decode_stages = 1 # One-stage or two-stage: 1 | 2
24
+
25
+ self.use_resnet = True # Use extra fully-connected residual blocks after Encoder
26
+
27
+ self.use_vae = True # Sample latent vector (with reparametrization trick) or use encodings directly
28
+
29
+ self.pred_mode = "one_shot" # Feed-forward (one-shot) or autogressive: "one_shot" | "autoregressive"
30
+ self.rel_targets = False # Predict coordinates in relative or absolute format
31
+
32
+ self.label_condition = False # Make all blocks conditional on the label
33
+ self.n_labels = 100 # Number of labels (when used)
34
+ self.dim_label = 64 # Label embedding dimensionality
35
+
36
+ self.self_match = False # Use Hungarian (self-match) or Ordered assignment
37
+
38
+ self.n_layers = 4 # Number of Encoder blocks
39
+ self.n_layers_decode = 4 # Number of Decoder blocks
40
+ self.n_heads = 8 # Transformer config: number of heads
41
+ self.dim_feedforward = 512 # Transformer config: FF dimensionality
42
+ self.d_model = 256 # Transformer config: model dimensionality
43
+
44
+ self.dim_z = 256 # Latent vector dimensionality
45
+
46
+ self.max_num_groups = 8 # Number of paths (N_P)
47
+ self.max_seq_len = 30 # Number of commands (N_C)
48
+ self.max_total_len = self.max_num_groups * self.max_seq_len # Concatenated sequence length for baselines
49
+
50
+ self.num_groups_proposal = self.max_num_groups # Number of predicted paths, default: N_P
51
+
52
+ def get_model_args(self):
53
+ model_args = []
54
+
55
+ model_args += ["commands_grouped", "args_grouped"] if self.encode_stages <= 1 else ["commands", "args"]
56
+
57
+ if self.rel_targets:
58
+ model_args += ["commands_grouped", "args_rel_grouped"] if self.decode_stages == 1 else ["commands", "args_rel"]
59
+ else:
60
+ model_args += ["commands_grouped", "args_grouped"] if self.decode_stages == 1 else ["commands", "args"]
61
+
62
+ if self.label_condition:
63
+ model_args.append("label")
64
+
65
+ return model_args
66
+
67
+
68
+ class SketchRNN(_DefaultConfig):
69
+ # LSTM - Autoregressive - One-stage
70
+ def __init__(self):
71
+ super().__init__()
72
+
73
+ self.model_type = "lstm"
74
+
75
+ self.pred_mode = "autoregressive"
76
+ self.rel_targets = True
77
+
78
+
79
+ class Sketchformer(_DefaultConfig):
80
+ # Transformer - Autoregressive - One-stage
81
+ def __init__(self):
82
+ super().__init__()
83
+
84
+ self.pred_mode = "autoregressive"
85
+ self.rel_targets = True
86
+
87
+
88
+ class OneStageOneShot(_DefaultConfig):
89
+ # Transformer - One-shot - One-stage
90
+ def __init__(self):
91
+ super().__init__()
92
+
93
+ self.encode_stages = 1
94
+ self.decode_stages = 1
95
+
96
+
97
+ class Hierarchical(_DefaultConfig):
98
+ # Transformer - One-shot - Two-stage - Ordered
99
+ def __init__(self):
100
+ super().__init__()
101
+
102
+ self.encode_stages = 2
103
+ self.decode_stages = 2
104
+
105
+
106
+ class HierarchicalSelfMatching(_DefaultConfig):
107
+ # Transformer - One-shot - Two-stage - Hungarian
108
+ def __init__(self):
109
+ super().__init__()
110
+
111
+ self.encode_stages = 2
112
+ self.decode_stages = 2
113
+ self.self_match = True
src/preprocessing/deepsvg/deepsvg_models/model_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch
7
+ from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
8
+ from torch.distributions.categorical import Categorical
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def _get_key_padding_mask(commands, seq_dim=0):
13
+ """
14
+ Args:
15
+ commands: Shape [S, ...]
16
+ """
17
+ with torch.no_grad():
18
+ key_padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) > 0
19
+
20
+ if seq_dim == 0:
21
+ return key_padding_mask.transpose(0, 1)
22
+ return key_padding_mask
23
+
24
+
25
+ def _get_padding_mask(commands, seq_dim=0, extended=False):
26
+ with torch.no_grad():
27
+ padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) == 0
28
+ padding_mask = padding_mask.float()
29
+
30
+ if extended:
31
+ # padding_mask doesn't include the final EOS, extend by 1 position to include it in the loss
32
+ S = commands.size(seq_dim)
33
+ torch.narrow(padding_mask, seq_dim, 3, S-3).add_(torch.narrow(padding_mask, seq_dim, 0, S-3)).clamp_(max=1)
34
+
35
+ if seq_dim == 0:
36
+ return padding_mask.unsqueeze(-1)
37
+ return padding_mask
38
+
39
+
40
+ def _get_group_mask(commands, seq_dim=0):
41
+ """
42
+ Args:
43
+ commands: Shape [S, ...]
44
+ """
45
+ with torch.no_grad():
46
+ group_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("m")).cumsum(dim=seq_dim)
47
+ return group_mask
48
+
49
+
50
+ def _get_visibility_mask(commands, seq_dim=0):
51
+ """
52
+ Args:
53
+ commands: Shape [S, ...]
54
+ """
55
+ S = commands.size(seq_dim)
56
+ with torch.no_grad():
57
+ visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) < S - 1
58
+
59
+ if seq_dim == 0:
60
+ return visibility_mask.unsqueeze(-1)
61
+ return visibility_mask
62
+
63
+
64
+ def _get_key_visibility_mask(commands, seq_dim=0):
65
+ S = commands.size(seq_dim)
66
+ with torch.no_grad():
67
+ key_visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) >= S - 1
68
+
69
+ if seq_dim == 0:
70
+ return key_visibility_mask.transpose(0, 1)
71
+ return key_visibility_mask
72
+
73
+
74
+ def _generate_square_subsequent_mask(sz):
75
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
76
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
77
+ return mask
78
+
79
+
80
+ def _sample_categorical(temperature=0.0001, *args_logits):
81
+ if len(args_logits) == 1:
82
+ arg_logits, = args_logits
83
+ return Categorical(logits=arg_logits / temperature).sample()
84
+ return (*(Categorical(logits=arg_logits / temperature).sample() for arg_logits in args_logits),)
85
+
86
+
87
+ def _threshold_sample(arg_logits, threshold=0.5, temperature=1.0):
88
+ scores = F.softmax(arg_logits / temperature, dim=-1)[..., 1]
89
+ return scores > threshold
src/preprocessing/deepsvg/deepsvg_schedulers/warmup.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from torch.optim.lr_scheduler import _LRScheduler
7
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
8
+
9
+
10
+ class GradualWarmupScheduler(_LRScheduler):
11
+ """ Gradually warm-up(increasing) learning rate in optimizer.
12
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
13
+ Args:
14
+ optimizer (Optimizer): Wrapped optimizer.
15
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
16
+ total_epoch: target learning rate is reached at total_epoch, gradually
17
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
18
+ """
19
+
20
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
21
+ self.multiplier = multiplier
22
+ if self.multiplier < 1.:
23
+ raise ValueError('multiplier should be greater thant or equal to 1.')
24
+ self.total_epoch = total_epoch
25
+ self.after_scheduler = after_scheduler
26
+ self.finished = False
27
+ super(GradualWarmupScheduler, self).__init__(optimizer)
28
+
29
+ def get_lr(self):
30
+ if self.last_epoch > self.total_epoch:
31
+ if self.after_scheduler:
32
+ if not self.finished:
33
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
34
+ self.finished = True
35
+ return self.after_scheduler.get_last_lr()
36
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
37
+
38
+ if self.multiplier == 1.0:
39
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
40
+ else:
41
+ return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
42
+
43
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
44
+ if epoch is None:
45
+ epoch = self.last_epoch + 1
46
+ self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
47
+ if self.last_epoch <= self.total_epoch:
48
+ warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
49
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
50
+ param_group['lr'] = lr
51
+ else:
52
+ if epoch is None:
53
+ self.after_scheduler.step(metrics, None)
54
+ else:
55
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
56
+
57
+ def step(self, epoch=None, metrics=None):
58
+ if type(self.after_scheduler) != ReduceLROnPlateau:
59
+ if self.finished and self.after_scheduler:
60
+ if epoch is None:
61
+ self.after_scheduler.step(None)
62
+ else:
63
+ self.after_scheduler.step(epoch - self.total_epoch)
64
+ self._last_lr = self.after_scheduler.get_last_lr()
65
+ else:
66
+ return super(GradualWarmupScheduler, self).step(epoch)
67
+ else:
68
+ self.step_ReduceLROnPlateau(metrics, epoch)
src/preprocessing/deepsvg/deepsvg_svglib/geom.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import numpy as np
8
+ from enum import Enum
9
+ import torch
10
+ from typing import List, Union
11
+ Num = Union[int, float]
12
+ float_type = (int, float, np.float32)
13
+
14
+
15
+ def det(a: Point, b: Point):
16
+ return a.pos[0] * b.pos[1] - a.pos[1] * b.pos[0]
17
+
18
+
19
+ def get_rotation_matrix(angle: Union[Angle, float]):
20
+ if isinstance(angle, Angle):
21
+ theta = angle.rad
22
+ else:
23
+ theta = angle
24
+ c, s = np.cos(theta), np.sin(theta)
25
+ rot_m = np.array([[c, -s],
26
+ [s, c]], dtype=np.float32)
27
+ return rot_m
28
+
29
+
30
+ def union_bbox(bbox_list: List[Bbox]):
31
+ res = None
32
+ for bbox in bbox_list:
33
+ res = bbox.union(res)
34
+ return res
35
+
36
+
37
+ class Geom:
38
+ def copy(self):
39
+ raise NotImplementedError
40
+
41
+ def to_str(self):
42
+ raise NotImplementedError
43
+
44
+ def to_tensor(self):
45
+ raise NotImplementedError
46
+
47
+ @staticmethod
48
+ def from_tensor(vector: torch.Tensor):
49
+ raise NotImplementedError
50
+
51
+ def scale(self, factor):
52
+ pass
53
+
54
+ def translate(self, vec):
55
+ pass
56
+
57
+ def rotate(self, angle: Union[Angle, float]):
58
+ pass
59
+
60
+ def numericalize(self, n=256):
61
+ raise NotImplementedError
62
+
63
+
64
+ ######### Point
65
+ class Point(Geom):
66
+ num_args = 2
67
+
68
+ def __init__(self, x=None, y=None):
69
+ if isinstance(x, np.ndarray):
70
+ self.pos = x.astype(np.float32)
71
+ elif x is None and y is None:
72
+ self.pos = np.array([0., 0.], dtype=np.float32)
73
+ elif (isinstance(x, float_type) or x is None) and (isinstance(y, float_type) or y is None):
74
+ if x is None:
75
+ x = y
76
+ if y is None:
77
+ y = x
78
+ self.pos = np.array([x, y], dtype=np.float32)
79
+ else:
80
+ raise ValueError()
81
+
82
+ def copy(self):
83
+ return Point(self.pos.copy())
84
+
85
+ @property
86
+ def x(self):
87
+ return self.pos[0]
88
+
89
+ @property
90
+ def y(self):
91
+ return self.pos[1]
92
+
93
+ def xproj(self):
94
+ return Point(self.x, 0.)
95
+
96
+ def yproj(self):
97
+ return Point(0., self.y)
98
+
99
+ def __add__(self, other):
100
+ return Point(self.pos + other.pos)
101
+
102
+ def __sub__(self, other):
103
+ return self + other.__neg__()
104
+
105
+ def __mul__(self, lmbda):
106
+ if isinstance(lmbda, Point):
107
+ return Point(self.pos * lmbda.pos)
108
+
109
+ assert isinstance(lmbda, float_type)
110
+ return Point(lmbda * self.pos)
111
+
112
+ def __rmul__(self, lmbda):
113
+ return self * lmbda
114
+
115
+ def __truediv__(self, lmbda):
116
+ if isinstance(lmbda, Point):
117
+ return Point(self.pos / lmbda.pos)
118
+
119
+ assert isinstance(lmbda, float_type)
120
+ return self * (1 / lmbda)
121
+
122
+ def __neg__(self):
123
+ return self * -1
124
+
125
+ def __repr__(self):
126
+ return f"P({self.x}, {self.y})"
127
+
128
+ def to_str(self):
129
+ return f"{self.x} {self.y}"
130
+
131
+ def tolist(self):
132
+ return self.pos.tolist()
133
+
134
+ def to_tensor(self):
135
+ return torch.tensor(self.pos)
136
+
137
+ @staticmethod
138
+ def from_tensor(vector: torch.Tensor):
139
+ return Point(*vector.tolist())
140
+
141
+ def translate(self, vec: Point):
142
+ self.pos += vec.pos
143
+
144
+ def matmul(self, m):
145
+ return Point(m @ self.pos)
146
+
147
+ def rotate(self, angle: Union[Angle, float]):
148
+ rot_m = get_rotation_matrix(angle)
149
+ return self.matmul(rot_m)
150
+
151
+ def rotate_(self, angle: Union[Angle, float]):
152
+ rot_m = get_rotation_matrix(angle)
153
+ self.pos = rot_m @ self.pos
154
+
155
+ def scale(self, factor):
156
+ self.pos *= factor
157
+
158
+ def dot(self, other: Point):
159
+ return self.pos.dot(other.pos)
160
+
161
+ def norm(self):
162
+ return float(np.linalg.norm(self.pos))
163
+
164
+ def cross(self, other: Point):
165
+ return np.cross(self.pos, other.pos)
166
+
167
+ def dist(self, other: Point):
168
+ return (self - other).norm()
169
+
170
+ def angle(self, other: Point, signed=False):
171
+ rad = np.arccos(np.clip(self.normalize().dot(other.normalize()), -1., 1.))
172
+
173
+ if signed:
174
+ sign = 1 if det(self, other) >= 0 else -1
175
+ rad *= sign
176
+ return Angle.Rad(rad)
177
+
178
+ def distToLine(self, p1: Point, p2: Point):
179
+ if p1.isclose(p2):
180
+ return self.dist(p1)
181
+
182
+ return abs((p2 - p1).cross(p1 - self)) / (p2 - p1).norm()
183
+
184
+ def normalize(self):
185
+ return self / self.norm()
186
+
187
+ def numericalize(self, n=256):
188
+ self.pos = self.pos.round().clip(min=0, max=n-1)
189
+
190
+ def isclose(self, other: Point):
191
+ return np.allclose(self.pos, other.pos)
192
+
193
+ def iszero(self):
194
+ return np.all(self.pos == 0)
195
+
196
+ def pointwise_min(self, other: Point):
197
+ return Point(min(self.x, other.x), min(self.y, other.y))
198
+
199
+ def pointwise_max(self, other: Point):
200
+ return Point(max(self.x, other.x), max(self.y, other.y))
201
+
202
+
203
+ class Radius(Point):
204
+ def __init__(self, *args, **kwargs):
205
+ super().__init__(*args, **kwargs)
206
+
207
+ def copy(self):
208
+ return Radius(self.pos.copy())
209
+
210
+ def __repr__(self):
211
+ return f"Rad({self.pos[0]}, {self.pos[1]})"
212
+
213
+ def translate(self, vec: Point):
214
+ pass
215
+
216
+
217
+ class Size(Point):
218
+ def __init__(self, *args, **kwargs):
219
+ super().__init__(*args, **kwargs)
220
+
221
+ def copy(self):
222
+ return Size(self.pos.copy())
223
+
224
+ def __repr__(self):
225
+ return f"Size({self.pos[0]}, {self.pos[1]})"
226
+
227
+ def max(self):
228
+ return self.pos.max()
229
+
230
+ def min(self):
231
+ return self.pos.min()
232
+
233
+ def translate(self, vec: Point):
234
+ pass
235
+
236
+
237
+ ######### Coord
238
+ class Coord(Geom):
239
+ num_args = 1
240
+
241
+ class XY(Enum):
242
+ X = "x"
243
+ Y = "y"
244
+
245
+ def __init__(self, coord, xy: XY = XY.X):
246
+ self.coord = coord
247
+ self.xy = xy
248
+
249
+ def __repr__(self):
250
+ return f"{self.xy.value}({self.coord})"
251
+
252
+ def to_str(self):
253
+ return str(self.coord)
254
+
255
+ def to_tensor(self):
256
+ return torch.tensor([self.coord])
257
+
258
+ def __add__(self, other):
259
+ if isinstance(other, float_type):
260
+ return Coord(self.coord + other, self.xy)
261
+ elif isinstance(other, Coord):
262
+ if self.xy != other.xy:
263
+ raise ValueError()
264
+ return Coord(self.coord + other.coord, self.xy)
265
+ elif isinstance(other, Point):
266
+ return Coord(self.coord + getattr(other, self.xy.value), self.xy)
267
+ else:
268
+ raise ValueError()
269
+
270
+ def __sub__(self, other):
271
+ return self + other.__neg__()
272
+
273
+ def __mul__(self, lmbda):
274
+ assert isinstance(lmbda, float_type)
275
+ return Coord(lmbda * self.coord)
276
+
277
+ def __neg__(self):
278
+ return self * -1
279
+
280
+ def scale(self, factor):
281
+ self.coord *= factor
282
+
283
+ def translate(self, vec: Point):
284
+ self.coord += getattr(vec, self.xy.value)
285
+
286
+ def to_point(self, pos: Point, is_absolute=True):
287
+ point = pos.copy() if is_absolute else Point(0.)
288
+ point.pos[int(self.xy == Coord.XY.Y)] = self.coord
289
+ return point
290
+
291
+
292
+ class XCoord(Coord):
293
+ def __init__(self, coord):
294
+ super().__init__(coord, xy=Coord.XY.X)
295
+
296
+ def copy(self):
297
+ return XCoord(self.coord)
298
+
299
+
300
+ class YCoord(Coord):
301
+ def __init__(self, coord):
302
+ super().__init__(coord, xy=Coord.XY.Y)
303
+
304
+ def copy(self):
305
+ return YCoord(self.coord)
306
+
307
+
308
+ ######### Bbox
309
+ class Bbox(Geom):
310
+ num_args = 4
311
+
312
+ def __init__(self, x=None, y=None, w=None, h=None):
313
+ if isinstance(x, Point) and isinstance(y, Point):
314
+ self.xy = x
315
+ wh = y - x
316
+ self.wh = Size(wh.x, wh.y)
317
+ elif (isinstance(x, float_type) or x is None) and (isinstance(y, float_type) or y is None):
318
+ if x is None:
319
+ x = 0.
320
+ if y is None:
321
+ y = float(x)
322
+
323
+ if w is None and h is None:
324
+ w, h = float(x), float(y)
325
+ x, y = 0., 0.
326
+ self.xy = Point(x, y)
327
+ self.wh = Size(w, h)
328
+ else:
329
+ raise ValueError()
330
+
331
+ @property
332
+ def xy2(self):
333
+ return self.xy + self.wh
334
+
335
+ def copy(self):
336
+ bbox = Bbox()
337
+ bbox.xy = self.xy.copy()
338
+ bbox.wh = self.wh.copy()
339
+ return bbox
340
+
341
+ @property
342
+ def size(self):
343
+ return self.wh
344
+
345
+ @property
346
+ def center(self):
347
+ return self.xy + self.wh / 2
348
+
349
+ def __repr__(self):
350
+ return f"Bbox({self.xy.to_str()} {self.wh.to_str()})"
351
+
352
+ def to_str(self):
353
+ return f"{self.xy.to_str()} {self.wh.to_str()}"
354
+
355
+ def to_tensor(self):
356
+ return torch.tensor([*self.xy.to_tensor(), *self.wh.to_tensor()])
357
+
358
+ def make_square(self, min_size=None):
359
+ center = self.center
360
+ size = self.wh.max()
361
+
362
+ if min_size is not None:
363
+ size = max(size, min_size)
364
+
365
+ self.wh = Size(size, size)
366
+ self.xy = center - self.wh / 2
367
+
368
+ return self
369
+
370
+ def translate(self, vec):
371
+ self.xy.translate(vec)
372
+
373
+ def scale(self, factor):
374
+ self.xy.scale(factor)
375
+ self.wh.scale(factor)
376
+
377
+ def union(self, other: Bbox):
378
+ if other is None:
379
+ return self
380
+ return Bbox(self.xy.pointwise_min(other.xy), self.xy2.pointwise_max(other.xy2))
381
+
382
+ def intersect(self, other: Bbox):
383
+ if other is None:
384
+ return self
385
+
386
+ bbox = Bbox(self.xy.pointwise_max(other.xy), self.xy2.pointwise_min(other.xy2))
387
+ if bbox.wh.x < 0 or bbox.wh.y < 0:
388
+ return None
389
+
390
+ return bbox
391
+
392
+ @staticmethod
393
+ def from_points(points: List[Point]):
394
+ if not points:
395
+ return None
396
+ xy = xy2 = points[0]
397
+ for p in points[1:]:
398
+ xy = xy.pointwise_min(p)
399
+ xy2 = xy2.pointwise_max(p)
400
+ return Bbox(xy, xy2)
401
+
402
+ def to_rectangle(self, *args, **kwargs):
403
+ from .svg_primitive import SVGRectangle
404
+ return SVGRectangle(self.xy, self.wh, *args, **kwargs)
405
+
406
+ def area(self):
407
+ return self.wh.pos.prod()
408
+
409
+ def overlap(self, other):
410
+ inter = self.intersect(other)
411
+ if inter is None:
412
+ return 0.
413
+ return inter.area() / self.area()
414
+
415
+
416
+ ######### Angle
417
+ class Angle(Geom):
418
+ num_args = 1
419
+
420
+ def __init__(self, deg):
421
+ self.deg = deg
422
+
423
+ @property
424
+ def rad(self):
425
+ return np.deg2rad(self.deg)
426
+
427
+ def copy(self):
428
+ return Angle(self.deg)
429
+
430
+ def __repr__(self):
431
+ return f"α({self.deg})"
432
+
433
+ def to_str(self):
434
+ return str(self.deg)
435
+
436
+ def to_tensor(self):
437
+ return torch.tensor([self.deg])
438
+
439
+ @staticmethod
440
+ def from_tensor(vector: torch.Tensor):
441
+ return Angle(vector.item())
442
+
443
+ @staticmethod
444
+ def Rad(rad):
445
+ return Angle(np.rad2deg(rad))
446
+
447
+ def __add__(self, other: Angle):
448
+ return Angle(self.deg + other.deg)
449
+
450
+ def __sub__(self, other: Angle):
451
+ return self + other.__neg__()
452
+
453
+ def __mul__(self, lmbda):
454
+ assert isinstance(lmbda, float_type)
455
+ return Angle(lmbda * self.deg)
456
+
457
+ def __rmul__(self, lmbda):
458
+ assert isinstance(lmbda, float_type)
459
+ return self * lmbda
460
+
461
+ def __truediv__(self, lmbda):
462
+ assert isinstance(lmbda, float_type)
463
+ return self * (1 / lmbda)
464
+
465
+ def __neg__(self):
466
+ return self * -1
467
+
468
+
469
+ ######### Flag
470
+ class Flag(Geom):
471
+ num_args = 1
472
+
473
+ def __init__(self, flag):
474
+ self.flag = int(flag)
475
+
476
+ def copy(self):
477
+ return Flag(self.flag)
478
+
479
+ def __repr__(self):
480
+ return f"flag({self.flag})"
481
+
482
+ def to_str(self):
483
+ return str(self.flag)
484
+
485
+ def to_tensor(self):
486
+ return torch.tensor([self.flag])
487
+
488
+ def __invert__(self):
489
+ return Flag(1 - self.flag)
490
+
491
+ @staticmethod
492
+ def from_tensor(vector: torch.Tensor):
493
+ return Flag(vector.item())
src/preprocessing/deepsvg/deepsvg_svglib/svg.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from .geom import *
8
+ from xml.dom import expatbuilder
9
+ import torch
10
+ from typing import List, Union
11
+ import IPython.display as ipd
12
+ import cairosvg
13
+ from PIL import Image
14
+ import io
15
+ import os
16
+ from moviepy.editor import ImageClip, concatenate_videoclips, ipython_display
17
+ import math
18
+ import random
19
+ import networkx as nx
20
+
21
+ Num = Union[int, float]
22
+
23
+ from .svg_command import SVGCommandBezier
24
+ from .svg_path import SVGPath, Filling, Orientation
25
+ from .svg_primitive import SVGPathGroup, SVGRectangle, SVGCircle, SVGEllipse, SVGLine, SVGPolyline, SVGPolygon
26
+ from .geom import union_bbox
27
+
28
+
29
+ class SVG:
30
+ def __init__(self, svg_path_groups: List[SVGPathGroup], viewbox: Bbox = None):
31
+ if viewbox is None:
32
+ viewbox = Bbox(24)
33
+
34
+ self.svg_path_groups = svg_path_groups
35
+ self.viewbox = viewbox
36
+
37
+ def __add__(self, other: SVG):
38
+ svg = self.copy()
39
+ svg.svg_path_groups.extend(other.svg_path_groups)
40
+ return svg
41
+
42
+ @property
43
+ def paths(self):
44
+ for path_group in self.svg_path_groups:
45
+ for path in path_group.svg_paths:
46
+ yield path
47
+
48
+ def __getitem__(self, idx):
49
+ if isinstance(idx, tuple):
50
+ assert len(idx) == 2, "Dimension out of range"
51
+ i, j = idx
52
+ return self.svg_path_groups[i][j]
53
+
54
+ return self.svg_path_groups[idx]
55
+
56
+ def __len__(self):
57
+ return len(self.svg_path_groups)
58
+
59
+ def total_length(self):
60
+ return sum([path_group.total_len() for path_group in self.svg_path_groups])
61
+
62
+ @property
63
+ def start_pos(self):
64
+ return Point(0.)
65
+
66
+ @property
67
+ def end_pos(self):
68
+ if not self.svg_path_groups:
69
+ return Point(0.)
70
+
71
+ return self.svg_path_groups[-1].end_pos
72
+
73
+ def copy(self):
74
+ return SVG([svg_path_group.copy() for svg_path_group in self.svg_path_groups], self.viewbox.copy())
75
+
76
+ @staticmethod
77
+ def load_svg(file_path):
78
+ with open(file_path, "r") as f:
79
+ return SVG.from_str(f.read())
80
+
81
+ @staticmethod
82
+ def load_splineset(spline_str: str, width, height, add_closing=True):
83
+ if "SplineSet" not in spline_str:
84
+ raise ValueError("Not a SplineSet")
85
+
86
+ spline = spline_str[spline_str.index('SplineSet') + 10:spline_str.index('EndSplineSet')]
87
+ svg_str = SVG._spline_to_svg_str(spline, height)
88
+
89
+ if not svg_str:
90
+ raise ValueError("Empty SplineSet")
91
+
92
+ svg_path_group = SVGPath.from_str(svg_str, add_closing=add_closing)
93
+ return SVG([svg_path_group], viewbox=Bbox(width, height))
94
+
95
+ @staticmethod
96
+ def _spline_to_svg_str(spline_str: str, height, replace_with_prev=False):
97
+ path = []
98
+ prev_xy = []
99
+ for line in spline_str.splitlines():
100
+ if not line:
101
+ continue
102
+ tokens = line.split(' ')
103
+ cmd = tokens[-2]
104
+ if cmd not in 'cml':
105
+ raise ValueError(f"Command not recognized: {cmd}")
106
+ args = tokens[:-2]
107
+ args = [float(x) for x in args if x]
108
+
109
+ if replace_with_prev and cmd in 'c':
110
+ args[:2] = prev_xy
111
+ prev_xy = args[-2:]
112
+
113
+ new_y_args = []
114
+ for i, a in enumerate(args):
115
+ if i % 2 == 1:
116
+ new_y_args.append(str(height - a))
117
+ else:
118
+ new_y_args.append(str(a))
119
+
120
+ path.extend([cmd.upper()] + new_y_args)
121
+ return " ".join(path)
122
+
123
+ @staticmethod
124
+ def from_str(svg_str: str):
125
+ svg_path_groups = []
126
+ svg_dom = expatbuilder.parseString(svg_str, False)
127
+ svg_root = svg_dom.getElementsByTagName('svg')[0]
128
+
129
+ viewbox_list = list(map(float, svg_root.getAttribute("viewBox").split(" ")))
130
+ view_box = Bbox(*viewbox_list)
131
+
132
+ primitives = {
133
+ "path": SVGPath,
134
+ "rect": SVGRectangle,
135
+ "circle": SVGCircle, "ellipse": SVGEllipse,
136
+ "line": SVGLine,
137
+ "polyline": SVGPolyline, "polygon": SVGPolygon
138
+ }
139
+
140
+ for tag, Primitive in primitives.items():
141
+ for x in svg_dom.getElementsByTagName(tag):
142
+ svg_path_groups.append(Primitive.from_xml(x))
143
+
144
+ return SVG(svg_path_groups, view_box)
145
+
146
+ def to_tensor(self, concat_groups=True, PAD_VAL=-1):
147
+ group_tensors = [p.to_tensor(PAD_VAL=PAD_VAL) for p in self.svg_path_groups]
148
+
149
+ if concat_groups:
150
+ return torch.cat(group_tensors, dim=0)
151
+
152
+ return group_tensors
153
+
154
+ def to_fillings(self):
155
+ return [p.path.filling for p in self.svg_path_groups]
156
+
157
+ @staticmethod
158
+ def from_tensor(tensor: torch.Tensor, viewbox: Bbox = None, allow_empty=False):
159
+ if viewbox is None:
160
+ viewbox = Bbox(24)
161
+
162
+ svg = SVG([SVGPath.from_tensor(tensor, allow_empty=allow_empty)], viewbox=viewbox)
163
+ return svg
164
+
165
+ @staticmethod
166
+ def from_tensors(tensors: List[torch.Tensor], viewbox: Bbox = None, allow_empty=False):
167
+ if viewbox is None:
168
+ viewbox = Bbox(24)
169
+
170
+ svg = SVG([SVGPath.from_tensor(t, allow_empty=allow_empty) for t in tensors], viewbox=viewbox)
171
+ return svg
172
+
173
+ def save_svg(self, file_path):
174
+ with open(file_path, "w") as f:
175
+ f.write(self.to_str())
176
+
177
+ def save_png(self, file_path):
178
+ cairosvg.svg2png(bytestring=self.to_str(), write_to=file_path)
179
+
180
+ def draw(self, fill=False, file_path=None, do_display=True, return_png=False,
181
+ with_points=False, with_handles=False, with_bboxes=False, with_markers=False, color_firstlast=False,
182
+ with_moves=True):
183
+ if file_path is not None:
184
+ _, file_extension = os.path.splitext(file_path)
185
+ if file_extension == ".svg":
186
+ self.save_svg(file_path)
187
+ elif file_extension == ".png":
188
+ self.save_png(file_path)
189
+ else:
190
+ raise ValueError(f"Unsupported file_path extension {file_extension}")
191
+
192
+ svg_str = self.to_str(fill=fill, with_points=with_points, with_handles=with_handles, with_bboxes=with_bboxes,
193
+ with_markers=with_markers, color_firstlast=color_firstlast, with_moves=with_moves)
194
+
195
+ if do_display:
196
+ ipd.display(ipd.SVG(svg_str))
197
+
198
+ if return_png:
199
+ if file_path is None:
200
+ img_data = cairosvg.svg2png(bytestring=svg_str)
201
+ return Image.open(io.BytesIO(img_data))
202
+ else:
203
+ _, file_extension = os.path.splitext(file_path)
204
+
205
+ if file_extension == ".svg":
206
+ img_data = cairosvg.svg2png(url=file_path)
207
+ return Image.open(io.BytesIO(img_data))
208
+ else:
209
+ return Image.open(file_path)
210
+
211
+ def draw_colored(self, *args, **kwargs):
212
+ self.copy().normalize().split_paths().set_color("random").draw(*args, **kwargs)
213
+
214
+ def __repr__(self):
215
+ return "SVG[{}](\n{}\n)".format(self.viewbox,
216
+ ",\n".join([f"\t{svg_path_group}" for svg_path_group in self.svg_path_groups]))
217
+
218
+ def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=False,
219
+ with_moves=True):
220
+ viz_elements = []
221
+ for svg_path_group in self.svg_path_groups:
222
+ viz_elements.extend(
223
+ svg_path_group._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves))
224
+ return viz_elements
225
+
226
+ def _markers(self):
227
+ return ('<defs>'
228
+ '<marker id="arrow" viewBox="0 0 10 10" markerWidth="4" markerHeight="4" refX="0" refY="3" orient="auto" markerUnits="strokeWidth">'
229
+ '<path d="M0,0 L0,6 L9,3 z" fill="#f00" />'
230
+ '</marker>'
231
+ '</defs>')
232
+
233
+ def to_str(self, fill=False, with_points=False, with_handles=False, with_bboxes=False, with_markers=False,
234
+ color_firstlast=False, with_moves=True) -> str:
235
+ viz_elements = self._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves)
236
+ newline = "\n"
237
+ return (
238
+ f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="{self.viewbox.to_str()}" height="200px" width="200px">'
239
+ f'{self._markers() if with_markers else ""}'
240
+ f'{newline.join(svg_path_group.to_str(fill=fill, with_markers=with_markers) for svg_path_group in [*self.svg_path_groups, *viz_elements])}'
241
+ '</svg>')
242
+
243
+ def _apply_to_paths(self, method, *args, **kwargs):
244
+ for path_group in self.svg_path_groups:
245
+ getattr(path_group, method)(*args, **kwargs)
246
+ return self
247
+
248
+ def split_paths(self):
249
+ path_groups = []
250
+ for path_group in self.svg_path_groups:
251
+ path_groups.extend(path_group.split_paths())
252
+ self.svg_path_groups = path_groups
253
+ return self
254
+
255
+ def merge_groups(self):
256
+ path_group = self.svg_path_groups[0]
257
+ for path_group in self.svg_path_groups[1:]:
258
+ path_group.svg_paths.extend(path_group.svg_paths)
259
+ self.svg_path_groups = [path_group]
260
+ return self
261
+
262
+ def empty(self):
263
+ return len(self.svg_path_groups) == 0
264
+
265
+ def drop_z(self):
266
+ return self._apply_to_paths("drop_z")
267
+
268
+ def filter_empty(self):
269
+ self._apply_to_paths("filter_empty")
270
+ self.svg_path_groups = [path_group for path_group in self.svg_path_groups if path_group.svg_paths]
271
+ return self
272
+
273
+ def translate(self, vec: Point):
274
+ return self._apply_to_paths("translate", vec)
275
+
276
+ def rotate(self, angle: Angle, center: Point = None):
277
+ if center is None:
278
+ center = self.viewbox.center
279
+
280
+ self.translate(-self.viewbox.center)
281
+ self._apply_to_paths("rotate", angle)
282
+ self.translate(center)
283
+
284
+ return self
285
+
286
+ def zoom(self, factor, center: Point = None):
287
+ if center is None:
288
+ center = self.viewbox.center
289
+
290
+ self.translate(-self.viewbox.center)
291
+ self._apply_to_paths("scale", factor)
292
+ self.translate(center)
293
+
294
+ return self
295
+
296
+ def normalize(self, viewbox: Bbox = None):
297
+ if viewbox is None:
298
+ viewbox = Bbox(24)
299
+
300
+ size = self.viewbox.size
301
+ scale_factor = viewbox.size.min() / size.max()
302
+ self.zoom(scale_factor, viewbox.center)
303
+ self.viewbox = viewbox
304
+
305
+ return self
306
+
307
+ def compute_filling(self):
308
+ return self._apply_to_paths("compute_filling")
309
+
310
+ def recompute_origins(self):
311
+ origin = self.start_pos
312
+
313
+ for path_group in self.svg_path_groups:
314
+ path_group.set_origin(origin.copy())
315
+ origin = path_group.end_pos
316
+
317
+ def canonicalize_new(self, normalize=False):
318
+ self.to_path().simplify_arcs()
319
+
320
+ self.compute_filling()
321
+
322
+ if normalize:
323
+ self.normalize()
324
+
325
+ self.split_paths()
326
+
327
+ self.filter_consecutives()
328
+ self.filter_empty()
329
+ self._apply_to_paths("reorder")
330
+ self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
331
+ self._apply_to_paths("canonicalize")
332
+ self.recompute_origins()
333
+
334
+ self.drop_z()
335
+
336
+ return self
337
+
338
+ def canonicalize(self, normalize=False):
339
+ self.to_path().simplify_arcs()
340
+
341
+ if normalize:
342
+ self.normalize()
343
+
344
+ self.split_paths()
345
+ self.filter_consecutives()
346
+ self.filter_empty()
347
+ self._apply_to_paths("reorder")
348
+ self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
349
+ self._apply_to_paths("canonicalize")
350
+ self.recompute_origins()
351
+
352
+ self.drop_z()
353
+
354
+ return self
355
+
356
+ def reorder(self):
357
+ return self._apply_to_paths("reorder")
358
+
359
+ def canonicalize_old(self):
360
+ self.filter_empty()
361
+ self._apply_to_paths("reorder")
362
+ self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
363
+ self._apply_to_paths("canonicalize")
364
+ self.split_paths()
365
+ self.recompute_origins()
366
+
367
+ self.drop_z()
368
+
369
+ return self
370
+
371
+ def to_video(self, wrapper, color="grey"):
372
+ clips, svg_commands = [], []
373
+
374
+ im = SVG([]).draw(do_display=False, return_png=True)
375
+ clips.append(wrapper(np.array(im)))
376
+
377
+ for svg_path in self.paths:
378
+ clips, svg_commands = svg_path.to_video(wrapper, clips, svg_commands, color=color)
379
+
380
+ im = self.draw(do_display=False, return_png=True)
381
+ clips.append(wrapper(np.array(im)))
382
+
383
+ return clips
384
+
385
+ def animate(self, file_path=None, frame_duration=0.1, do_display=True):
386
+ clips = self.to_video(lambda img: ImageClip(img).set_duration(frame_duration))
387
+
388
+ clip = concatenate_videoclips(clips, method="compose", bg_color=(255, 255, 255))
389
+
390
+ if file_path is not None:
391
+ clip.write_gif(file_path, fps=24, verbose=False, logger=None)
392
+
393
+ if do_display:
394
+ src = clip if file_path is None else file_path
395
+ ipd.display(ipython_display(src, fps=24, rd_kwargs=dict(logger=None), autoplay=1, loop=1))
396
+
397
+ def numericalize(self, n=256):
398
+ self.normalize(viewbox=Bbox(n))
399
+ return self._apply_to_paths("numericalize", n)
400
+
401
+ def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False):
402
+ self._apply_to_paths("simplify", tolerance=tolerance, epsilon=epsilon, angle_threshold=angle_threshold,
403
+ force_smooth=force_smooth)
404
+ self.recompute_origins()
405
+ return self
406
+
407
+ def reverse(self):
408
+ self._apply_to_paths("reverse")
409
+ return self
410
+
411
+ def reverse_non_closed(self):
412
+ self._apply_to_paths("reverse_non_closed")
413
+ return self
414
+
415
+ def duplicate_extremities(self):
416
+ self._apply_to_paths("duplicate_extremities")
417
+ return self
418
+
419
+ def simplify_heuristic(self, tolerance=0.1, force_smooth=False):
420
+ return self.copy().split(max_dist=2, include_lines=False) \
421
+ .simplify(tolerance=tolerance, epsilon=0.2, angle_threshold=150, force_smooth=force_smooth) \
422
+ .split(max_dist=7.5)
423
+
424
+ def simplify_heuristic2(self):
425
+ return self.copy().split(max_dist=2, include_lines=False) \
426
+ .simplify(tolerance=0.2, epsilon=0.2, angle_threshold=150) \
427
+ .split(max_dist=7.5)
428
+
429
+ def split(self, n=None, max_dist=None, include_lines=True):
430
+ return self._apply_to_paths("split", n=n, max_dist=max_dist, include_lines=include_lines)
431
+
432
+ @staticmethod
433
+ def unit_circle():
434
+ d = 2 * (math.sqrt(2) - 1) / 3
435
+
436
+ circle = SVGPath([
437
+ SVGCommandBezier(Point(.5, 0.), Point(.5 + d, 0.), Point(1., .5 - d), Point(1., .5)),
438
+ SVGCommandBezier(Point(1., .5), Point(1., .5 + d), Point(.5 + d, 1.), Point(.5, 1.)),
439
+ SVGCommandBezier(Point(.5, 1.), Point(.5 - d, 1.), Point(0., .5 + d), Point(0., .5)),
440
+ SVGCommandBezier(Point(0., .5), Point(0., .5 - d), Point(.5 - d, 0.), Point(.5, 0.))
441
+ ]).to_group()
442
+
443
+ return SVG([circle], viewbox=Bbox(1))
444
+
445
+ @staticmethod
446
+ def unit_square():
447
+ square = SVGPath.from_str("m 0,0 h1 v1 h-1 v-1")
448
+ return SVG([square], viewbox=Bbox(1))
449
+
450
+ def add_path_group(self, path_group: SVGPathGroup):
451
+ path_group.set_origin(self.end_pos.copy())
452
+ self.svg_path_groups.append(path_group)
453
+
454
+ return self
455
+
456
+ def add_path_groups(self, path_groups: List[SVGPathGroup]):
457
+ for path_group in path_groups:
458
+ self.add_path_group(path_group)
459
+
460
+ return self
461
+
462
+ def simplify_arcs(self):
463
+ return self._apply_to_paths("simplify_arcs")
464
+
465
+ def to_path(self):
466
+ for i, path_group in enumerate(self.svg_path_groups):
467
+ self.svg_path_groups[i] = path_group.to_path()
468
+ return self
469
+
470
+ def filter_consecutives(self):
471
+ return self._apply_to_paths("filter_consecutives")
472
+
473
+ def filter_duplicates(self):
474
+ return self._apply_to_paths("filter_duplicates")
475
+
476
+ def set_color(self, color):
477
+ colors = ["deepskyblue", "lime", "deeppink", "gold", "coral", "darkviolet", "royalblue", "darkmagenta", "teal",
478
+ "gold",
479
+ "green", "maroon", "aqua", "grey", "steelblue", "lime", "orange"]
480
+
481
+ if color == "random_random":
482
+ random.shuffle(colors)
483
+
484
+ if isinstance(color, list):
485
+ colors = color
486
+
487
+ for i, path_group in enumerate(self.svg_path_groups):
488
+ if color == "random" or color == "random_random" or isinstance(color, list):
489
+ c = colors[i % len(colors)]
490
+ else:
491
+ c = color
492
+ path_group.color = c
493
+ return self
494
+
495
+ def bbox(self):
496
+ return union_bbox([path_group.bbox() for path_group in self.svg_path_groups])
497
+
498
+ def overlap_graph(self, threshold=0.95, draw=False):
499
+ G = nx.DiGraph()
500
+ shapes = [group.to_shapely() for group in self.svg_path_groups]
501
+
502
+ for i, group1 in enumerate(shapes):
503
+ G.add_node(i)
504
+
505
+ if self.svg_path_groups[i].path.filling != Filling.OUTLINE:
506
+
507
+ for j, group2 in enumerate(shapes):
508
+ if i != j and self.svg_path_groups[j].path.filling == Filling.FILL:
509
+ overlap = group1.intersection(group2).area / group1.area
510
+ if overlap > threshold:
511
+ G.add_edge(j, i, weight=overlap)
512
+
513
+ if draw:
514
+ pos = nx.spring_layout(G)
515
+ nx.draw_networkx(G, pos, with_labels=True)
516
+ labels = nx.get_edge_attributes(G, 'weight')
517
+ nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
518
+ return G
519
+
520
+ def group_overlapping_paths(self):
521
+ G = self.overlap_graph()
522
+
523
+ path_groups = []
524
+ root_nodes = [i for i, d in G.in_degree() if d == 0]
525
+
526
+ for root in root_nodes:
527
+ if self[root].path.filling == Filling.FILL:
528
+ current = [root]
529
+
530
+ while current:
531
+ n = current.pop(0)
532
+
533
+ fill_neighbors, erase_neighbors = [], []
534
+ for m in G.neighbors(n):
535
+ if G.in_degree(m) == 1:
536
+ if self[m].path.filling == Filling.ERASE:
537
+ erase_neighbors.append(m)
538
+ else:
539
+ fill_neighbors.append(m)
540
+ G.remove_node(n)
541
+
542
+ path_group = SVGPathGroup([self[n].path.copy().set_orientation(Orientation.CLOCKWISE)], fill=True)
543
+ if erase_neighbors:
544
+ for n in erase_neighbors:
545
+ neighbor = self[n].path.copy().set_orientation(Orientation.COUNTER_CLOCKWISE)
546
+ path_group.append(neighbor)
547
+ G.remove_nodes_from(erase_neighbors)
548
+
549
+ path_groups.append(path_group)
550
+
551
+ current.extend(fill_neighbors)
552
+
553
+ # Add outlines in the end
554
+ for path_group in self.svg_path_groups:
555
+ if path_group.path.filling == Filling.OUTLINE:
556
+ path_groups.append(path_group)
557
+
558
+ return SVG(path_groups)
559
+
560
+ def to_points(self, sort=True):
561
+ points = np.concatenate([path_group.to_points() for path_group in self.svg_path_groups])
562
+
563
+ if sort:
564
+ ind = np.lexsort((points[:, 0], points[:, 1]))
565
+ points = points[ind]
566
+
567
+ # Remove duplicates
568
+ row_mask = np.append([True], np.any(np.diff(points, axis=0), 1))
569
+ points = points[row_mask]
570
+
571
+ return points
572
+
573
+ def permute(self, indices=None):
574
+ if indices is not None:
575
+ self.svg_path_groups = [self.svg_path_groups[i] for i in indices]
576
+ return self
577
+
578
+ def fill_(self, fill=True):
579
+ return self._apply_to_paths("fill_", fill)
src/preprocessing/deepsvg/deepsvg_svglib/svg_command.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from .geom import *
8
+ from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
9
+ from .util_fns import get_roots
10
+ from enum import Enum
11
+ import torch
12
+ import math
13
+ from typing import List, Union
14
+ Num = Union[int, float]
15
+
16
+
17
+ class SVGCmdEnum(Enum):
18
+ MOVE_TO = "m"
19
+ LINE_TO = "l"
20
+ CUBIC_BEZIER = "c"
21
+ CLOSE_PATH = "z"
22
+ ELLIPTIC_ARC = "a"
23
+ QUAD_BEZIER = "q"
24
+ LINE_TO_HORIZONTAL = "h"
25
+ LINE_TO_VERTICAL = "v"
26
+ CUBIC_BEZIER_REFL = "s"
27
+ QUAD_BEZIER_REFL = "t"
28
+
29
+
30
+ svgCmdArgTypes = {
31
+ SVGCmdEnum.MOVE_TO.value: [Point],
32
+ SVGCmdEnum.LINE_TO.value: [Point],
33
+ SVGCmdEnum.CUBIC_BEZIER.value: [Point, Point, Point],
34
+ SVGCmdEnum.CLOSE_PATH.value: [],
35
+ SVGCmdEnum.ELLIPTIC_ARC.value: [Radius, Angle, Flag, Flag, Point],
36
+ SVGCmdEnum.QUAD_BEZIER.value: [Point, Point],
37
+ SVGCmdEnum.LINE_TO_HORIZONTAL.value: [XCoord],
38
+ SVGCmdEnum.LINE_TO_VERTICAL.value: [YCoord],
39
+ SVGCmdEnum.CUBIC_BEZIER_REFL.value: [Point, Point],
40
+ SVGCmdEnum.QUAD_BEZIER_REFL.value: [Point],
41
+ }
42
+
43
+
44
+ class SVGCommand:
45
+ def __init__(self, command: SVGCmdEnum, args: List[Geom], start_pos: Point, end_pos: Point):
46
+ self.command = command
47
+ self.args = args
48
+
49
+ self.start_pos = start_pos
50
+ self.end_pos = end_pos
51
+
52
+ def copy(self):
53
+ raise NotImplementedError
54
+
55
+ @staticmethod
56
+ def from_str(cmd_str: str, args_str: List[Num], pos=None, initial_pos=None, prev_command: SVGCommand = None):
57
+ if pos is None:
58
+ pos = Point(0.)
59
+ if initial_pos is None:
60
+ initial_pos = Point(0.)
61
+
62
+ cmd = SVGCmdEnum(cmd_str.lower())
63
+
64
+ # Implicit MoveTo commands are treated as LineTo
65
+ if cmd is SVGCmdEnum.MOVE_TO and len(args_str) > 2:
66
+ l_cmd_str = SVGCmdEnum.LINE_TO.value
67
+ if cmd_str.isupper():
68
+ l_cmd_str = l_cmd_str.upper()
69
+
70
+ l1, pos, initial_pos = SVGCommand.from_str(cmd_str, args_str[:2], pos, initial_pos)
71
+ l2, pos, initial_pos = SVGCommand.from_str(l_cmd_str, args_str[2:], pos, initial_pos)
72
+ return [*l1, *l2], pos, initial_pos
73
+
74
+ nb_args = len(args_str)
75
+
76
+ if cmd is SVGCmdEnum.CLOSE_PATH:
77
+ assert nb_args == 0, f"Expected no argument for command {cmd_str}: {nb_args} given"
78
+ return [SVGCommandClose(pos, initial_pos)], initial_pos, initial_pos
79
+
80
+ expected_nb_args = sum([ArgType.num_args for ArgType in svgCmdArgTypes[cmd.value]])
81
+ assert nb_args % expected_nb_args == 0, f"Expected {expected_nb_args} arguments for command {cmd_str}: {nb_args} given"
82
+
83
+ l = []
84
+ i = 0
85
+ for _ in range(nb_args // expected_nb_args):
86
+ args = []
87
+ for ArgType in svgCmdArgTypes[cmd.value]:
88
+ num_args = ArgType.num_args
89
+ arg = ArgType(*args_str[i:i+num_args])
90
+
91
+ if cmd_str.islower():
92
+ arg.translate(pos)
93
+ if isinstance(arg, Coord):
94
+ arg = arg.to_point(pos)
95
+
96
+ args.append(arg)
97
+ i += num_args
98
+
99
+ if cmd is SVGCmdEnum.LINE_TO or cmd is SVGCmdEnum.LINE_TO_VERTICAL or cmd is SVGCmdEnum.LINE_TO_HORIZONTAL:
100
+ cmd_parsed = SVGCommandLine(pos, *args)
101
+ elif cmd is SVGCmdEnum.MOVE_TO:
102
+ cmd_parsed = SVGCommandMove(pos, *args)
103
+ elif cmd is SVGCmdEnum.ELLIPTIC_ARC:
104
+ cmd_parsed = SVGCommandArc(pos, *args)
105
+ elif cmd is SVGCmdEnum.CUBIC_BEZIER:
106
+ cmd_parsed = SVGCommandBezier(pos, *args)
107
+ elif cmd is SVGCmdEnum.QUAD_BEZIER:
108
+ cmd_parsed = SVGCommandBezier(pos, args[0], args[0], args[1])
109
+ elif cmd is SVGCmdEnum.QUAD_BEZIER_REFL or cmd is SVGCmdEnum.CUBIC_BEZIER_REFL:
110
+ if isinstance(prev_command, SVGCommandBezier):
111
+ control1 = pos * 2 - prev_command.control2
112
+ else:
113
+ control1 = pos
114
+ control2 = args[0] if cmd is SVGCmdEnum.CUBIC_BEZIER_REFL else control1
115
+ cmd_parsed = SVGCommandBezier(pos, control1, control2, args[-1])
116
+
117
+ prev_command = cmd_parsed
118
+ pos = cmd_parsed.end_pos
119
+
120
+ if cmd is SVGCmdEnum.MOVE_TO:
121
+ initial_pos = pos
122
+
123
+ l.append(cmd_parsed)
124
+
125
+ return l, pos, initial_pos
126
+
127
+ def __repr__(self):
128
+ cmd = self.command.value.upper()
129
+ return f"{cmd}{self.get_geoms()}"
130
+
131
+ def to_str(self):
132
+ cmd = self.command.value.upper()
133
+ return f"{cmd}{' '.join([arg.to_str() for arg in self.args])}"
134
+
135
+ def to_tensor(self, PAD_VAL=-1):
136
+ raise NotImplementedError
137
+
138
+ @staticmethod
139
+ def from_tensor(vector: torch.Tensor):
140
+ cmd_index, args = int(vector[0]), vector[1:]
141
+
142
+ cmd = SVGCmdEnum(SVGTensor.COMMANDS_SIMPLIFIED[cmd_index])
143
+ radius = Radius(*args[:2].tolist())
144
+ x_axis_rotation = Angle(*args[2:3].tolist())
145
+ large_arc_flag = Flag(args[3].item())
146
+ sweep_flag = Flag(args[4].item())
147
+ start_pos = Point(*args[5:7].tolist())
148
+ control1 = Point(*args[7:9].tolist())
149
+ control2 = Point(*args[9:11].tolist())
150
+ end_pos = Point(*args[11:].tolist())
151
+
152
+ return SVGCommand.from_args(cmd, radius, x_axis_rotation, large_arc_flag, sweep_flag, start_pos, control1, control2, end_pos)
153
+
154
+ @staticmethod
155
+ def from_args(command: SVGCmdEnum, radius: Radius, x_axis_rotation: Angle, large_arc_flag: Flag,
156
+ sweep_flag: Flag, start_pos: Point, control1: Point, control2: Point, end_pos: Point):
157
+ if command is SVGCmdEnum.MOVE_TO:
158
+ return SVGCommandMove(start_pos, end_pos)
159
+ elif command is SVGCmdEnum.LINE_TO:
160
+ return SVGCommandLine(start_pos, end_pos)
161
+ elif command is SVGCmdEnum.CUBIC_BEZIER:
162
+ return SVGCommandBezier(start_pos, control1, control2, end_pos)
163
+ elif command is SVGCmdEnum.CLOSE_PATH:
164
+ return SVGCommandClose(start_pos, end_pos)
165
+ elif command is SVGCmdEnum.ELLIPTIC_ARC:
166
+ return SVGCommandArc(start_pos, radius, x_axis_rotation, large_arc_flag, sweep_flag, end_pos)
167
+
168
+ def draw(self, *args, **kwargs):
169
+ from .svg_path import SVGPath
170
+ return SVGPath([self]).draw(*args, **kwargs)
171
+
172
+ def reverse(self):
173
+ raise NotImplementedError
174
+
175
+ def is_left_to(self, other: SVGCommand):
176
+ p1, p2 = self.start_pos, other.start_pos
177
+
178
+ if p1.y == p2.y:
179
+ return p1.x < p2.x
180
+
181
+ return p1.y < p2.y or (np.isclose(p1.norm(), p2.norm()) and p1.x < p2.x)
182
+
183
+ def numericalize(self, n=256):
184
+ raise NotImplementedError
185
+
186
+ def get_geoms(self):
187
+ return [self.start_pos, self.end_pos]
188
+
189
+ def get_points_viz(self, first=False, last=False):
190
+ from .svg_primitive import SVGCircle
191
+ color = "red" if first else "purple" if last else "deepskyblue" # "#C4C4C4"
192
+ opacity = 0.75 if first or last else 1.0
193
+ return [SVGCircle(self.end_pos, radius=Radius(0.4), color=color, fill=True, stroke_width=".1", opacity=opacity)]
194
+
195
+ def get_handles_viz(self):
196
+ return []
197
+
198
+ def sample_points(self, n=10, return_array=False):
199
+ return []
200
+
201
+ def split(self, n=2):
202
+ raise NotImplementedError
203
+
204
+ def length(self):
205
+ raise NotImplementedError
206
+
207
+ def bbox(self):
208
+ raise NotImplementedError
209
+
210
+
211
+ class SVGCommandLinear(SVGCommand):
212
+ def __init__(self, *args, **kwargs):
213
+ super().__init__(*args, **kwargs)
214
+
215
+ def to_tensor(self, PAD_VAL=-1):
216
+ cmd_index = SVGTensor.COMMANDS_SIMPLIFIED.index(self.command.value)
217
+ return torch.tensor([cmd_index,
218
+ *([PAD_VAL] * 5),
219
+ *self.start_pos.to_tensor(),
220
+ *([PAD_VAL] * 4),
221
+ *self.end_pos.to_tensor()])
222
+
223
+ def numericalize(self, n=256):
224
+ self.start_pos.numericalize(n)
225
+ self.end_pos.numericalize(n)
226
+
227
+ def copy(self):
228
+ return self.__class__(self.start_pos.copy(), self.end_pos.copy())
229
+
230
+ def reverse(self):
231
+ return self.__class__(self.end_pos, self.start_pos)
232
+
233
+ def split(self, n=2):
234
+ return [self]
235
+
236
+ def bbox(self):
237
+ return Bbox(self.start_pos, self.end_pos)
238
+
239
+
240
+ class SVGCommandMove(SVGCommandLinear):
241
+ def __init__(self, start_pos: Point, end_pos: Point=None):
242
+ if end_pos is None:
243
+ start_pos, end_pos = Point(0.), start_pos
244
+ super().__init__(SVGCmdEnum.MOVE_TO, [end_pos], start_pos, end_pos)
245
+
246
+ def get_points_viz(self, first=False, last=False):
247
+ from .svg_primitive import SVGLine
248
+ points_viz = super().get_points_viz(first, last)
249
+ points_viz.append(SVGLine(self.start_pos, self.end_pos, color="red", dasharray=0.5))
250
+ return points_viz
251
+
252
+ def bbox(self):
253
+ return Bbox(self.end_pos, self.end_pos)
254
+
255
+
256
+ class SVGCommandLine(SVGCommandLinear):
257
+ def __init__(self, start_pos: Point, end_pos: Point):
258
+ super().__init__(SVGCmdEnum.LINE_TO, [end_pos], start_pos, end_pos)
259
+
260
+ def sample_points(self, n=10, return_array=False):
261
+ z = np.linspace(0., 1., n)
262
+
263
+ if return_array:
264
+ points = (1-z)[:, None] * self.start_pos.pos[None] + z[:, None] * self.end_pos.pos[None]
265
+ return points
266
+
267
+ points = [(1 - alpha) * self.start_pos + alpha * self.end_pos for alpha in z]
268
+ return points
269
+
270
+ def split(self, n=2):
271
+ points = self.sample_points(n+1)
272
+ return [SVGCommandLine(p1, p2) for p1, p2 in zip(points[:-1], points[1:])]
273
+
274
+ def length(self):
275
+ return self.start_pos.dist(self.end_pos)
276
+
277
+
278
+ class SVGCommandClose(SVGCommandLinear):
279
+ def __init__(self, start_pos: Point, end_pos: Point):
280
+ super().__init__(SVGCmdEnum.CLOSE_PATH, [], start_pos, end_pos)
281
+
282
+ def get_points_viz(self, first=False, last=False):
283
+ return []
284
+
285
+
286
+ class SVGCommandBezier(SVGCommand):
287
+ def __init__(self, start_pos: Point, control1: Point, control2: Point, end_pos: Point):
288
+ if control2 is None:
289
+ control2 = control1.copy()
290
+ super().__init__(SVGCmdEnum.CUBIC_BEZIER, [control1, control2, end_pos], start_pos, end_pos)
291
+
292
+ self.control1 = control1
293
+ self.control2 = control2
294
+
295
+ @property
296
+ def p1(self):
297
+ return self.start_pos
298
+
299
+ @property
300
+ def p2(self):
301
+ return self.end_pos
302
+
303
+ @property
304
+ def q1(self):
305
+ return self.control1
306
+
307
+ @property
308
+ def q2(self):
309
+ return self.control2
310
+
311
+ def copy(self):
312
+ return SVGCommandBezier(self.start_pos.copy(), self.control1.copy(), self.control2.copy(), self.end_pos.copy())
313
+
314
+ def to_tensor(self, PAD_VAL=-1):
315
+ cmd_index = SVGTensor.COMMANDS_SIMPLIFIED.index(SVGCmdEnum.CUBIC_BEZIER.value)
316
+ return torch.tensor([cmd_index,
317
+ *([PAD_VAL] * 5),
318
+ *self.start_pos.to_tensor(),
319
+ *self.control1.to_tensor(),
320
+ *self.control2.to_tensor(),
321
+ *self.end_pos.to_tensor()])
322
+
323
+ def to_vector(self):
324
+ return np.array([
325
+ self.start_pos.tolist(),
326
+ self.control1.tolist(),
327
+ self.control2.tolist(),
328
+ self.end_pos.tolist()
329
+ ])
330
+
331
+ @staticmethod
332
+ def from_vector(vector):
333
+ return SVGCommandBezier(Point(vector[0]), Point(vector[1]), Point(vector[2]), Point(vector[3]))
334
+
335
+ def reverse(self):
336
+ return SVGCommandBezier(self.end_pos, self.control2, self.control1, self.start_pos)
337
+
338
+ def numericalize(self, n=256):
339
+ self.start_pos.numericalize(n)
340
+ self.control1.numericalize(n)
341
+ self.control2.numericalize(n)
342
+ self.end_pos.numericalize(n)
343
+
344
+ def get_geoms(self):
345
+ return [self.start_pos, self.control1, self.control2, self.end_pos]
346
+
347
+ def get_handles_viz(self):
348
+ from .svg_primitive import SVGLine, SVGCircle
349
+ anchor_1 = SVGCircle(self.control1, radius=Radius(0.4), color="lime", fill=True, stroke_width=".1")
350
+ anchor_2 = SVGCircle(self.control2, radius=Radius(0.4), color="lime", fill=True, stroke_width=".1")
351
+
352
+ handle_1 = SVGLine(self.start_pos, self.control1, color="grey", dasharray=0.5, stroke_width=".1")
353
+ handle_2 = SVGLine(self.end_pos, self.control2, color="grey", dasharray=0.5, stroke_width=".1")
354
+ return [handle_1, handle_2, anchor_1, anchor_2]
355
+
356
+ def eval(self, t):
357
+ return (1 - t)**3 * self.start_pos + 3 * (1 - t)**2 * t * self.control1 + 3 * (1 - t) * t**2 * self.control2 + t**3 * self.end_pos
358
+
359
+ def derivative(self, t, n=1):
360
+ if n == 1:
361
+ return 3 * (1 - t)**2 * (self.control1 - self.start_pos) + 6 * (1 - t) * t * (self.control2 - self.control1) + 3 * t**2 * (self.end_pos - self.control2)
362
+ elif n == 2:
363
+ return 6 * (1 - t) * (self.control2 - 2 * self.control1 + self.start_pos) + 6 * t * (self.end_pos - 2 * self.control2 + self.control1)
364
+
365
+ raise NotImplementedError
366
+
367
+ def angle(self, other: SVGCommandBezier):
368
+ t1, t2 = self.derivative(1.), -other.derivative(0.)
369
+ if np.isclose(t1.norm(), 0.) or np.isclose(t2.norm(), 0.):
370
+ return 0.
371
+ angle = np.arccos(np.clip(t1.normalize().dot(t2.normalize()), -1., 1.))
372
+ return np.rad2deg(angle)
373
+
374
+ def sample_points(self, n=10, return_array=False):
375
+ b = self.to_vector()
376
+
377
+ z = np.linspace(0., 1., n)
378
+ Z = np.stack([np.ones_like(z), z, z**2, z**3], axis=1)
379
+ Q = np.array([[1., 0., 0., 0.],
380
+ [-3, 3., 0., 0.],
381
+ [3., -6, 3., 0.],
382
+ [-1, 3., -3, 1]])
383
+
384
+ points = Z @ Q @ b
385
+
386
+ if return_array:
387
+ return points
388
+
389
+ return [Point(p) for p in points]
390
+
391
+ def _split_two(self, z=.5):
392
+ b = self.to_vector()
393
+
394
+ Q1 = np.array([[1, 0, 0, 0],
395
+ [-(z - 1), z, 0, 0],
396
+ [(z - 1) ** 2, -2 * (z - 1) * z, z ** 2, 0],
397
+ [-(z - 1) ** 3, 3 * (z - 1) ** 2 * z, -3 * (z - 1) * z ** 2, z ** 3]])
398
+ Q2 = np.array([[-(z - 1) ** 3, 3 * (z - 1) ** 2 * z, -3 * (z - 1) * z ** 2, z ** 3],
399
+ [0, (z - 1) ** 2, -2 * (z - 1) * z, z ** 2],
400
+ [0, 0, -(z - 1), z],
401
+ [0, 0, 0, 1]])
402
+
403
+ return SVGCommandBezier.from_vector(Q1 @ b), SVGCommandBezier.from_vector(Q2 @ b)
404
+
405
+ def split(self, n=2):
406
+ b_list = []
407
+ b = self
408
+
409
+ for i in range(n - 1):
410
+ z = 1. / (n - i)
411
+ b1, b = b._split_two(z)
412
+ b_list.append(b1)
413
+ b_list.append(b)
414
+ return b_list
415
+
416
+ def length(self):
417
+ p = self.sample_points(n=100, return_array=True)
418
+ return np.linalg.norm(p[1:] - p[:-1], axis=-1).sum()
419
+
420
+ def bbox(self):
421
+ return Bbox.from_points(self.find_extrema())
422
+
423
+ def find_roots(self):
424
+ a = 3 * (-self.p1 + 3 * self.q1 - 3 * self.q2 + self.p2)
425
+ b = 6 * (self.p1 - 2 * self.q1 + self.q2)
426
+ c = 3 * (self.q1 - self.p1)
427
+
428
+ x_roots, y_roots = get_roots(a.x, b.x, c.x), get_roots(a.y, b.y, c.y)
429
+ roots_cat = [*x_roots, *y_roots]
430
+ roots = [root for root in roots_cat if 0 <= root <= 1]
431
+ return roots
432
+
433
+ def find_extrema(self):
434
+ points = [self.start_pos, self.end_pos]
435
+ points.extend([self.eval(root) for root in self.find_roots()])
436
+ return points
437
+
438
+
439
+ class SVGCommandArc(SVGCommand):
440
+ def __init__(self, start_pos: Point, radius: Radius, x_axis_rotation: Angle, large_arc_flag: Flag, sweep_flag: Flag, end_pos: Point):
441
+ super().__init__(SVGCmdEnum.ELLIPTIC_ARC, [radius, x_axis_rotation, large_arc_flag, sweep_flag, end_pos], start_pos, end_pos)
442
+
443
+ self.radius = radius
444
+ self.x_axis_rotation = x_axis_rotation
445
+ self.large_arc_flag = large_arc_flag
446
+ self.sweep_flag = sweep_flag
447
+
448
+ def copy(self):
449
+ return SVGCommandArc(self.start_pos.copy(), self.radius.copy(), self.x_axis_rotation.copy(), self.large_arc_flag.copy(),
450
+ self.sweep_flag.copy(), self.end_pos.copy())
451
+
452
+ def to_tensor(self, PAD_VAL=-1):
453
+ cmd_index = SVGTensor.COMMANDS_SIMPLIFIED.index(SVGCmdEnum.ELLIPTIC_ARC.value)
454
+ return torch.tensor([cmd_index,
455
+ *self.radius.to_tensor(),
456
+ *self.x_axis_rotation.to_tensor(),
457
+ *self.large_arc_flag.to_tensor(),
458
+ *self.sweep_flag.to_tensor(),
459
+ *self.start_pos.to_tensor(),
460
+ *([PAD_VAL] * 4),
461
+ *self.end_pos.to_tensor()])
462
+
463
+ def _get_center_parametrization(self):
464
+ r = self.radius
465
+ p1, p2 = self.start_pos, self.end_pos
466
+
467
+ h, m = 0.5 * (p1 - p2), 0.5 * (p1 + p2)
468
+ p1_trans = h.rotate(-self.x_axis_rotation)
469
+
470
+ sign = -1 if self.large_arc_flag.flag == self.sweep_flag.flag else 1
471
+ x2, y2, rx2, ry2 = p1_trans.x**2, p1_trans.y**2, r.x**2, r.y**2
472
+ sqrt = math.sqrt(max((rx2*ry2 - rx2*y2 - ry2*x2) / (rx2*y2 + ry2*x2), 0.))
473
+ c_trans = sign * sqrt * Point(r.x * p1_trans.y / r.y, -r.y * p1_trans.x / r.x)
474
+
475
+ c = c_trans.rotate(self.x_axis_rotation) + m
476
+
477
+ d, ns = (p1_trans - c_trans) / r, -(p1_trans + c_trans) / r
478
+
479
+ theta_1 = Point(1, 0).angle(d, signed=True)
480
+
481
+ delta_theta = d.angle(ns, signed=True)
482
+ delta_theta.deg %= 360
483
+ if self.sweep_flag.flag == 0 and delta_theta.deg > 0:
484
+ delta_theta = delta_theta - Angle(360)
485
+ if self.sweep_flag == 1 and delta_theta.deg < 0:
486
+ delta_theta = delta_theta + Angle(360)
487
+
488
+ return c, theta_1, delta_theta
489
+
490
+ def _get_point(self, c: Point, t: float_type):
491
+ r = self.radius
492
+ return c + Point(r.x * np.cos(t), r.y * np.sin(t)).rotate(self.x_axis_rotation)
493
+
494
+ def _get_derivative(self, t: float_type):
495
+ r = self.radius
496
+ return Point(-r.x * np.sin(t), r.y * np.cos(t)).rotate(self.x_axis_rotation)
497
+
498
+ def to_beziers(self):
499
+ """ References:
500
+ https://www.w3.org/TR/2018/CR-SVG2-20180807/implnote.html
501
+ https://mortoray.com/2017/02/16/rendering-an-svg-elliptical-arc-as-bezier-curves/
502
+ http://www.spaceroots.org/documents/ellipse/elliptical-arc.pdf """
503
+ beziers = []
504
+
505
+ c, theta_1, delta_theta = self._get_center_parametrization()
506
+ nb_curves = max(int(abs(delta_theta.deg) // 45), 1)
507
+ etas = [theta_1 + i * delta_theta / nb_curves for i in range(nb_curves+1)]
508
+ for eta_1, eta_2 in zip(etas[:-1], etas[1:]):
509
+ e1, e2 = eta_1.rad, eta_2.rad
510
+ alpha = np.sin(e2 - e1) * (math.sqrt(4 + 3 * np.tan(0.5 * (e2 - e1))**2) - 1) / 3
511
+ p1, p2 = self._get_point(c, e1), self._get_point(c, e2)
512
+ q1 = p1 + alpha * self._get_derivative(e1)
513
+ q2 = p2 - alpha * self._get_derivative(e2)
514
+ beziers.append(SVGCommandBezier(p1, q1, q2, p2))
515
+
516
+ return beziers
517
+
518
+ def reverse(self):
519
+ return SVGCommandArc(self.end_pos, self.radius, self.x_axis_rotation, self.large_arc_flag, ~self.sweep_flag, self.start_pos)
520
+
521
+ def numericalize(self, n=256):
522
+ raise NotImplementedError
523
+
524
+ def get_geoms(self):
525
+ return [self.start_pos, self.radius, self.x_axis_rotation, self.large_arc_flag, self.sweep_flag, self.end_pos]
526
+
527
+ def split(self, n=2):
528
+ raise NotImplementedError
529
+
530
+ def sample_points(self, n=10, return_array=False):
531
+ raise NotImplementedError
src/preprocessing/deepsvg/deepsvg_svglib/svg_path.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from .geom import *
8
+ import src.preprocessing.deepsvg.deepsvg_svglib.geom as geom
9
+ import re
10
+ import torch
11
+ from typing import List, Union
12
+ from xml.dom import minidom
13
+ import math
14
+ #import shapely.geometry
15
+ import numpy as np
16
+
17
+ from .geom import union_bbox
18
+ from .svg_command import SVGCommand, SVGCommandMove, SVGCommandClose, SVGCommandBezier, SVGCommandLine, SVGCommandArc
19
+
20
+
21
+ COMMANDS = "MmZzLlHhVvCcSsQqTtAa"
22
+ COMMAND_RE = re.compile(r"([MmZzLlHhVvCcSsQqTtAa])")
23
+ FLOAT_RE = re.compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
24
+
25
+
26
+ empty_command = SVGCommandMove(Point(0.))
27
+
28
+
29
+ class Orientation:
30
+ COUNTER_CLOCKWISE = 0
31
+ CLOCKWISE = 1
32
+
33
+
34
+ class Filling:
35
+ OUTLINE = 0
36
+ FILL = 1
37
+ ERASE = 2
38
+
39
+
40
+ class SVGPath:
41
+ def __init__(self, path_commands: List[SVGCommand] = None, origin: Point = None, closed=False, filling=Filling.OUTLINE):
42
+ self.origin = origin or Point(0.)
43
+ self.path_commands = path_commands
44
+ self.closed = closed
45
+
46
+ self.filling = filling
47
+
48
+ @property
49
+ def start_command(self):
50
+ return SVGCommandMove(self.origin, self.start_pos)
51
+
52
+ @property
53
+ def start_pos(self):
54
+ return self.path_commands[0].start_pos
55
+
56
+ @property
57
+ def end_pos(self):
58
+ return self.path_commands[-1].end_pos
59
+
60
+ def to_group(self, *args, **kwargs):
61
+ from .svg_primitive import SVGPathGroup
62
+ return SVGPathGroup([self], *args, **kwargs)
63
+
64
+ def set_filling(self, filling=True):
65
+ self.filling = Filling.FILL if filling else Filling.ERASE
66
+ return self
67
+
68
+ def __len__(self):
69
+ return 1 + len(self.path_commands)
70
+
71
+ def __getitem__(self, idx):
72
+ if idx == 0:
73
+ return self.start_command
74
+ return self.path_commands[idx-1]
75
+
76
+ def all_commands(self, with_close=True):
77
+ close_cmd = [SVGCommandClose(self.path_commands[-1].end_pos.copy(), self.start_pos.copy())] if self.closed and self.path_commands and with_close \
78
+ else ()
79
+ return [self.start_command, *self.path_commands, *close_cmd]
80
+
81
+ def copy(self):
82
+ return SVGPath([path_command.copy() for path_command in self.path_commands], self.origin.copy(), self.closed, filling=self.filling)
83
+
84
+ @staticmethod
85
+ def _tokenize_path(path_str):
86
+ cmd = None
87
+ for x in COMMAND_RE.split(path_str):
88
+ if x and x in COMMANDS:
89
+ cmd = x
90
+ elif cmd is not None:
91
+ yield cmd, list(map(float, FLOAT_RE.findall(x)))
92
+
93
+ @staticmethod
94
+ def from_xml(x: minidom.Element):
95
+ stroke = x.getAttribute('stroke')
96
+ dasharray = x.getAttribute('dasharray')
97
+ stroke_width = x.getAttribute('stroke-width')
98
+
99
+ fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
100
+
101
+ filling = Filling.OUTLINE if not x.hasAttribute("filling") else int(x.getAttribute("filling"))
102
+
103
+ s = x.getAttribute('d')
104
+ return SVGPath.from_str(s, fill=fill, filling=filling)
105
+
106
+ @staticmethod
107
+ def from_str(s: str, fill=False, filling=Filling.OUTLINE, add_closing=False):
108
+ path_commands = []
109
+ pos = initial_pos = Point(0.)
110
+ prev_command = None
111
+ for cmd, args in SVGPath._tokenize_path(s):
112
+ cmd_parsed, pos, initial_pos = SVGCommand.from_str(cmd, args, pos, initial_pos, prev_command)
113
+ prev_command = cmd_parsed[-1]
114
+ path_commands.extend(cmd_parsed)
115
+
116
+ return SVGPath.from_commands(path_commands, fill=fill, filling=filling, add_closing=add_closing)
117
+
118
+ @staticmethod
119
+ def from_tensor(tensor: torch.Tensor, allow_empty=False):
120
+ return SVGPath.from_commands([SVGCommand.from_tensor(row) for row in tensor], allow_empty=allow_empty)
121
+
122
+ @staticmethod
123
+ def from_commands(path_commands: List[SVGCommand], fill=False, filling=Filling.OUTLINE, add_closing=False, allow_empty=False):
124
+ from .svg_primitive import SVGPathGroup
125
+
126
+ if not path_commands:
127
+ return SVGPathGroup([])
128
+
129
+ svg_paths = []
130
+ svg_path = None
131
+
132
+ for command in path_commands:
133
+ if isinstance(command, SVGCommandMove):
134
+ if svg_path is not None and (allow_empty or svg_path.path_commands): # SVGPath contains at least one command
135
+ if add_closing:
136
+ svg_path.closed = True
137
+ if not svg_path.path_commands:
138
+ svg_path.path_commands.append(empty_command)
139
+ svg_paths.append(svg_path)
140
+
141
+ svg_path = SVGPath([], command.start_pos.copy(), filling=filling)
142
+ else:
143
+ if svg_path is None:
144
+ # Ignore commands until the first moveTo commands
145
+ continue
146
+
147
+ if isinstance(command, SVGCommandClose):
148
+ if allow_empty or svg_path.path_commands: # SVGPath contains at least one command
149
+ svg_path.closed = True
150
+ if not svg_path.path_commands:
151
+ svg_path.path_commands.append(empty_command)
152
+ svg_paths.append(svg_path)
153
+ svg_path = None
154
+ else:
155
+ svg_path.path_commands.append(command)
156
+ if svg_path is not None and (allow_empty or svg_path.path_commands): # SVGPath contains at least one command
157
+ if add_closing:
158
+ svg_path.closed = True
159
+ if not svg_path.path_commands:
160
+ svg_path.path_commands.append(empty_command)
161
+ svg_paths.append(svg_path)
162
+ return SVGPathGroup(svg_paths, fill=fill)
163
+
164
+ def __repr__(self):
165
+ return "SVGPath({})".format(" ".join(command.__repr__() for command in self.all_commands()))
166
+
167
+ def to_str(self, fill=False):
168
+ return " ".join(command.to_str() for command in self.all_commands())
169
+
170
+ def to_tensor(self, PAD_VAL=-1):
171
+ return torch.stack([command.to_tensor(PAD_VAL=PAD_VAL) for command in self.all_commands()])
172
+
173
+ def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=False, with_moves=True):
174
+ points = self._get_points_viz(color_firstlast, with_moves) if with_points else ()
175
+ handles = self._get_handles_viz() if with_handles else ()
176
+ return [*points, *handles]
177
+
178
+ def draw(self, viewbox=Bbox(24), *args, **kwargs):
179
+ from .svg import SVG
180
+ return SVG([self.to_group()], viewbox=viewbox).draw(*args, **kwargs)
181
+
182
+ def _get_points_viz(self, color_firstlast=True, with_moves=True):
183
+ points = []
184
+ commands = self.all_commands(with_close=False)
185
+ n = len(commands)
186
+ for i, command in enumerate(commands):
187
+ if not isinstance(command, SVGCommandMove) or with_moves:
188
+ points_viz = command.get_points_viz(first=(color_firstlast and i <= 1), last=(color_firstlast and i >= n-2))
189
+ points.extend(points_viz)
190
+ return points
191
+
192
+ def _get_handles_viz(self):
193
+ handles = []
194
+ for command in self.path_commands:
195
+ handles.extend(command.get_handles_viz())
196
+ return handles
197
+
198
+ def _get_unique_geoms(self):
199
+ geoms = []
200
+ for command in self.all_commands():
201
+ geoms.extend(command.get_geoms())
202
+ return list(set(geoms))
203
+
204
+ def translate(self, vec):
205
+ for geom in self._get_unique_geoms():
206
+ geom.translate(vec)
207
+ return self
208
+
209
+ def rotate(self, angle):
210
+ for geom in self._get_unique_geoms():
211
+ geom.rotate_(angle)
212
+ return self
213
+
214
+ def scale(self, factor):
215
+ for geom in self._get_unique_geoms():
216
+ geom.scale(factor)
217
+ return self
218
+
219
+ def filter_consecutives(self):
220
+ path_commands = []
221
+ for command in self.path_commands:
222
+ if not command.start_pos.isclose(command.end_pos):
223
+ path_commands.append(command)
224
+ self.path_commands = path_commands
225
+ return self
226
+
227
+ def filter_duplicates(self, min_dist=0.2):
228
+ path_commands = []
229
+ current_command = None
230
+ for command in self.path_commands:
231
+ if current_command is None:
232
+ path_commands.append(command)
233
+ current_command = command
234
+
235
+ if command.end_pos.dist(current_command.end_pos) >= min_dist:
236
+ command.start_pos = current_command.end_pos
237
+ path_commands.append(command)
238
+ current_command = command
239
+
240
+ self.path_commands = path_commands
241
+ return self
242
+
243
+ def duplicate_extremities(self):
244
+ self.path_commands = [SVGCommandLine(self.start_pos, self.start_pos),
245
+ *self.path_commands,
246
+ SVGCommandLine(self.end_pos, self.end_pos)]
247
+ return self
248
+
249
+ def is_clockwise(self):
250
+ if len(self.path_commands) == 1:
251
+ cmd = self.path_commands[0]
252
+ return cmd.start_pos.tolist() <= cmd.end_pos.tolist()
253
+
254
+ det_total = 0.
255
+ for cmd in self.path_commands:
256
+ det_total += geom.det(cmd.start_pos, cmd.end_pos)
257
+ return det_total >= 0.
258
+
259
+ def set_orientation(self, orientation):
260
+ """
261
+ orientation: 1 (clockwise), 0 (counter-clockwise)
262
+ """
263
+ if orientation == self.is_clockwise():
264
+ return self
265
+ return self.reverse()
266
+
267
+ def set_closed(self, closed=True):
268
+ self.closed = closed
269
+ return self
270
+
271
+ def reverse(self):
272
+ path_commands = []
273
+
274
+ for command in reversed(self.path_commands):
275
+ path_commands.append(command.reverse())
276
+
277
+ self.path_commands = path_commands
278
+ return self
279
+
280
+ def reverse_non_closed(self):
281
+ if not self.start_pos.isclose(self.end_pos):
282
+ return self.reverse()
283
+ return self
284
+
285
+ def simplify_arcs(self):
286
+ path_commands = []
287
+ for command in self.path_commands:
288
+ if isinstance(command, SVGCommandArc):
289
+ if command.radius.iszero():
290
+ continue
291
+ if command.start_pos.isclose(command.end_pos):
292
+ continue
293
+ path_commands.extend(command.to_beziers())
294
+ else:
295
+ path_commands.append(command)
296
+
297
+ self.path_commands = path_commands
298
+ return self
299
+
300
+ def _get_topleftmost_command(self):
301
+ topleftmost_cmd = None
302
+ topleftmost_idx = 0
303
+
304
+ for i, cmd in enumerate(self.path_commands):
305
+ if topleftmost_cmd is None or cmd.is_left_to(topleftmost_cmd):
306
+ topleftmost_cmd = cmd
307
+ topleftmost_idx = i
308
+
309
+ return topleftmost_cmd, topleftmost_idx
310
+
311
+ def reorder(self):
312
+ if self.closed:
313
+ topleftmost_cmd, topleftmost_idx = self._get_topleftmost_command()
314
+
315
+ self.path_commands = [
316
+ *self.path_commands[topleftmost_idx:],
317
+ *self.path_commands[:topleftmost_idx]
318
+ ]
319
+
320
+ return self
321
+
322
+ def to_video(self, wrapper, clips=None, svg_commands=None, color="grey"):
323
+ from .svg import SVG
324
+ from .svg_primitive import SVGLine, SVGCircle
325
+
326
+ if clips is None:
327
+ clips = []
328
+ if svg_commands is None:
329
+ svg_commands = []
330
+ svg_dots, svg_moves = [], []
331
+
332
+ for command in self.all_commands():
333
+ start_pos, end_pos = command.start_pos, command.end_pos
334
+
335
+ if isinstance(command, SVGCommandMove):
336
+ move = SVGLine(start_pos, end_pos, color="teal", dasharray=0.5)
337
+ svg_moves.append(move)
338
+
339
+ dot = SVGCircle(end_pos, radius=Radius(0.1), color="red")
340
+ svg_dots.append(dot)
341
+
342
+ svg_path = SVGPath(svg_commands).to_group(color=color)
343
+ svg_new_path = SVGPath([SVGCommandMove(start_pos), command]).to_group(color="red")
344
+
345
+ svg_paths = [svg_path, svg_new_path] if svg_commands else [svg_new_path]
346
+ im = SVG([*svg_paths, *svg_moves, *svg_dots]).draw(do_display=False, return_png=True, with_points=False)
347
+ clips.append(wrapper(np.array(im)))
348
+
349
+ svg_dots[-1].color = "grey"
350
+ svg_commands.append(command)
351
+ svg_moves = []
352
+
353
+ return clips, svg_commands
354
+
355
+ def numericalize(self, n=256):
356
+ for command in self.all_commands():
357
+ command.numericalize(n)
358
+
359
+ def smooth(self):
360
+ # https://github.com/paperjs/paper.js/blob/c7d85b663edb728ec78fffa9f828435eaf78d9c9/src/path/Path.js#L1288
361
+ n = len(self.path_commands)
362
+ knots = [self.start_pos, *(path_commmand.end_pos for path_commmand in self.path_commands)]
363
+ r = [knots[0] + 2 * knots[1]]
364
+ f = [2]
365
+ p = [Point(0.)] * (n + 1)
366
+
367
+ # Solve with the Thomas algorithm
368
+ for i in range(1, n):
369
+ internal = i < n - 1
370
+ a = 1
371
+ b = 4 if internal else 2
372
+ u = 4 if internal else 3
373
+ v = 2 if internal else 0
374
+ m = a / f[i-1]
375
+
376
+ f.append(b-m)
377
+ r.append(u * knots[i] + v * knots[i + 1] - m * r[i-1])
378
+
379
+ p[n-1] = r[n-1] / f[n-1]
380
+ for i in range(n-2, -1, -1):
381
+ p[i] = (r[i] - p[i+1]) / f[i]
382
+ p[n] = (3 * knots[n] - p[n-1]) / 2
383
+
384
+ for i in range(n):
385
+ p1, p2 = knots[i], knots[i+1]
386
+ c1, c2 = p[i], 2 * p2 - p[i+1]
387
+ self.path_commands[i] = SVGCommandBezier(p1, c1, c2, p2)
388
+
389
+ return self
390
+
391
+ def simplify_heuristic(self):
392
+ return self.copy().split(max_dist=2, include_lines=False) \
393
+ .simplify(tolerance=0.1, epsilon=0.2, angle_threshold=150) \
394
+ .split(max_dist=7.5)
395
+
396
+ def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False):
397
+ # https://github.com/paperjs/paper.js/blob/c044b698c6b224c10a7747664b2a4cd00a416a25/src/path/PathFitter.js#L44
398
+ points = [self.start_pos, *(path_command.end_pos for path_command in self.path_commands)]
399
+
400
+ def subdivide_indices():
401
+ segments_list = []
402
+ current_segment = []
403
+ prev_command = None
404
+
405
+ for i, command in enumerate(self.path_commands):
406
+ if isinstance(command, SVGCommandLine):
407
+ if current_segment:
408
+ segments_list.append(current_segment)
409
+ current_segment = []
410
+ prev_command = None
411
+
412
+ continue
413
+
414
+ if prev_command is not None and prev_command.angle(command) < angle_threshold:
415
+ if current_segment:
416
+ segments_list.append(current_segment)
417
+ current_segment = []
418
+
419
+ current_segment.append(i)
420
+ prev_command = command
421
+
422
+ if current_segment:
423
+ segments_list.append(current_segment)
424
+
425
+ return segments_list
426
+
427
+ path_commands = []
428
+
429
+ def computeMaxError(first, last, curve: SVGCommandBezier, u):
430
+ maxDist = 0.
431
+ index = (last - first + 1) // 2
432
+ for i in range(1, last - first):
433
+ dist = curve.eval(u[i]).dist(points[first + i]) ** 2
434
+ if dist >= maxDist:
435
+ maxDist = dist
436
+ index = first + i
437
+ return maxDist, index
438
+
439
+ def chordLengthParametrize(first, last):
440
+ u = [0.]
441
+ for i in range(1, last - first + 1):
442
+ u.append(u[i-1] + points[first + i].dist(points[first + i-1]))
443
+
444
+ for i, _ in enumerate(u[1:], 1):
445
+ u[i] /= u[-1]
446
+
447
+ return u
448
+
449
+ def isMachineZero(val):
450
+ MACHINE_EPSILON = 1.12e-16
451
+ return val >= -MACHINE_EPSILON and val <= MACHINE_EPSILON
452
+
453
+ def findRoot(curve: SVGCommandBezier, point, u):
454
+ """
455
+ Newton's root finding algorithm calculates f(x)=0 by reiterating
456
+ x_n+1 = x_n - f(x_n)/f'(x_n)
457
+ We are trying to find curve parameter u for some point p that minimizes
458
+ the distance from that point to the curve. Distance point to curve is d=q(u)-p.
459
+ At minimum distance the point is perpendicular to the curve.
460
+ We are solving
461
+ f = q(u)-p * q'(u) = 0
462
+ with
463
+ f' = q'(u) * q'(u) + q(u)-p * q''(u)
464
+ gives
465
+ u_n+1 = u_n - |q(u_n)-p * q'(u_n)| / |q'(u_n)**2 + q(u_n)-p * q''(u_n)|
466
+ """
467
+ diff = curve.eval(u) - point
468
+ d1, d2 = curve.derivative(u, n=1), curve.derivative(u, n=2)
469
+ numerator = diff.dot(d1)
470
+ denominator = d1.dot(d1) + diff.dot(d2)
471
+
472
+ return u if isMachineZero(denominator) else u - numerator / denominator
473
+
474
+ def reparametrize(first, last, u, curve: SVGCommandBezier):
475
+ for i in range(0, last - first + 1):
476
+ u[i] = findRoot(curve, points[first + i], u[i])
477
+
478
+ for i in range(1, len(u)):
479
+ if u[i] <= u[i-1]:
480
+ return False
481
+
482
+ return True
483
+
484
+ def generateBezier(first, last, uPrime, tan1, tan2):
485
+ epsilon = 1e-12
486
+ p1, p2 = points[first], points[last]
487
+ C = np.zeros((2, 2))
488
+ X = np.zeros(2)
489
+
490
+ for i in range(last - first + 1):
491
+ u = uPrime[i]
492
+ t = 1 - u
493
+ b = 3 * u * t
494
+ b0 = t**3
495
+ b1 = b * t
496
+ b2 = b * u
497
+ b3 = u**3
498
+ a1 = tan1 * b1
499
+ a2 = tan2 * b2
500
+ tmp = points[first + i] - p1 * (b0 + b1) - p2 * (b2 + b3)
501
+
502
+ C[0, 0] += a1.dot(a1)
503
+ C[0, 1] += a1.dot(a2)
504
+ C[1, 0] = C[0, 1]
505
+ C[1, 1] += a2.dot(a2)
506
+ X[0] += a1.dot(tmp)
507
+ X[1] += a2.dot(tmp)
508
+
509
+ detC0C1 = C[0, 0] * C[1, 1] - C[1, 0] * C[0, 1]
510
+ if abs(detC0C1) > epsilon:
511
+ detC0X = C[0, 0] * X[1] - C[1, 0] * X[0]
512
+ detXC1 = X[0] * C[1, 1] - X[1] * C[0, 1]
513
+ alpha1 = detXC1 / detC0C1
514
+ alpha2 = detC0X / detC0C1
515
+ else:
516
+ c0 = C[0, 0] + C[0, 1]
517
+ c1 = C[1, 0] + C[1, 1]
518
+ alpha1 = alpha2 = X[0] / c0 if abs(c0) > epsilon else (X[1] / c1 if abs(c1) > epsilon else 0)
519
+
520
+ segLength = p2.dist(p1)
521
+ eps = epsilon * segLength
522
+ handle1 = handle2 = None
523
+
524
+ if alpha1 < eps or alpha2 < eps:
525
+ alpha1 = alpha2 = segLength / 3
526
+ else:
527
+ line = p2 - p1
528
+ handle1 = tan1 * alpha1
529
+ handle2 = tan2 * alpha2
530
+
531
+ if handle1.dot(line) - handle2.dot(line) > segLength**2:
532
+ alpha1 = alpha2 = segLength / 3
533
+ handle1 = handle2 = None
534
+
535
+ if handle1 is None or handle2 is None:
536
+ handle1 = tan1 * alpha1
537
+ handle2 = tan2 * alpha2
538
+
539
+ return SVGCommandBezier(p1, p1 + handle1, p2 + handle2, p2)
540
+
541
+ def computeLinearMaxError(first, last):
542
+ maxDist = 0.
543
+ index = (last - first + 1) // 2
544
+
545
+ p1, p2 = points[first], points[last]
546
+ for i in range(first + 1, last):
547
+ dist = points[i].distToLine(p1, p2)
548
+ if dist >= maxDist:
549
+ maxDist = dist
550
+ index = i
551
+ return maxDist, index
552
+
553
+ def ramerDouglasPeucker(first, last, epsilon):
554
+ max_error, split_index = computeLinearMaxError(first, last)
555
+
556
+ if max_error > epsilon:
557
+ ramerDouglasPeucker(first, split_index, epsilon)
558
+ ramerDouglasPeucker(split_index, last, epsilon)
559
+ else:
560
+ p1, p2 = points[first], points[last]
561
+ path_commands.append(SVGCommandLine(p1, p2))
562
+
563
+ def fitCubic(error, first, last, tan1=None, tan2=None):
564
+ # For convenience, compute extremity tangents if not provided
565
+ if tan1 is None and tan2 is None:
566
+ tan1 = (points[first + 1] - points[first]).normalize()
567
+ tan2 = (points[last - 1] - points[last]).normalize()
568
+
569
+ if last - first == 1:
570
+ p1, p2 = points[first], points[last]
571
+ dist = p1.dist(p2) / 3
572
+ path_commands.append(SVGCommandBezier(p1, p1 + dist * tan1, p2 + dist * tan2, p2))
573
+ return
574
+
575
+ uPrime = chordLengthParametrize(first, last)
576
+ maxError = max(error, error**2)
577
+ parametersInOrder = True
578
+
579
+ for i in range(5):
580
+ curve = generateBezier(first, last, uPrime, tan1, tan2)
581
+
582
+ max_error, split_index = computeMaxError(first, last, curve, uPrime)
583
+
584
+ if max_error < error and parametersInOrder:
585
+ path_commands.append(curve)
586
+ return
587
+
588
+ if max_error >= maxError:
589
+ break
590
+
591
+ parametersInOrder = reparametrize(first, last, uPrime, curve)
592
+ maxError = max_error
593
+
594
+ tanCenter = (points[split_index-1] - points[split_index+1]).normalize()
595
+ fitCubic(error, first, split_index, tan1, tanCenter)
596
+ fitCubic(error, split_index, last, -tanCenter, tan2)
597
+
598
+ segments_list = subdivide_indices()
599
+ if force_smooth:
600
+ fitCubic(tolerance, 0, len(points) - 1)
601
+ else:
602
+ if segments_list:
603
+ seg = segments_list[0]
604
+ ramerDouglasPeucker(0, seg[0], epsilon)
605
+
606
+ for seg, seg_next in zip(segments_list[:-1], segments_list[1:]):
607
+ fitCubic(tolerance, seg[0], seg[-1] + 1)
608
+ ramerDouglasPeucker(seg[-1] + 1, seg_next[0], epsilon)
609
+
610
+ seg = segments_list[-1]
611
+ fitCubic(tolerance, seg[0], seg[-1] + 1)
612
+ ramerDouglasPeucker(seg[-1] + 1, len(points) - 1, epsilon)
613
+ else:
614
+ ramerDouglasPeucker(0, len(points) - 1, epsilon)
615
+
616
+ self.path_commands = path_commands
617
+
618
+ return self
619
+
620
+ def split(self, n=None, max_dist=None, include_lines=True):
621
+ path_commands = []
622
+
623
+ for command in self.path_commands:
624
+ if isinstance(command, SVGCommandLine) and not include_lines:
625
+ path_commands.append(command)
626
+ else:
627
+ l = command.length()
628
+ if max_dist is not None:
629
+ n = max(math.ceil(l / max_dist), 1)
630
+
631
+ path_commands.extend(command.split(n=n))
632
+
633
+ self.path_commands = path_commands
634
+
635
+ return self
636
+
637
+ def bbox(self):
638
+ return union_bbox([cmd.bbox() for cmd in self.path_commands])
639
+
640
+ def sample_points(self, max_dist=0.4):
641
+ points = []
642
+
643
+ for command in self.path_commands:
644
+ l = command.length()
645
+ n = max(math.ceil(l / max_dist), 1)
646
+ points.extend(command.sample_points(n=n, return_array=True)[None])
647
+ points = np.concatenate(points, axis=0)
648
+ return points
649
+
650
+ def to_shapely(self):
651
+ polygon = shapely.geometry.Polygon(self.sample_points())
652
+
653
+ if not polygon.is_valid:
654
+ polygon = polygon.buffer(0)
655
+
656
+ return polygon
657
+
658
+ def to_points(self):
659
+ return np.array([self.start_pos.pos, *(cmd.end_pos.pos for cmd in self.path_commands)])
src/preprocessing/deepsvg/deepsvg_svglib/svg_primitive.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from .geom import *
8
+ import torch
9
+ import re
10
+ from typing import List, Union
11
+ from xml.dom import minidom
12
+ from .svg_path import SVGPath
13
+ from .svg_command import SVGCommandLine, SVGCommandArc, SVGCommandBezier, SVGCommandClose
14
+ import shapely as shapely
15
+ import shapely.ops
16
+ import shapely.geometry
17
+ import networkx as nx
18
+
19
+
20
+ FLOAT_RE = re.compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
21
+
22
+
23
+ def extract_args(args):
24
+ return list(map(float, FLOAT_RE.findall(args)))
25
+
26
+
27
+ class SVGPrimitive:
28
+ """
29
+ Reference: https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Basic_Shapes
30
+ """
31
+ def __init__(self, color="black", fill=False, dasharray=None, stroke_width=".3", opacity=1.0):
32
+ self.color = color
33
+ self.dasharray = dasharray
34
+ self.stroke_width = stroke_width
35
+ self.opacity = opacity
36
+
37
+ self.fill = fill
38
+
39
+ def _get_fill_attr(self):
40
+ fill_attr = f'fill="{self.color}" fill-opacity="{self.opacity}"' if self.fill else f'fill="none" stroke="{self.color}" stroke-width="{self.stroke_width}" stroke-opacity="{self.opacity}"'
41
+ if self.dasharray is not None and not self.fill:
42
+ fill_attr += f' stroke-dasharray="{self.dasharray}"'
43
+ return fill_attr
44
+
45
+ @classmethod
46
+ def from_xml(cls, x: minidom.Element):
47
+ raise NotImplementedError
48
+
49
+ def draw(self, viewbox=Bbox(24), *args, **kwargs):
50
+ from .svg import SVG
51
+ return SVG([self], viewbox=viewbox).draw(*args, **kwargs)
52
+
53
+ def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=True, with_moves=True):
54
+ return []
55
+
56
+ def to_path(self):
57
+ raise NotImplementedError
58
+
59
+ def copy(self):
60
+ raise NotImplementedError
61
+
62
+ def bbox(self):
63
+ raise NotImplementedError
64
+
65
+ def fill_(self, fill=True):
66
+ self.fill = fill
67
+ return self
68
+
69
+
70
+ class SVGEllipse(SVGPrimitive):
71
+ def __init__(self, center: Point, radius: Radius, *args, **kwargs):
72
+ super().__init__(*args, **kwargs)
73
+
74
+ self.center = center
75
+ self.radius = radius
76
+
77
+ def __repr__(self):
78
+ return f'SVGEllipse(c={self.center} r={self.radius})'
79
+
80
+ def to_str(self, *args, **kwargs):
81
+ fill_attr = self._get_fill_attr()
82
+ return f'<ellipse {fill_attr} cx="{self.center.x}" cy="{self.center.y}" rx="{self.radius.x}" ry="{self.radius.y}"/>'
83
+
84
+ @classmethod
85
+ def from_xml(_, x: minidom.Element):
86
+ fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
87
+
88
+ center = Point(float(x.getAttribute("cx")), float(x.getAttribute("cy")))
89
+ radius = Radius(float(x.getAttribute("rx")), float(x.getAttribute("ry")))
90
+ return SVGEllipse(center, radius, fill=fill)
91
+
92
+ def to_path(self):
93
+ p0, p1 = self.center + self.radius.xproj(), self.center + self.radius.yproj()
94
+ p2, p3 = self.center - self.radius.xproj(), self.center - self.radius.yproj()
95
+ commands = [
96
+ SVGCommandArc(p0, self.radius, Angle(0.), Flag(0.), Flag(1.), p1),
97
+ SVGCommandArc(p1, self.radius, Angle(0.), Flag(0.), Flag(1.), p2),
98
+ SVGCommandArc(p2, self.radius, Angle(0.), Flag(0.), Flag(1.), p3),
99
+ SVGCommandArc(p3, self.radius, Angle(0.), Flag(0.), Flag(1.), p0),
100
+ ]
101
+ return SVGPath(commands, closed=True).to_group(fill=self.fill)
102
+
103
+
104
+ class SVGCircle(SVGEllipse):
105
+ def __init__(self, *args, **kwargs):
106
+ super().__init__(*args, **kwargs)
107
+
108
+ def __repr__(self):
109
+ return f'SVGCircle(c={self.center} r={self.radius})'
110
+
111
+ def to_str(self, *args, **kwargs):
112
+ fill_attr = self._get_fill_attr()
113
+ return f'<circle {fill_attr} cx="{self.center.x}" cy="{self.center.y}" r="{self.radius.x}"/>'
114
+
115
+ @classmethod
116
+ def from_xml(_, x: minidom.Element):
117
+ fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
118
+
119
+ center = Point(float(x.getAttribute("cx")), float(x.getAttribute("cy")))
120
+ radius = Radius(float(x.getAttribute("r")))
121
+ return SVGCircle(center, radius, fill=fill)
122
+
123
+
124
+ class SVGRectangle(SVGPrimitive):
125
+ def __init__(self, xy: Point, wh: Size, *args, **kwargs):
126
+ super().__init__(*args, **kwargs)
127
+
128
+ self.xy = xy
129
+ self.wh = wh
130
+
131
+ def __repr__(self):
132
+ return f'SVGRectangle(xy={self.xy} wh={self.wh})'
133
+
134
+ def to_str(self, *args, **kwargs):
135
+ fill_attr = self._get_fill_attr()
136
+ return f'<rect {fill_attr} x="{self.xy.x}" y="{self.xy.y}" width="{self.wh.x}" height="{self.wh.y}"/>'
137
+
138
+ @classmethod
139
+ def from_xml(_, x: minidom.Element):
140
+ fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
141
+
142
+ xy = Point(0.)
143
+ if x.hasAttribute("x"):
144
+ xy.pos[0] = float(x.getAttribute("x"))
145
+ if x.hasAttribute("y"):
146
+ xy.pos[1] = float(x.getAttribute("y"))
147
+ wh = Size(float(x.getAttribute("width")), float(x.getAttribute("height")))
148
+ return SVGRectangle(xy, wh, fill=fill)
149
+
150
+ def to_path(self):
151
+ p0, p1, p2, p3 = self.xy, self.xy + self.wh.xproj(), self.xy + self.wh, self.xy + self.wh.yproj()
152
+ commands = [
153
+ SVGCommandLine(p0, p1),
154
+ SVGCommandLine(p1, p2),
155
+ SVGCommandLine(p2, p3),
156
+ SVGCommandLine(p3, p0)
157
+ ]
158
+ return SVGPath(commands, closed=True).to_group(fill=self.fill)
159
+
160
+
161
+ class SVGLine(SVGPrimitive):
162
+ def __init__(self, start_pos: Point, end_pos: Point, *args, **kwargs):
163
+ super().__init__(*args, **kwargs)
164
+
165
+ self.start_pos = start_pos
166
+ self.end_pos = end_pos
167
+
168
+ def __repr__(self):
169
+ return f'SVGLine(xy1={self.start_pos} xy2={self.end_pos})'
170
+
171
+ def to_str(self, *args, **kwargs):
172
+ fill_attr = self._get_fill_attr()
173
+ return f'<line {fill_attr} x1="{self.start_pos.x}" y1="{self.start_pos.y}" x2="{self.end_pos.x}" y2="{self.end_pos.y}"/>'
174
+
175
+ @classmethod
176
+ def from_xml(_, x: minidom.Element):
177
+ fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
178
+
179
+ start_pos = Point(float(x.getAttribute("x1") or 0.), float(x.getAttribute("y1") or 0.))
180
+ end_pos = Point(float(x.getAttribute("x2") or 0.), float(x.getAttribute("y2") or 0.))
181
+ return SVGLine(start_pos, end_pos, fill=fill)
182
+
183
+ def to_path(self):
184
+ return SVGPath([SVGCommandLine(self.start_pos, self.end_pos)]).to_group(fill=self.fill)
185
+
186
+
187
+ class SVGPolyline(SVGPrimitive):
188
+ def __init__(self, points: List[Point], *args, **kwargs):
189
+ super().__init__(*args, **kwargs)
190
+
191
+ self.points = points
192
+
193
+ def __repr__(self):
194
+ return f'SVGPolyline(points={self.points})'
195
+
196
+ def to_str(self, *args, **kwargs):
197
+ fill_attr = self._get_fill_attr()
198
+ return '<polyline {} points="{}"/>'.format(fill_attr, ' '.join([p.to_str() for p in self.points]))
199
+
200
+ @classmethod
201
+ def from_xml(cls, x: minidom.Element):
202
+ fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
203
+
204
+ args = extract_args(x.getAttribute("points"))
205
+ assert len(args) % 2 == 0, f"Expected even number of arguments for SVGPolyline: {len(args)} given"
206
+ points = [Point(x, args[2*i+1]) for i, x in enumerate(args[::2])]
207
+ return cls(points, fill=fill)
208
+
209
+ def to_path(self):
210
+ commands = [SVGCommandLine(p1, p2) for p1, p2 in zip(self.points[:-1], self.points[1:])]
211
+ is_closed = self.__class__.__name__ == "SVGPolygon"
212
+ return SVGPath(commands, closed=is_closed).to_group(fill=self.fill)
213
+
214
+
215
+ class SVGPolygon(SVGPolyline):
216
+ def __init__(self, *args, **kwargs):
217
+ super().__init__(*args, **kwargs)
218
+
219
+ def __repr__(self):
220
+ return f'SVGPolygon(points={self.points})'
221
+
222
+ def to_str(self, *args, **kwargs):
223
+ fill_attr = self._get_fill_attr()
224
+ return '<polygon {} points="{}"/>'.format(fill_attr, ' '.join([p.to_str() for p in self.points]))
225
+
226
+
227
+ class SVGPathGroup(SVGPrimitive):
228
+ def __init__(self, svg_paths: List[SVGPath] = None, origin=None, *args, **kwargs):
229
+ super().__init__(*args, **kwargs)
230
+ self.svg_paths = svg_paths
231
+
232
+ if origin is None:
233
+ origin = Point(0.)
234
+ self.origin = origin
235
+
236
+ # Alias
237
+ @property
238
+ def paths(self):
239
+ return self.svg_paths
240
+
241
+ @property
242
+ def path(self):
243
+ return self.svg_paths[0]
244
+
245
+ def __getitem__(self, idx):
246
+ return self.svg_paths[idx]
247
+
248
+ def __len__(self):
249
+ return len(self.paths)
250
+
251
+ def total_len(self):
252
+ return sum([len(path) for path in self.svg_paths])
253
+
254
+ @property
255
+ def start_pos(self):
256
+ return self.svg_paths[0].start_pos
257
+
258
+ @property
259
+ def end_pos(self):
260
+ last_path = self.svg_paths[-1]
261
+ if last_path.closed:
262
+ return last_path.start_pos
263
+ return last_path.end_pos
264
+
265
+ def set_origin(self, origin: Point):
266
+ self.origin = origin
267
+ if self.svg_paths:
268
+ self.svg_paths[0].origin = origin
269
+ self.recompute_origins()
270
+
271
+ def append(self, path: SVGPath):
272
+ self.svg_paths.append(path)
273
+
274
+ def copy(self):
275
+ return SVGPathGroup([svg_path.copy() for svg_path in self.svg_paths], self.origin.copy(),
276
+ self.color, self.fill, self.dasharray, self.stroke_width, self.opacity)
277
+
278
+ def __repr__(self):
279
+ return "SVGPathGroup({})".format(", ".join(svg_path.__repr__() for svg_path in self.svg_paths))
280
+
281
+ def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=True, with_moves=True):
282
+ viz_elements = []
283
+ for svg_path in self.svg_paths:
284
+ viz_elements.extend(svg_path._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves))
285
+
286
+ if with_bboxes:
287
+ viz_elements.append(self._get_bbox_viz())
288
+
289
+ return viz_elements
290
+
291
+ def _get_bbox_viz(self):
292
+ color = "red" if self.color == "black" else self.color
293
+ bbox = self.bbox().to_rectangle(color=color)
294
+ return bbox
295
+
296
+ def to_path(self):
297
+ return self
298
+
299
+ def to_str(self, with_markers=False, *args, **kwargs):
300
+ fill_attr = self._get_fill_attr()
301
+ marker_attr = 'marker-start="url(#arrow)"' if with_markers else ''
302
+ return '<path {} {} filling="{}" d="{}"></path>'.format(fill_attr, marker_attr, self.path.filling,
303
+ " ".join(svg_path.to_str() for svg_path in self.svg_paths))
304
+
305
+ def to_tensor(self, PAD_VAL=-1):
306
+ return torch.cat([p.to_tensor(PAD_VAL=PAD_VAL) for p in self.svg_paths], dim=0)
307
+
308
+ def _apply_to_paths(self, method, *args, **kwargs):
309
+ for path in self.svg_paths:
310
+ getattr(path, method)(*args, **kwargs)
311
+ return self
312
+
313
+ def translate(self, vec):
314
+ return self._apply_to_paths("translate", vec)
315
+
316
+ def rotate(self, angle: Angle):
317
+ return self._apply_to_paths("rotate", angle)
318
+
319
+ def scale(self, factor):
320
+ return self._apply_to_paths("scale", factor)
321
+
322
+ def numericalize(self, n=256):
323
+ return self._apply_to_paths("numericalize", n)
324
+
325
+ def drop_z(self):
326
+ return self._apply_to_paths("set_closed", False)
327
+
328
+ def recompute_origins(self):
329
+ origin = self.origin
330
+ for path in self.svg_paths:
331
+ path.origin = origin.copy()
332
+ origin = path.end_pos
333
+ return self
334
+
335
+ def reorder(self):
336
+ self._apply_to_paths("reorder")
337
+ self.recompute_origins()
338
+ return self
339
+
340
+ def filter_empty(self):
341
+ self.svg_paths = [path for path in self.svg_paths if path.path_commands]
342
+ return self
343
+
344
+ def canonicalize(self):
345
+ self.svg_paths = sorted(self.svg_paths, key=lambda x: x.start_pos.tolist()[::-1])
346
+ if not self.svg_paths[0].is_clockwise():
347
+ self._apply_to_paths("reverse")
348
+
349
+ self.recompute_origins()
350
+ return self
351
+
352
+ def reverse(self):
353
+ self._apply_to_paths("reverse")
354
+
355
+ self.recompute_origins()
356
+ return self
357
+
358
+ def duplicate_extremities(self):
359
+ self._apply_to_paths("duplicate_extremities")
360
+ return self
361
+
362
+ def reverse_non_closed(self):
363
+ self._apply_to_paths("reverse_non_closed")
364
+
365
+ self.recompute_origins()
366
+ return self
367
+
368
+ def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False):
369
+ self._apply_to_paths("simplify", tolerance=tolerance, epsilon=epsilon, angle_threshold=angle_threshold,
370
+ force_smooth=force_smooth)
371
+ self.recompute_origins()
372
+ return self
373
+
374
+ def split_paths(self):
375
+ return [SVGPathGroup([svg_path], self.origin,
376
+ self.color, self.fill, self.dasharray, self.stroke_width, self.opacity)
377
+ for svg_path in self.svg_paths]
378
+
379
+ def split(self, n=None, max_dist=None, include_lines=True):
380
+ return self._apply_to_paths("split", n=n, max_dist=max_dist, include_lines=include_lines)
381
+
382
+ def simplify_arcs(self):
383
+ return self._apply_to_paths("simplify_arcs")
384
+
385
+ def filter_consecutives(self):
386
+ return self._apply_to_paths("filter_consecutives")
387
+
388
+ def filter_duplicates(self):
389
+ return self._apply_to_paths("filter_duplicates")
390
+
391
+ def bbox(self):
392
+ return union_bbox([path.bbox() for path in self.svg_paths])
393
+
394
+ def to_shapely(self):
395
+ return shapely.ops.unary_union([path.to_shapely() for path in self.svg_paths])
396
+
397
+ def compute_filling(self):
398
+ if self.fill:
399
+ G = self.overlap_graph()
400
+
401
+ root_nodes = [i for i, d in G.in_degree() if d == 0]
402
+
403
+ for root in root_nodes:
404
+ if not self.svg_paths[root].closed:
405
+ continue
406
+
407
+ current = [(1, root)]
408
+
409
+ while current:
410
+ visited = set()
411
+ neighbors = set()
412
+ for d, n in current:
413
+ self.svg_paths[n].set_filling(d != 0)
414
+
415
+ for n2 in G.neighbors(n):
416
+ if not n2 in visited:
417
+ d2 = d + (self.svg_paths[n2].is_clockwise() == self.svg_paths[n].is_clockwise()) * 2 - 1
418
+ visited.add(n2)
419
+ neighbors.add((d2, n2))
420
+
421
+ G.remove_nodes_from([n for d, n in current])
422
+
423
+ current = [(d, n) for d, n in neighbors if G.in_degree(n) == 0]
424
+
425
+ return self
426
+
427
+ def overlap_graph(self, threshold=0.9, draw=False):
428
+ G = nx.DiGraph()
429
+ shapes = [path.to_shapely() for path in self.svg_paths]
430
+
431
+ for i, path1 in enumerate(shapes):
432
+ G.add_node(i)
433
+
434
+ if self.svg_paths[i].closed:
435
+ for j, path2 in enumerate(shapes):
436
+ if i != j and self.svg_paths[j].closed:
437
+ overlap = path1.intersection(path2).area / path1.area
438
+ if overlap > threshold:
439
+ G.add_edge(j, i, weight=overlap)
440
+
441
+ if draw:
442
+ pos = nx.spring_layout(G)
443
+ nx.draw_networkx(G, pos, with_labels=True)
444
+ labels = nx.get_edge_attributes(G, 'weight')
445
+ nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
446
+ return G
447
+
448
+ def bbox_overlap(self, other: SVGPathGroup):
449
+ return self.bbox().overlap(other.bbox())
450
+
451
+ def to_points(self):
452
+ return np.concatenate([path.to_points() for path in self.svg_paths])
src/preprocessing/deepsvg/deepsvg_svglib/svglib_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import src.preprocessing.deepsvg.deepsvg_svglib.svg as svg_lib
7
+ from .geom import Bbox, Point
8
+ import math
9
+ import numpy as np
10
+ import IPython.display as ipd
11
+ from moviepy.editor import ImageClip, concatenate_videoclips, ipython_display
12
+
13
+
14
+ def make_grid(svgs, num_cols=3, grid_width=24):
15
+ """
16
+ svgs: List[svg_lib.SVG]
17
+ """
18
+ nb_rows = math.ceil(len(svgs) / num_cols)
19
+ grid = svg_lib.SVG([], viewbox=Bbox(grid_width * num_cols, grid_width * nb_rows))
20
+
21
+ for i, svg in enumerate(svgs):
22
+ row, col = i // num_cols, i % num_cols
23
+ svg = svg.copy().translate(Point(grid_width * col, grid_width * row))
24
+
25
+ grid.add_path_groups(svg.svg_path_groups)
26
+
27
+ return grid
28
+
29
+
30
+ def make_grid_grid(svg_grid, grid_width=24):
31
+ """
32
+ svg_grid: List[List[svg_lib.SVG]]
33
+ """
34
+ nb_rows = len(svg_grid)
35
+ num_cols = len(svg_grid[0])
36
+ grid = svg_lib.SVG([], viewbox=Bbox(grid_width * num_cols, grid_width * nb_rows))
37
+
38
+ for i, row in enumerate(svg_grid):
39
+ for j, svg in enumerate(row):
40
+ svg = svg.copy().translate(Point(grid_width * j, grid_width * i))
41
+
42
+ grid.add_path_groups(svg.svg_path_groups)
43
+
44
+ return grid
45
+
46
+
47
+ def make_grid_lines(svg_grid, grid_width=24):
48
+ """
49
+ svg_grid: List[List[svg_lib.SVG]]
50
+ """
51
+ nb_rows = len(svg_grid)
52
+ num_cols = max(len(r) for r in svg_grid)
53
+ grid = svg_lib.SVG([], viewbox=Bbox(grid_width * num_cols, grid_width * nb_rows))
54
+
55
+ for i, row in enumerate(svg_grid):
56
+ for j, svg in enumerate(row):
57
+ j_shift = (num_cols - len(row)) // 2
58
+ svg = svg.copy().translate(Point(grid_width * (j + j_shift), grid_width * i))
59
+
60
+ grid.add_path_groups(svg.svg_path_groups)
61
+
62
+ return grid
63
+
64
+
65
+ COLORS = ["aliceblue", "antiquewhite", "aqua", "aquamarine", "azure", "beige", "bisque", "black", "blanchedalmond",
66
+ "blue", "blueviolet", "brown", "burlywood", "cadetblue", "chartreuse", "chocolate", "coral", "cornflowerblue",
67
+ "cornsilk", "crimson", "cyan", "darkblue", "darkcyan", "darkgoldenrod", "darkgray", "darkgreen", "darkgrey",
68
+ "darkkhaki", "darkmagenta", "darkolivegreen", "darkorange", "darkorchid", "darkred", "darksalmon",
69
+ "darkseagreen", "darkslateblue", "darkslategray", "darkslategrey", "darkturquoise", "darkviolet", "deeppink",
70
+ "deepskyblue", "dimgray", "dimgrey", "dodgerblue", "firebrick", "floralwhite", "forestgreen", "fuchsia",
71
+ "gainsboro", "ghostwhite", "gold", "goldenrod", "gray", "green", "greenyellow", "grey", "honeydew", "hotpink",
72
+ "indianred", "indigo", "ivory", "khaki", "lavender", "lavenderblush", "lawngreen", "lemonchiffon",
73
+ "lightblue", "lightcoral", "lightcyan", "lightgoldenrodyellow", "lightgray", "lightgreen", "lightgrey",
74
+ "lightpink", "lightsalmon", "lightseagreen", "lightskyblue", "lightslategray", "lightslategrey",
75
+ "lightsteelblue", "lightyellow", "lime", "limegreen", "linen", "magenta", "maroon", "mediumaquamarine",
76
+ "mediumblue", "mediumorchid", "mediumpurple", "mediumseagreen", "mediumslateblue", "mediumspringgreen",
77
+ "mediumturquoise", "mediumvioletred", "midnightblue", "mintcream", "mistyrose", "moccasin", "navajowhite",
78
+ "navy", "oldlace", "olive", "olivedrab", "orange", "orangered", "orchid", "palegoldenrod", "palegreen",
79
+ "paleturquoise", "palevioletred", "papayawhip", "peachpuff", "peru", "pink", "plum", "powderblue", "purple",
80
+ "red", "rosybrown", "royalblue", "saddlebrown", "salmon", "sandybrown", "seagreen", "seashell", "sienna",
81
+ "silver", "skyblue", "slateblue", "slategray", "slategrey", "snow", "springgreen", "steelblue", "tan", "teal",
82
+ "thistle", "tomato", "turquoise", "violet", "wheat", "white", "whitesmoke", "yellow", "yellowgreen"]
83
+
84
+
85
+ def to_gif(img_list, file_path=None, frame_duration=0.1, do_display=True):
86
+ clips = [ImageClip(np.array(img)).set_duration(frame_duration) for img in img_list]
87
+
88
+ clip = concatenate_videoclips(clips, method="compose", bg_color=(255, 255, 255))
89
+
90
+ if file_path is not None:
91
+ clip.write_gif(file_path, fps=24, verbose=False, logger=None)
92
+
93
+ if do_display:
94
+ src = clip if file_path is None else file_path
95
+ ipd.display(ipython_display(src, fps=24, rd_kwargs=dict(logger=None), autoplay=1, loop=1))
src/preprocessing/deepsvg/deepsvg_svglib/util_fns.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import math
7
+
8
+
9
+ def get_roots(a, b, c):
10
+ if a == 0:
11
+ if b == 0:
12
+ return []
13
+ return [-c / b]
14
+ r = b * b - 4 * a * c
15
+ if r < 0:
16
+ return []
17
+ elif r == 0:
18
+ x0 = -b / (2 * a)
19
+ return [x0]
20
+
21
+ x1, x2 = (-b - math.sqrt(r)) / (2 * a), (-b + math.sqrt(r)) / (2 * a)
22
+ return x1, x2
src/preprocessing/deepsvg/deepsvg_utils/train_utils.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import shutil
7
+ import torch
8
+ import torch.nn as nn
9
+ import os
10
+ import random
11
+ import numpy as np
12
+ import glob
13
+
14
+
15
+ def save_ckpt(checkpoint_dir, model, cfg=None, optimizer=None, scheduler_lr=None, scheduler_warmup=None,
16
+ stats=None, train_vars=None):
17
+ if is_multi_gpu(model):
18
+ model = model.module
19
+
20
+ state = {
21
+ "model": model.state_dict()
22
+ }
23
+
24
+ if optimizer is not None:
25
+ state["optimizer"] = optimizer.state_dict()
26
+ if scheduler_lr is not None:
27
+ state["scheduler_lr"] = scheduler_lr.state_dict()
28
+ if scheduler_warmup is not None:
29
+ state["scheduler_warmup"] = scheduler_warmup.state_dict()
30
+ if cfg is not None:
31
+ state["cfg"] = cfg.to_dict()
32
+ if stats is not None:
33
+ state["stats"] = stats.to_dict()
34
+ if train_vars is not None:
35
+ state["train_vars"] = train_vars.to_dict()
36
+
37
+ checkpoint_path = os.path.join(checkpoint_dir, "{:06d}.pth.tar".format(stats.step))
38
+
39
+ if not os.path.exists(checkpoint_dir):
40
+ os.makedirs(checkpoint_dir)
41
+ torch.save(state, checkpoint_path)
42
+
43
+ if stats.is_best():
44
+ best_model_path = os.path.join(checkpoint_dir, "best.pth.tar")
45
+ shutil.copyfile(checkpoint_path, best_model_path)
46
+
47
+
48
+ def save_ckpt_list(checkpoint_dir, model, cfg=None, optimizers=None, scheduler_lrs=None, scheduler_warmups=None,
49
+ stats=None, train_vars=None):
50
+ if is_multi_gpu(model):
51
+ model = model.module
52
+
53
+ state = {
54
+ "model": model.state_dict()
55
+ }
56
+
57
+ if optimizers is not None:
58
+ state["optimizers"] = [optimizer.state_dict() if optimizer is not None else optimizer for optimizer in optimizers]
59
+ if scheduler_lrs is not None:
60
+ state["scheduler_lrs"] = [scheduler_lr.state_dict() if scheduler_lr is not None else scheduler_lr for scheduler_lr in scheduler_lrs]
61
+ if scheduler_warmups is not None:
62
+ state["scheduler_warmups"] = [scheduler_warmup.state_dict() if scheduler_warmup is not None else None for scheduler_warmup in scheduler_warmups]
63
+ if cfg is not None:
64
+ state["cfg"] = cfg.to_dict()
65
+ if stats is not None:
66
+ state["stats"] = stats.to_dict()
67
+ if train_vars is not None:
68
+ state["train_vars"] = train_vars.to_dict()
69
+
70
+ checkpoint_path = os.path.join(checkpoint_dir, "{:06d}.pth.tar".format(stats.step))
71
+
72
+ if not os.path.exists(checkpoint_dir):
73
+ os.makedirs(checkpoint_dir)
74
+ torch.save(state, checkpoint_path)
75
+
76
+ if stats.is_best():
77
+ best_model_path = os.path.join(checkpoint_dir, "best.pth.tar")
78
+ shutil.copyfile(checkpoint_path, best_model_path)
79
+
80
+
81
+ def load_ckpt(checkpoint_dir, model, cfg=None, optimizer=None, scheduler_lr=None, scheduler_warmup=None,
82
+ stats=None, train_vars=None):
83
+ if not os.path.exists(checkpoint_dir):
84
+ return False
85
+
86
+ if os.path.isfile(checkpoint_dir):
87
+ checkpoint_path = checkpoint_dir
88
+ else:
89
+ ckpts_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "./[0-9]*.pth.tar")))
90
+ if not ckpts_paths:
91
+ return False
92
+ checkpoint_path = ckpts_paths[-1]
93
+
94
+ state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
95
+
96
+ if is_multi_gpu(model):
97
+ model = model.module
98
+ model.load_state_dict(state["model"], strict=False)
99
+
100
+ if optimizer is not None:
101
+ optimizer.load_state_dict(state["optimizer"])
102
+ if scheduler_lr is not None:
103
+ scheduler_lr.load_state_dict(state["scheduler_lr"])
104
+ if scheduler_warmup is not None:
105
+ scheduler_warmup.load_state_dict(state["scheduler_warmup"])
106
+ if cfg is not None:
107
+ cfg.load_dict(state["cfg"])
108
+ if stats is not None:
109
+ stats.load_dict(state["stats"])
110
+ if train_vars is not None:
111
+ train_vars.load_dict(state["train_vars"])
112
+
113
+ return True
114
+
115
+
116
+ def load_ckpt_list(checkpoint_dir, model, cfg=None, optimizers=None, scheduler_lrs=None, scheduler_warmups=None,
117
+ stats=None, train_vars=None):
118
+ if not os.path.exists(checkpoint_dir):
119
+ return False
120
+
121
+ if os.path.isfile(checkpoint_dir):
122
+ checkpoint_path = checkpoint_dir
123
+ else:
124
+ ckpts_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "./[0-9]*.pth.tar")))
125
+ if not ckpts_paths:
126
+ return False
127
+ checkpoint_path = ckpts_paths[-1]
128
+
129
+ state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
130
+
131
+ if is_multi_gpu(model):
132
+ model = model.module
133
+ model.load_state_dict(state["model"], strict=False)
134
+
135
+ for optimizer, scheduler_lr, scheduler_warmup, optimizer_sd, scheduler_lr_sd, scheduler_warmups_sd in zip(optimizers, scheduler_lrs, scheduler_warmups, state["optimizers"], state["scheduler_lrs"], state["scheduler_warmups"]):
136
+ if optimizer is not None and optimizer_sd is not None:
137
+ optimizer.load_state_dict(optimizer_sd)
138
+ if scheduler_lr is not None and scheduler_lr_sd is not None:
139
+ scheduler_lr.load_state_dict(scheduler_lr_sd)
140
+ if scheduler_warmup is not None and scheduler_warmups_sd is not None:
141
+ scheduler_warmup.load_state_dict(scheduler_warmups_sd)
142
+ if cfg is not None and state["cfg"] is not None:
143
+ cfg.load_dict(state["cfg"])
144
+ if stats is not None and state["stats"] is not None:
145
+ stats.load_dict(state["stats"])
146
+ if train_vars is not None and state["train_vars"] is not None:
147
+ train_vars.load_dict(state["train_vars"])
148
+
149
+ return True
150
+
151
+
152
+ def load_model(checkpoint_path, model):
153
+ state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
154
+
155
+ if is_multi_gpu(model):
156
+ model = model.module
157
+ model.load_state_dict(state["model"], strict=False)
158
+
159
+
160
+ def is_multi_gpu(model):
161
+ return isinstance(model, nn.DataParallel)
162
+
163
+
164
+ def count_parameters(model):
165
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
166
+
167
+
168
+ def pad_sequence(sequences, batch_first=False, padding_value=0, max_len=None):
169
+ r"""Pad a list of variable length Tensors with ``padding_value``
170
+
171
+ ``pad_sequence`` stacks a list of Tensors along a new dimension,
172
+ and pads them to equal length. For example, if the input is list of
173
+ sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
174
+ otherwise.
175
+
176
+ `B` is batch size. It is equal to the number of elements in ``sequences``.
177
+ `T` is length of the longest sequence.
178
+ `L` is length of the sequence.
179
+ `*` is any number of trailing dimensions, including none.
180
+
181
+ Example:
182
+ >>> from torch.nn.utils.rnn import pad_sequence
183
+ >>> a = torch.ones(25, 300)
184
+ >>> b = torch.ones(22, 300)
185
+ >>> c = torch.ones(15, 300)
186
+ >>> pad_sequence([a, b, c]).size()
187
+ torch.Size([25, 3, 300])
188
+
189
+ Note:
190
+ This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
191
+ where `T` is the length of the longest sequence. This function assumes
192
+ trailing dimensions and type of all the Tensors in sequences are same.
193
+
194
+ Arguments:
195
+ sequences (list[Tensor]): list of variable length sequences.
196
+ batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
197
+ ``T x B x *`` otherwise
198
+ padding_value (float, optional): value for padded elements. Default: 0.
199
+
200
+ Returns:
201
+ Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
202
+ Tensor of size ``B x T x *`` otherwise
203
+ """
204
+
205
+ # assuming trailing dimensions and type of all the Tensors
206
+ # in sequences are same and fetching those from sequences[0]
207
+ max_size = sequences[0].size()
208
+ trailing_dims = max_size[1:]
209
+
210
+ if max_len is None:
211
+ max_len = max([s.size(0) for s in sequences])
212
+ if batch_first:
213
+ out_dims = (len(sequences), max_len) + trailing_dims
214
+ else:
215
+ out_dims = (max_len, len(sequences)) + trailing_dims
216
+
217
+ out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
218
+ for i, tensor in enumerate(sequences):
219
+ length = tensor.size(0)
220
+ # use index notation to prevent duplicate references to the tensor
221
+ if batch_first:
222
+ out_tensor[i, :length, ...] = tensor
223
+ else:
224
+ out_tensor[:length, i, ...] = tensor
225
+
226
+ return out_tensor
227
+
228
+
229
+ def set_seed(_seed=42):
230
+ random.seed(_seed)
231
+ np.random.seed(_seed)
232
+ torch.manual_seed(_seed)
233
+ torch.cuda.manual_seed(_seed)
234
+ torch.cuda.manual_seed_all(_seed)
235
+ os.environ['PYTHONHASHSEED'] = str(_seed)
236
+
237
+
238
+ def infinite_range(start_idx=0):
239
+ while True:
240
+ yield start_idx
241
+ start_idx += 1
src/preprocessing/deepsvg/deepsvg_utils/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is taken from <https://github.com/alexandre01/deepsvg>
2
+ by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
3
+ from the paper >https://arxiv.org/pdf/2007.11301.pdf>
4
+ """
5
+
6
+ import torch
7
+
8
+
9
+ def linear(a, b, x, min_x, max_x):
10
+ """
11
+ b ___________
12
+ /|
13
+ / |
14
+ a _______/ |
15
+ | |
16
+ min_x max_x
17
+ """
18
+ return a + min(max((x - min_x) / (max_x - min_x), 0), 1) * (b - a)
19
+
20
+
21
+ def batchify(data, device):
22
+ return (d.unsqueeze(0).to(device) for d in data)
23
+
24
+
25
+ def _make_seq_first(*args):
26
+ # N, G, S, ... -> S, G, N, ...
27
+ if len(args) == 1:
28
+ arg, = args
29
+ return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
30
+ return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)
31
+
32
+
33
+ def _make_batch_first(*args):
34
+ # S, G, N, ... -> N, G, S, ...
35
+ if len(args) == 1:
36
+ arg, = args
37
+ return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
38
+ return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)
39
+
40
+
41
+ def _pack_group_batch(*args):
42
+ # S, G, N, ... -> S, G * N, ...
43
+ if len(args) == 1:
44
+ arg, = args
45
+ return arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None
46
+ return (*(arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None for arg in args),)
47
+
48
+
49
+ def _unpack_group_batch(N, *args):
50
+ # S, G * N, ... -> S, G, N, ...
51
+ if len(args) == 1:
52
+ arg, = args
53
+ return arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None
54
+ return (*(arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None for arg in args),)
src/preprocessing/preprocessing.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Imports
2
+ import os
3
+ import copy
4
+ import torch
5
+ import glob
6
+ import pandas as pd
7
+ import pickle
8
+ from xml.dom import minidom
9
+ from svgpathtools import svg2paths2
10
+ from svgpathtools import wsvg
11
+ import sys
12
+ sys.path.append(os.getcwd())
13
+ from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
14
+ from src.preprocessing.deepsvg.deepsvg_config import config_hierarchical_ordered
15
+ from src.preprocessing.deepsvg.deepsvg_utils import train_utils
16
+ from src.preprocessing.deepsvg.deepsvg_utils import utils
17
+ from src.preprocessing.deepsvg.deepsvg_dataloader import svg_dataset
18
+
19
+ # ---- Methods for embedding logos ----
20
+
21
+ def compute_embedding_folder(folder_path: str, model_path: str, save: str = None) -> pd.DataFrame:
22
+ data_list = []
23
+ for file in os.listdir(folder_path):
24
+ print('File: ' + file)
25
+ try:
26
+ embedding = compute_embedding(os.path.join(folder_path, file), model_path)
27
+ embedding['filename'] = file
28
+ data_list.append(embedding)
29
+ except:
30
+ print('Embedding failed')
31
+ print('Concatenating')
32
+ data = pd.concat(data_list)
33
+ if not save == None:
34
+ output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
35
+ pickle.dump(data, output)
36
+ output.close()
37
+ return data
38
+
39
+
40
+ def compute_embedding(path: str, model_path: str, save: str = None) -> pd.DataFrame:
41
+ # Convert all primitives to SVG paths - TODO text
42
+ paths, attributes, svg_attributes = svg2paths2(path) # In previous project, this is performed at the end
43
+ wsvg(paths, attributes=attributes, svg_attributes=svg_attributes, filename=path)
44
+
45
+ svg = SVG.load_svg(path)
46
+ svg.normalize() # Using DeepSVG normalize instead of expanding viewbox - TODO check is this equal?
47
+ svg_str = svg.to_str()
48
+
49
+ # Assign animation id to every path - TODO this changes the original logo!
50
+ document = minidom.parseString(svg_str)
51
+ paths = document.getElementsByTagName('path')
52
+ for i in range(len(paths)):
53
+ paths[i].setAttribute('animation_id', str(i))
54
+ with open(path, 'wb') as svg_file:
55
+ svg_file.write(document.toxml(encoding='iso-8859-1'))
56
+
57
+ # Decompose SVGs
58
+
59
+ decomposed_svgs = {}
60
+
61
+ for i in range(len(paths)):
62
+ doc_temp = copy.deepcopy(document)
63
+ paths_temp = doc_temp.getElementsByTagName('path')
64
+ current_path = paths_temp[i]
65
+ # Iteratively choose path i and remove all others
66
+ remove_temp = paths_temp[:i] + paths_temp[i+1:]
67
+ for path in remove_temp:
68
+ if not path.parentNode.nodeName == 'clipPath':
69
+ path.parentNode.removeChild(path)
70
+ # Check for style attributes; add in case there are none
71
+ if len(current_path.getAttribute('style')) <= 0:
72
+ current_path.setAttribute('stroke', 'black')
73
+ current_path.setAttribute('stroke-width', '2')
74
+ id = current_path.getAttribute('animation_id')
75
+ decomposed_svgs[id] = doc_temp.toprettyxml(encoding='iso-8859-1')
76
+ doc_temp.unlink()
77
+ #print(decomposed_svgs)
78
+ meta = {}
79
+ for id in decomposed_svgs:
80
+ svg_d_str = decomposed_svgs[id]
81
+ # Load into SVG and canonicalize
82
+ current_svg = SVG.from_str(svg_d_str)
83
+ # Canonicalize
84
+ current_svg.canonicalize() # Applies DeepSVG canonicalize; previously custom methods were used
85
+ decomposed_svgs[id] = current_svg.to_str()
86
+ if not os.path.exists('data/temp_svg'):
87
+ os.mkdir('data/temp_svg')
88
+ with open(('data/temp_svg/path_' + str(id)) + '.svg', 'w') as svg_file:
89
+ svg_file.write(decomposed_svgs[id])
90
+
91
+ # Collect metadata
92
+ len_groups = [path_group.total_len() for path_group in current_svg.svg_path_groups]
93
+ start_pos = [path_group.svg_paths[0].start_pos for path_group in current_svg.svg_path_groups]
94
+ try:
95
+ total_len = sum(len_groups)
96
+ nb_groups = len(len_groups)
97
+ max_len_group = max(len_groups)
98
+ except:
99
+ total_len = 0
100
+ nb_groups = 0
101
+ max_len_group = 0
102
+
103
+ meta[id] = {
104
+ 'id': id,
105
+ 'total_len': total_len,
106
+ 'nb_groups': nb_groups,
107
+ 'len_groups': len_groups,
108
+ 'max_len_group': max_len_group,
109
+ 'start_pos': start_pos
110
+ }
111
+ metadata = pd.DataFrame(meta.values())
112
+ #print(metadata)
113
+ if not os.path.exists('data/metadata'):
114
+ os.mkdir('data/metadata')
115
+ metadata.to_csv('data/metadata/metadata.csv', index=False)
116
+ # Load pretrained DeepSVG model
117
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
118
+ cfg = config_hierarchical_ordered.Config()
119
+ model = cfg.make_model().to(device)
120
+ train_utils.load_model(model_path, model)
121
+ model.eval()
122
+ # Load dataset
123
+ cfg.data_dir = 'data/temp_svg/'
124
+ cfg.meta_filepath = 'data/metadata/metadata.csv'
125
+ dataset = svg_dataset.load_dataset(cfg)
126
+ svg_files = glob.glob('data/temp_svg/*.svg')
127
+ #print(svg_files)
128
+ svg_list = []
129
+ for svg_file in svg_files:
130
+ id = svg_file.split('\\')[1].split('_')[1].split('.')[0]
131
+ # Preprocessing
132
+ svg = SVG.load_svg(svg_file)
133
+ svg = dataset.simplify(svg)
134
+ svg = dataset.preprocess(svg, augment=False)
135
+ data = dataset.get(svg=svg)
136
+ # Get embedding
137
+ model_args = utils.batchify((data[key] for key in cfg.model_args), device)
138
+ with torch.no_grad():
139
+ z = model(*model_args, encode_mode=True).cpu().numpy()[0][0][0]
140
+ dict_data = {
141
+ 'animation_id': id,
142
+ 'embedding': z
143
+ }
144
+ svg_list.append(dict_data)
145
+ data = pd.DataFrame.from_records(svg_list, index='animation_id')['embedding'].apply(pd.Series)
146
+ data.reset_index(level=0, inplace=True)
147
+ data.dropna(inplace=True)
148
+ data.reset_index(drop=True, inplace=True)
149
+ if not save == None:
150
+ output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
151
+ pickle.dump(data, output)
152
+ output.close()
153
+ print('Embedding computed')
154
+ return data
155
+
156
+
157
+ #compute_embedding_folder('data/raw_dataset', 'src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar', 'data/embedding')