Spaces:
Running
Running
Daniel Gil-U Fuhge
commited on
Commit
•
e17e8cc
1
Parent(s):
2f22ac0
add model files
Browse files- AnimationTransformer.py +272 -0
- models/animation_transformer.pth +3 -0
- models/reward_function_mode_state_dict.pth +3 -0
- src/postprocessing/__init__.py +0 -0
- src/postprocessing/get_style_attributes.py +318 -0
- src/postprocessing/get_svg_color_tendency.py +19 -0
- src/postprocessing/get_svg_size_pos.py +268 -0
- src/postprocessing/insert_animation.py +333 -0
- src/postprocessing/logo_0.svg +809 -0
- src/postprocessing/postprocessing.py +604 -0
- src/postprocessing/transform_animation_predictor_output.py +78 -0
- src/preprocessing/deepsvg/deepsvg_config/config.py +106 -0
- src/preprocessing/deepsvg/deepsvg_config/config_hierarchical_ordered.py +29 -0
- src/preprocessing/deepsvg/deepsvg_config/default_icons.py +102 -0
- src/preprocessing/deepsvg/deepsvg_dataloader/svg_dataset.py +239 -0
- src/preprocessing/deepsvg/deepsvg_difflib/tensor.py +305 -0
- src/preprocessing/deepsvg/deepsvg_models/basic_blocks.py +70 -0
- src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar +0 -0
- src/preprocessing/deepsvg/deepsvg_models/layers/attention.py +166 -0
- src/preprocessing/deepsvg/deepsvg_models/layers/functional.py +261 -0
- src/preprocessing/deepsvg/deepsvg_models/layers/improved_transformer.py +146 -0
- src/preprocessing/deepsvg/deepsvg_models/layers/positional_encoding.py +48 -0
- src/preprocessing/deepsvg/deepsvg_models/layers/transformer.py +398 -0
- src/preprocessing/deepsvg/deepsvg_models/loss.py +70 -0
- src/preprocessing/deepsvg/deepsvg_models/model.py +484 -0
- src/preprocessing/deepsvg/deepsvg_models/model_config.py +113 -0
- src/preprocessing/deepsvg/deepsvg_models/model_utils.py +89 -0
- src/preprocessing/deepsvg/deepsvg_schedulers/warmup.py +68 -0
- src/preprocessing/deepsvg/deepsvg_svglib/geom.py +493 -0
- src/preprocessing/deepsvg/deepsvg_svglib/svg.py +579 -0
- src/preprocessing/deepsvg/deepsvg_svglib/svg_command.py +531 -0
- src/preprocessing/deepsvg/deepsvg_svglib/svg_path.py +659 -0
- src/preprocessing/deepsvg/deepsvg_svglib/svg_primitive.py +452 -0
- src/preprocessing/deepsvg/deepsvg_svglib/svglib_utils.py +95 -0
- src/preprocessing/deepsvg/deepsvg_svglib/util_fns.py +22 -0
- src/preprocessing/deepsvg/deepsvg_utils/train_utils.py +241 -0
- src/preprocessing/deepsvg/deepsvg_utils/utils.py +54 -0
- src/preprocessing/preprocessing.py +157 -0
AnimationTransformer.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
import dataset_helper
|
8 |
+
|
9 |
+
|
10 |
+
class AnimationTransformer(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
dim_model, # hidden_size; corresponds to embedding length
|
14 |
+
num_heads,
|
15 |
+
num_encoder_layers,
|
16 |
+
num_decoder_layers,
|
17 |
+
dropout_p,
|
18 |
+
use_positional_encoder=True
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.model_type = "Transformer"
|
23 |
+
self.dim_model = dim_model
|
24 |
+
|
25 |
+
# TODO: Currently left out, as input sequence shuffled. Later check if use is beneficial.
|
26 |
+
self.use_positional_encoder = use_positional_encoder
|
27 |
+
self.positional_encoder = PositionalEncoding(
|
28 |
+
dim_model=dim_model,
|
29 |
+
dropout_p=dropout_p
|
30 |
+
)
|
31 |
+
|
32 |
+
self.transformer = nn.Transformer(
|
33 |
+
d_model=dim_model,
|
34 |
+
nhead=num_heads,
|
35 |
+
num_encoder_layers=num_encoder_layers,
|
36 |
+
num_decoder_layers=num_decoder_layers,
|
37 |
+
dropout=dropout_p,
|
38 |
+
batch_first=True
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, src, tgt, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
|
42 |
+
# Src size must be (batch_size, src sequence length)
|
43 |
+
# Tgt size must be (batch_size, tgt sequence length)
|
44 |
+
|
45 |
+
if self.use_positional_encoder:
|
46 |
+
src = self.positional_encoder(src)
|
47 |
+
tgt = self.positional_encoder(tgt)
|
48 |
+
|
49 |
+
# Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
|
50 |
+
out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask,
|
51 |
+
tgt_key_padding_mask=tgt_key_padding_mask)
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
def get_tgt_mask(size) -> torch.tensor:
|
56 |
+
# Generates a square matrix where each row allows one word more to be seen
|
57 |
+
mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
|
58 |
+
mask = mask.float()
|
59 |
+
mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
|
60 |
+
mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
|
61 |
+
|
62 |
+
# EX for size=5:
|
63 |
+
# [[0., -inf, -inf, -inf, -inf],
|
64 |
+
# [0., 0., -inf, -inf, -inf],
|
65 |
+
# [0., 0., 0., -inf, -inf],
|
66 |
+
# [0., 0., 0., 0., -inf],
|
67 |
+
# [0., 0., 0., 0., 0.]]
|
68 |
+
|
69 |
+
return mask
|
70 |
+
|
71 |
+
|
72 |
+
def create_pad_mask(matrix: torch.tensor) -> torch.tensor:
|
73 |
+
pad_masks = []
|
74 |
+
|
75 |
+
# Iterate over each sequence in the batch.
|
76 |
+
for i in range(0, matrix.size(0)):
|
77 |
+
sequence = []
|
78 |
+
|
79 |
+
# Iterate over each element in the sequence and append True if padding value
|
80 |
+
for j in range(0, matrix.size(1)):
|
81 |
+
sequence.append(matrix[i, j, 0] == dataset_helper.PADDING_VALUE)
|
82 |
+
|
83 |
+
pad_masks.append(sequence)
|
84 |
+
|
85 |
+
#print("matrix", matrix, matrix.shape, "pad_mask", pad_masks)
|
86 |
+
return torch.tensor(pad_masks)
|
87 |
+
|
88 |
+
|
89 |
+
def _transformer_call_in_loops(model, batch, device, loss_function):
|
90 |
+
source, target = batch[0], batch[1]
|
91 |
+
source, target = source.to(device), target.to(device)
|
92 |
+
|
93 |
+
# First index is all batch entries, second is
|
94 |
+
target_input = target[:, :-1] # trg input is offset by one (SOS token and excluding EOS)
|
95 |
+
target_expected = target[:, 1:] # trg is offset by one (excluding SOS token)
|
96 |
+
|
97 |
+
# SOS - 1 - 2 - 3 - 4 - EOS - PAD - PAD // target_input
|
98 |
+
# 1 - 2 - 3 - 4 - EOS - PAD - PAD - PAD // target_expected
|
99 |
+
|
100 |
+
# Get mask to mask out the next words
|
101 |
+
tgt_mask = get_tgt_mask(target_input.size(1)).to(device)
|
102 |
+
|
103 |
+
# Standard training except we pass in y_input and tgt_mask
|
104 |
+
prediction = model(source, target_input,
|
105 |
+
tgt_mask=tgt_mask,
|
106 |
+
src_key_padding_mask=create_pad_mask(source).to(device),
|
107 |
+
# Mask with expected as EOS is no input (see above)
|
108 |
+
tgt_key_padding_mask=create_pad_mask(target_expected).to(device))
|
109 |
+
|
110 |
+
return loss_function(prediction, target_expected, create_pad_mask(target_expected).to(device))
|
111 |
+
#return loss_function(prediction, target_expected)
|
112 |
+
|
113 |
+
def train_loop(model, opt, loss_function, dataloader, device):
|
114 |
+
model.train()
|
115 |
+
total_loss = 0
|
116 |
+
|
117 |
+
t0 = time.time()
|
118 |
+
i = 1
|
119 |
+
for batch in dataloader:
|
120 |
+
loss = _transformer_call_in_loops(model, batch, device, loss_function)
|
121 |
+
|
122 |
+
opt.zero_grad()
|
123 |
+
loss.backward()
|
124 |
+
opt.step()
|
125 |
+
|
126 |
+
total_loss += loss.detach().item()
|
127 |
+
|
128 |
+
if i == 1 or i % 10 == 0:
|
129 |
+
elapsed_time = time.time() - t0
|
130 |
+
total_expected = elapsed_time / i * len(dataloader)
|
131 |
+
print(f">> {i}: Time per Batch {elapsed_time / i : .2f}s | "
|
132 |
+
f"Total expected {total_expected / 60 : .2f} min | "
|
133 |
+
f"Remaining {(total_expected - elapsed_time) / 60 : .2f} min ")
|
134 |
+
i += 1
|
135 |
+
|
136 |
+
print(f">> Epoch time: {(time.time() - t0)/60:.2f} min")
|
137 |
+
return total_loss / len(dataloader)
|
138 |
+
|
139 |
+
|
140 |
+
def validation_loop(model, loss_function, dataloader, device):
|
141 |
+
model.eval()
|
142 |
+
total_loss = 0
|
143 |
+
|
144 |
+
with torch.no_grad():
|
145 |
+
for batch in dataloader:
|
146 |
+
loss = _transformer_call_in_loops(model, batch, device, loss_function)
|
147 |
+
|
148 |
+
total_loss += loss.detach().item()
|
149 |
+
|
150 |
+
return total_loss / len(dataloader)
|
151 |
+
|
152 |
+
|
153 |
+
def fit(model, optimizer, loss_function, train_dataloader, val_dataloader, epochs, device):
|
154 |
+
train_loss_list, validation_loss_list = [], []
|
155 |
+
|
156 |
+
print("Training and validating model")
|
157 |
+
for epoch in range(epochs):
|
158 |
+
print("-" * 25, f"Epoch {epoch + 1}", "-" * 25)
|
159 |
+
|
160 |
+
train_loss = train_loop(model, optimizer, loss_function, train_dataloader, device)
|
161 |
+
train_loss_list += [train_loss]
|
162 |
+
|
163 |
+
validation_loss = validation_loop(model, loss_function, val_dataloader, device)
|
164 |
+
validation_loss_list += [validation_loss]
|
165 |
+
|
166 |
+
print(f"Training loss: {train_loss:.4f}")
|
167 |
+
print(f"Validation loss: {validation_loss:.4f}")
|
168 |
+
print()
|
169 |
+
|
170 |
+
return train_loss_list, validation_loss_list
|
171 |
+
|
172 |
+
|
173 |
+
def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True):
|
174 |
+
if backpropagate:
|
175 |
+
model.train()
|
176 |
+
else:
|
177 |
+
model.eval()
|
178 |
+
|
179 |
+
source_sequence = source_sequence.float().to(device)
|
180 |
+
y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
|
181 |
+
|
182 |
+
i = 0
|
183 |
+
while i < max_length:
|
184 |
+
# Get source mask
|
185 |
+
prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
|
186 |
+
# tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
|
187 |
+
src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
|
188 |
+
|
189 |
+
next_embedding = prediction[0, -1, :] # prediction on last token
|
190 |
+
pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
|
191 |
+
#print(pred_deep_svg, pred_type, pred_parameters)
|
192 |
+
pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
|
193 |
+
device)
|
194 |
+
|
195 |
+
# === TYPE ===
|
196 |
+
# Apply Softmax
|
197 |
+
type_softmax = torch.softmax(pred_type, dim=0)
|
198 |
+
type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
|
199 |
+
animation_type = torch.argmax(type_softmax, dim=0)
|
200 |
+
|
201 |
+
# Break if EOS is most likely
|
202 |
+
if animation_type == 0:
|
203 |
+
print("END OF ANIMATION")
|
204 |
+
y_input = torch.cat((y_input, sos_token.unsqueeze(0).to(device)), dim=0)
|
205 |
+
return y_input
|
206 |
+
|
207 |
+
pred_type = torch.zeros(11)
|
208 |
+
pred_type[animation_type] = 1
|
209 |
+
|
210 |
+
# === DEEP SVG ===
|
211 |
+
# Find the closest path
|
212 |
+
distances = [torch.norm(pred_deep_svg - embedding[:-26]) for embedding in source_sequence]
|
213 |
+
closest_index = distances.index(min(distances))
|
214 |
+
closest_token = source_sequence[closest_index]
|
215 |
+
|
216 |
+
# === PARAMETERS ===
|
217 |
+
# overwrite unused parameters
|
218 |
+
for j in range(len(pred_parameters)):
|
219 |
+
if j in dataset_helper.ANIMATION_PARAMETER_INDICES[int(animation_type)]:
|
220 |
+
continue
|
221 |
+
pred_parameters[j] = -1
|
222 |
+
|
223 |
+
# === SEQUENCE ===
|
224 |
+
y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
|
225 |
+
y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
|
226 |
+
|
227 |
+
# === INFO PRINT ===
|
228 |
+
if showResult:
|
229 |
+
print(f"{int(y_input.size(0))}: Path {closest_index} ({round(float(distances[closest_index]), 3)}) "
|
230 |
+
f"got animation {animation_type} ({round(float(type_softmax[animation_type]), 3)}%) "
|
231 |
+
f"with parameters {[round(num, 2) for num in pred_parameters.tolist()]}")
|
232 |
+
|
233 |
+
i += 1
|
234 |
+
|
235 |
+
return y_input
|
236 |
+
|
237 |
+
|
238 |
+
class PositionalEncoding(nn.Module):
|
239 |
+
def __init__(self, dim_model, dropout_p, max_len=5000):
|
240 |
+
"""
|
241 |
+
Initializes the PositionalEncoding module which injects information about the relative or absolute position
|
242 |
+
of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the
|
243 |
+
two can be summed. Uses a sinusoidal pattern for positional encoding.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
dim_model (int): The dimension of the embeddings and the expected dimension of the positional encoding.
|
247 |
+
dropout_p (float): Dropout probability to be applied to the summed embeddings and positional encodings.
|
248 |
+
max_len (int): The max length of the sequences for which positional encodings are precomputed and stored.
|
249 |
+
"""
|
250 |
+
super(PositionalEncoding, self).__init__()
|
251 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
252 |
+
|
253 |
+
position = torch.arange(max_len).unsqueeze(1)
|
254 |
+
div_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0) / dim_model))
|
255 |
+
pos_encoding = torch.zeros(max_len, 1, dim_model)
|
256 |
+
pos_encoding[:, 0, 0::2] = torch.sin(position * div_term)
|
257 |
+
pos_encoding[:, 0, 1::2] = torch.cos(position * div_term)
|
258 |
+
|
259 |
+
self.register_buffer('pos_encoding', pos_encoding)
|
260 |
+
|
261 |
+
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
|
262 |
+
"""
|
263 |
+
Applies positional encoding to the input embeddings and applies dropout.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
embedding (torch.Tensor): The input embeddings with shape [batch_size, seq_len, dim_model]
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
torch.Tensor: The embeddings with positional encoding applied, and dropout, having the same shape as the
|
270 |
+
input token embeddings [seq_len, batch_size, dim_model].
|
271 |
+
"""
|
272 |
+
return self.dropout(embedding + self.pos_encoding[:embedding.size(0), :])
|
models/animation_transformer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12ae92d0b1a5ada8a8681122f76ea7c4e6b3fdf0169dd4b3a5d908899e563f86
|
3 |
+
size 60658902
|
models/reward_function_mode_state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49ea58f01ad6a281e005b9b2793e53f901037402c96411f444ed9630fff05fbf
|
3 |
+
size 111027985
|
src/postprocessing/__init__.py
ADDED
File without changes
|
src/postprocessing/get_style_attributes.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from svgpathtools import svg2paths
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from xml.dom import minidom
|
5 |
+
|
6 |
+
pd.options.mode.chained_assignment = None # default='warn'
|
7 |
+
|
8 |
+
|
9 |
+
def get_style_attributes_svg(file):
|
10 |
+
""" Get style attributes of an SVG.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
file (str): Path of SVG file.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
pd.DataFrame: Dataframe containing the attributes of each path.
|
17 |
+
|
18 |
+
"""
|
19 |
+
local_styles = get_local_style_attributes(file)
|
20 |
+
global_styles = get_global_style_attributes(file)
|
21 |
+
global_group_styles = get_global_group_style_attributes(file)
|
22 |
+
return combine_style_attributes(local_styles, global_styles, global_group_styles)
|
23 |
+
|
24 |
+
|
25 |
+
def get_style_attributes_path(file, animation_id, attribute):
|
26 |
+
""" Get style attributes of a specific path in an SVG.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
file (str): Path of SVG file.
|
30 |
+
animation_id (int): ID of element.
|
31 |
+
attribute (str): One of the following: fill, stroke, stroke_width, opacity, stroke_opacity.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
str: Attribute of specific path.
|
35 |
+
|
36 |
+
"""
|
37 |
+
styles = get_style_attributes_svg(file)
|
38 |
+
styles_animation_id = styles[styles["animation_id"] == str(animation_id)]
|
39 |
+
return styles_animation_id.iloc[0][attribute]
|
40 |
+
|
41 |
+
|
42 |
+
def parse_svg(file):
|
43 |
+
""" Parse a SVG file.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
file (str): Path of SVG file.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
list, list: List of path objects, list of dictionaries containing the attributes of each path.
|
50 |
+
|
51 |
+
"""
|
52 |
+
paths, attrs = svg2paths(file)
|
53 |
+
return paths, attrs
|
54 |
+
|
55 |
+
|
56 |
+
def get_local_style_attributes(file):
|
57 |
+
""" Generate dataframe containing local style attributes of an SVG.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
file (str): Path of SVG file.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
pd.DataFrame: Dataframe containing filename, animation_id, class, fill, stroke, stroke_width, opacity, stroke_opacity.
|
64 |
+
|
65 |
+
"""
|
66 |
+
return pd.DataFrame.from_records(_get_local_style_attributes(file))
|
67 |
+
|
68 |
+
|
69 |
+
def _get_local_style_attributes(file):
|
70 |
+
try:
|
71 |
+
_, attributes = parse_svg(file)
|
72 |
+
except:
|
73 |
+
print(f'{file}: Attributes not defined.')
|
74 |
+
for i, attr in enumerate(attributes):
|
75 |
+
animation_id = attr['animation_id']
|
76 |
+
class_ = ''
|
77 |
+
fill = '#000000'
|
78 |
+
stroke = '#000000'
|
79 |
+
stroke_width = '0'
|
80 |
+
opacity = '1.0'
|
81 |
+
stroke_opacity = '1.0'
|
82 |
+
|
83 |
+
if 'style' in attr:
|
84 |
+
a = attr['style']
|
85 |
+
if a.find('fill') != -1:
|
86 |
+
fill = a.split('fill:', 1)[-1].split(';', 1)[0]
|
87 |
+
if a.find('stroke') != -1:
|
88 |
+
stroke = a.split('stroke:', 1)[-1].split(';', 1)[0]
|
89 |
+
if a.find('stroke-width') != -1:
|
90 |
+
stroke_width = a.split('stroke-width:', 1)[-1].split(';', 1)[0]
|
91 |
+
if a.find('opacity') != -1:
|
92 |
+
opacity = a.split('opacity:', 1)[-1].split(';', 1)[0]
|
93 |
+
if a.find('stroke-opacity') != -1:
|
94 |
+
stroke_opacity = a.split('stroke-opacity:', 1)[-1].split(';', 1)[0]
|
95 |
+
else:
|
96 |
+
if 'fill' in attr:
|
97 |
+
fill = attr['fill']
|
98 |
+
if 'stroke' in attr:
|
99 |
+
stroke = attr['stroke']
|
100 |
+
if 'stroke-width' in attr:
|
101 |
+
stroke_width = attr['stroke-width']
|
102 |
+
if 'opacity' in attr:
|
103 |
+
opacity = attr['opacity']
|
104 |
+
if 'stroke-opacity' in attr:
|
105 |
+
stroke_opacity = attr['stroke-opacity']
|
106 |
+
|
107 |
+
if 'class' in attr:
|
108 |
+
class_ = attr['class']
|
109 |
+
|
110 |
+
# transform None and RGB to hex
|
111 |
+
if '#' not in fill and fill != '':
|
112 |
+
fill = transform_to_hex(fill)
|
113 |
+
if '#' not in stroke and stroke != '':
|
114 |
+
stroke = transform_to_hex(stroke)
|
115 |
+
|
116 |
+
yield dict(filename=file.split('.svg')[0], animation_id=animation_id, class_=class_, fill=fill, stroke=stroke,
|
117 |
+
stroke_width=stroke_width, opacity=opacity, stroke_opacity=stroke_opacity)
|
118 |
+
|
119 |
+
|
120 |
+
def get_global_style_attributes(file):
|
121 |
+
""" Generate dataframe containing global style attributes of an SVG.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
file (str): Path of SVG file.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
pd.DataFrame: Dataframe containing filename, class, fill, stroke, stroke_width, opacity, stroke_opacity.
|
128 |
+
|
129 |
+
"""
|
130 |
+
return pd.DataFrame.from_records(_get_global_style_attributes(file))
|
131 |
+
|
132 |
+
|
133 |
+
def _get_global_style_attributes(file):
|
134 |
+
doc = minidom.parse(file)
|
135 |
+
style = doc.getElementsByTagName('style')
|
136 |
+
for i, attr in enumerate(style):
|
137 |
+
a = attr.toxml()
|
138 |
+
for j in range(0, len(a.split(';}')) - 1):
|
139 |
+
fill = ''
|
140 |
+
stroke = ''
|
141 |
+
stroke_width = ''
|
142 |
+
opacity = ''
|
143 |
+
stroke_opacity = ''
|
144 |
+
attr = a.split(';}')[j]
|
145 |
+
class_ = attr.split('.', 1)[-1].split('{', 1)[0]
|
146 |
+
if attr.find('fill:') != -1:
|
147 |
+
fill = attr.split('fill:', 1)[-1].split(';', 1)[0]
|
148 |
+
if attr.find('stroke:') != -1:
|
149 |
+
stroke = attr.split('stroke:', 1)[-1].split(';', 1)[0]
|
150 |
+
if attr.find('stroke-width:') != -1:
|
151 |
+
stroke_width = attr.split('stroke-width:', 1)[-1].split(';', 1)[0]
|
152 |
+
if attr.find('opacity:') != -1:
|
153 |
+
opacity = attr.split('opacity:', 1)[-1].split(';', 1)[0]
|
154 |
+
if attr.find('stroke-opacity:') != -1:
|
155 |
+
stroke_opacity = attr.split('stroke-opacity:', 1)[-1].split(';', 1)[0]
|
156 |
+
|
157 |
+
# transform None and RGB to hex
|
158 |
+
if '#' not in fill and fill != '':
|
159 |
+
fill = transform_to_hex(fill)
|
160 |
+
if '#' not in stroke and stroke != '':
|
161 |
+
stroke = transform_to_hex(stroke)
|
162 |
+
|
163 |
+
yield dict(filename=file.split('.svg')[0], class_=class_, fill=fill, stroke=stroke,
|
164 |
+
stroke_width=stroke_width, opacity=opacity, stroke_opacity=stroke_opacity)
|
165 |
+
|
166 |
+
|
167 |
+
def get_global_group_style_attributes(file):
|
168 |
+
""" Generate dataframe containing global style attributes defined through <g> tags of an SVG.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
file (str): Path of SVG file.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
pd.DataFrame: Dataframe containing filename, href, animation_id, fill, stroke, stroke_width, opacity, stroke_opacity.
|
175 |
+
|
176 |
+
"""
|
177 |
+
df_group_animation_id_matching = pd.DataFrame.from_records(_get_group_animation_id_matching(file))
|
178 |
+
|
179 |
+
df_group_attributes = pd.DataFrame.from_records(_get_global_group_style_attributes(file))
|
180 |
+
df_group_attributes.drop_duplicates(inplace=True)
|
181 |
+
df_group_attributes.replace("", float("NaN"), inplace=True)
|
182 |
+
df_group_attributes.dropna(thresh=3, inplace=True)
|
183 |
+
|
184 |
+
if "href" in df_group_attributes.columns:
|
185 |
+
df_group_attributes.dropna(subset=["href"], inplace=True)
|
186 |
+
|
187 |
+
if df_group_attributes.empty:
|
188 |
+
return df_group_attributes
|
189 |
+
else:
|
190 |
+
return df_group_animation_id_matching.merge(df_group_attributes, how='left', on=['filename', 'href'])
|
191 |
+
|
192 |
+
|
193 |
+
def _get_global_group_style_attributes(file):
|
194 |
+
doc = minidom.parse(file)
|
195 |
+
groups = doc.getElementsByTagName('g')
|
196 |
+
for i, _ in enumerate(groups):
|
197 |
+
style = groups[i].getAttribute('style')
|
198 |
+
href = ''
|
199 |
+
fill = ''
|
200 |
+
stroke = ''
|
201 |
+
stroke_width = ''
|
202 |
+
opacity = ''
|
203 |
+
stroke_opacity = ''
|
204 |
+
if len(groups[i].getElementsByTagName('use')) != 0:
|
205 |
+
href = groups[i].getElementsByTagName('use')[0].getAttribute('xlink:href')
|
206 |
+
if style != '':
|
207 |
+
attributes = style.split(';')
|
208 |
+
for j, _ in enumerate(attributes):
|
209 |
+
attr = attributes[j]
|
210 |
+
if attr.find('fill:') != -1:
|
211 |
+
fill = attr.split('fill:', 1)[-1].split(';', 1)[0]
|
212 |
+
if attr.find('stroke:') != -1:
|
213 |
+
stroke = attr.split('stroke:', 1)[-1].split(';', 1)[0]
|
214 |
+
if attr.find('stroke-width:') != -1:
|
215 |
+
stroke_width = attr.split('stroke-width:', 1)[-1].split(';', 1)[0]
|
216 |
+
if attr.find('opacity:') != -1:
|
217 |
+
opacity = attr.split('opacity:', 1)[-1].split(';', 1)[0]
|
218 |
+
if attr.find('stroke-opacity:') != -1:
|
219 |
+
stroke_opacity = attr.split('stroke-opacity:', 1)[-1].split(';', 1)[0]
|
220 |
+
else:
|
221 |
+
fill = groups[i].getAttribute('fill')
|
222 |
+
stroke = groups[i].getAttribute('stroke')
|
223 |
+
stroke_width = groups[i].getAttribute('stroke-width')
|
224 |
+
opacity = groups[i].getAttribute('opacity')
|
225 |
+
stroke_opacity = groups[i].getAttribute('stroke-opacity')
|
226 |
+
|
227 |
+
# transform None and RGB to hex
|
228 |
+
if '#' not in fill and fill != '':
|
229 |
+
fill = transform_to_hex(fill)
|
230 |
+
if '#' not in stroke and stroke != '':
|
231 |
+
stroke = transform_to_hex(stroke)
|
232 |
+
|
233 |
+
yield dict(filename=file.split('.svg')[0], href=href.replace('#', ''), fill=fill, stroke=stroke,
|
234 |
+
stroke_width=stroke_width, opacity=opacity, stroke_opacity=stroke_opacity)
|
235 |
+
|
236 |
+
|
237 |
+
def _get_group_animation_id_matching(file):
|
238 |
+
doc = minidom.parse(file)
|
239 |
+
try:
|
240 |
+
symbol = doc.getElementsByTagName('symbol')
|
241 |
+
for i, _ in enumerate(symbol):
|
242 |
+
href = symbol[i].getAttribute('id')
|
243 |
+
animation_id = symbol[i].getElementsByTagName('path')[0].getAttribute('animation_id')
|
244 |
+
yield dict(filename=file.split('.svg')[0], href=href, animation_id=animation_id)
|
245 |
+
except:
|
246 |
+
defs = doc.getElementsByTagName('defs')
|
247 |
+
for i, _ in enumerate(defs):
|
248 |
+
href = defs[i].getElementsByTagName('symbol')[0].getAttribute('id')
|
249 |
+
animation_id = defs[i].getElementsByTagName('clipPath')[0].getElementsByTagName('path')[0].getAttribute('animation_id')
|
250 |
+
yield dict(filename=file.split('.svg')[0], href=href, animation_id=animation_id)
|
251 |
+
|
252 |
+
|
253 |
+
def combine_style_attributes(df_local, df_global, df_global_groups):
|
254 |
+
""" Combine local und global style attributes. Global attributes have priority.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
df_local (pd.DataFrame): Dataframe with local style attributes.
|
258 |
+
df_global (pd.DataFrame): Dataframe with global style attributes.
|
259 |
+
df_global_groups (pd.DataFrame): Dataframe with global style attributes defined through <g> tags.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
pd.DataFrame: Dataframe with all style attributes.
|
263 |
+
|
264 |
+
"""
|
265 |
+
if df_global.empty and df_global_groups.empty:
|
266 |
+
df_local.insert(loc=3, column='href', value="")
|
267 |
+
return df_local
|
268 |
+
|
269 |
+
if not df_global.empty:
|
270 |
+
df = df_local.merge(df_global, how='left', on=['filename', 'class_'])
|
271 |
+
df_styles = df[["filename", "animation_id", "class_"]]
|
272 |
+
df_styles["fill"] = _combine_columns(df, "fill")
|
273 |
+
df_styles["stroke"] = _combine_columns(df, "stroke")
|
274 |
+
df_styles["stroke_width"] = _combine_columns(df, "stroke_width")
|
275 |
+
df_styles["opacity"] = _combine_columns(df, "opacity")
|
276 |
+
df_styles["stroke_opacity"] = _combine_columns(df, "stroke_opacity")
|
277 |
+
df_local = df_styles.copy(deep=True)
|
278 |
+
if not df_global_groups.empty:
|
279 |
+
df = df_local.merge(df_global_groups, how='left', on=['filename', 'animation_id'])
|
280 |
+
df_styles = df[["filename", "animation_id", "class_", "href"]]
|
281 |
+
df_styles["href"] = df_styles["href"].fillna('')
|
282 |
+
df_styles["fill"] = _combine_columns(df, "fill")
|
283 |
+
df_styles["stroke"] = _combine_columns(df, "stroke")
|
284 |
+
df_styles["stroke_width"] = _combine_columns(df, "stroke_width")
|
285 |
+
df_styles["opacity"] = _combine_columns(df, "opacity")
|
286 |
+
df_styles["stroke_opacity"] = _combine_columns(df, "stroke_opacity")
|
287 |
+
|
288 |
+
return df_styles
|
289 |
+
|
290 |
+
|
291 |
+
def _combine_columns(df, col_name):
|
292 |
+
col = np.where(~df[f"{col_name}_y"].astype(str).isin(["", "nan"]),
|
293 |
+
df[f"{col_name}_y"], df[f"{col_name}_x"])
|
294 |
+
return col
|
295 |
+
|
296 |
+
|
297 |
+
def transform_to_hex(rgb):
|
298 |
+
""" Transform RGB to hex.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
rgb (str): RGB code.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
str: Hex code.
|
305 |
+
|
306 |
+
"""
|
307 |
+
if rgb == 'none':
|
308 |
+
return '#000000'
|
309 |
+
if 'rgb' in rgb:
|
310 |
+
rgb = rgb.replace('rgb(', '').replace(')', '')
|
311 |
+
if '%' in rgb:
|
312 |
+
rgb = rgb.replace('%', '')
|
313 |
+
rgb_list = rgb.split(',')
|
314 |
+
r_value, g_value, b_value = [int(float(i) / 100 * 255) for i in rgb_list]
|
315 |
+
else:
|
316 |
+
rgb_list = rgb.split(',')
|
317 |
+
r_value, g_value, b_value = [int(float(i)) for i in rgb_list]
|
318 |
+
return '#%02x%02x%02x' % (r_value, g_value, b_value)
|
src/postprocessing/get_svg_color_tendency.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.postprocessing.get_style_attributes import get_style_attributes_svg
|
2 |
+
|
3 |
+
|
4 |
+
def get_svg_color_tendencies(file):
|
5 |
+
""" Get two most frequent colors in SVG. Black and white are excluded.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
file (str): Path of SVG file.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
list: List of two most frequent colors in SVG.
|
12 |
+
|
13 |
+
"""
|
14 |
+
df = get_style_attributes_svg(file)
|
15 |
+
df = df[~df['fill'].isin(['#FFFFFF', '#ffffff'])]
|
16 |
+
colour_tendencies_list = df["fill"].value_counts()[:2].index.tolist()
|
17 |
+
colour_tendencies_list.append("#000000")
|
18 |
+
return colour_tendencies_list[:2]
|
19 |
+
|
src/postprocessing/get_svg_size_pos.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from xml.dom import minidom
|
2 |
+
from svgpathtools import svg2paths
|
3 |
+
|
4 |
+
|
5 |
+
def get_svg_size(file):
|
6 |
+
""" Get width and height of an SVG.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
file (str): Path of SVG file.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
float, float: Width and height of SVG.
|
13 |
+
|
14 |
+
"""
|
15 |
+
doc = minidom.parse(file)
|
16 |
+
width = doc.getElementsByTagName('svg')[0].getAttribute('width')
|
17 |
+
height = doc.getElementsByTagName('svg')[0].getAttribute('height')
|
18 |
+
|
19 |
+
if width != "" and height != "":
|
20 |
+
if not width[-1].isdigit():
|
21 |
+
width = width.replace('px', '').replace('pt', '')
|
22 |
+
if not height[-1].isdigit():
|
23 |
+
height = height.replace('px', '').replace('pt', '')
|
24 |
+
|
25 |
+
if width == "" or height == "" or not width[-1].isdigit() or not height[-1].isdigit():
|
26 |
+
# get bounding box of svg
|
27 |
+
xmin_svg, xmax_svg, ymin_svg, ymax_svg = 100, -100, 100, -100
|
28 |
+
paths, _ = svg2paths(file)
|
29 |
+
for path in paths:
|
30 |
+
xmin, xmax, ymin, ymax = path.bbox()
|
31 |
+
if xmin < xmin_svg:
|
32 |
+
xmin_svg = xmin
|
33 |
+
if xmax > xmax_svg:
|
34 |
+
xmax_svg = xmax
|
35 |
+
if ymin < ymin_svg:
|
36 |
+
ymin_svg = ymin
|
37 |
+
if ymax > ymax_svg:
|
38 |
+
ymax_svg = ymax
|
39 |
+
width = xmax_svg - xmin_svg
|
40 |
+
height = ymax_svg - ymin_svg
|
41 |
+
|
42 |
+
return float(width), float(height)
|
43 |
+
|
44 |
+
|
45 |
+
def get_svg_bbox(file):
|
46 |
+
""" Get bounding box coordinates of an SVG.
|
47 |
+
|
48 |
+
xmin, ymin: Upper left corner.
|
49 |
+
|
50 |
+
xmax, ymax: Lower right corner.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
file (str): Path of SVG file.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
float, float, float, float: Bounding box of SVG (xmin, xmax, ymin, ymax).
|
57 |
+
|
58 |
+
"""
|
59 |
+
try:
|
60 |
+
paths, _ = svg2paths(file)
|
61 |
+
except Exception as e:
|
62 |
+
print(f"{file}: svg2path fails. SVG bbox is computed by using get_svg_size. {e}")
|
63 |
+
width, height = get_svg_size(file)
|
64 |
+
return 0, width, 0, height
|
65 |
+
|
66 |
+
xmin_svg, xmax_svg, ymin_svg, ymax_svg = 100, -100, 100, -100
|
67 |
+
for path in paths:
|
68 |
+
try:
|
69 |
+
xmin, xmax, ymin, ymax = path.bbox()
|
70 |
+
if xmin < xmin_svg:
|
71 |
+
xmin_svg = xmin
|
72 |
+
if xmax > xmax_svg:
|
73 |
+
xmax_svg = xmax
|
74 |
+
if ymin < ymin_svg:
|
75 |
+
ymin_svg = ymin
|
76 |
+
if ymax > ymax_svg:
|
77 |
+
ymax_svg = ymax
|
78 |
+
except:
|
79 |
+
pass
|
80 |
+
|
81 |
+
return xmin_svg, xmax_svg, ymin_svg, ymax_svg
|
82 |
+
|
83 |
+
|
84 |
+
def get_path_bbox(file, animation_id):
|
85 |
+
""" Get bounding box coordinates of a path in an SVG.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
file (str): Path of SVG file.
|
89 |
+
animation_id (int): ID of element.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
float, float, float, float: Bounding box of path (xmin, xmax, ymin, ymax).
|
93 |
+
|
94 |
+
"""
|
95 |
+
try:
|
96 |
+
paths, attributes = svg2paths(file)
|
97 |
+
except Exception as e1:
|
98 |
+
print(f"{file}, animation ID {animation_id}: svg2path fails and path bbox cannot be computed. {e1}")
|
99 |
+
return 0, 0, 0, 0
|
100 |
+
|
101 |
+
for i, path in enumerate(paths):
|
102 |
+
if attributes[i]["animation_id"] == str(animation_id):
|
103 |
+
try:
|
104 |
+
xmin, xmax, ymin, ymax = path.bbox()
|
105 |
+
return xmin, xmax, ymin, ymax
|
106 |
+
except Exception as e2:
|
107 |
+
print(f"{file}, animation ID {animation_id}: svg2path fails and path bbox cannot be computed. {e2}")
|
108 |
+
return 0, 0, 0, 0
|
109 |
+
|
110 |
+
|
111 |
+
def get_midpoint_of_path_bbox(file, animation_id):
|
112 |
+
""" Get midpoint of bounding box of path.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
file (str): Path of SVG file.
|
116 |
+
animation_id (int): ID of element.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
float, float: Midpoint of bounding box of path (x_midpoint, y_midpoint).
|
120 |
+
|
121 |
+
"""
|
122 |
+
try:
|
123 |
+
xmin, xmax, ymin, ymax = get_path_bbox(file, animation_id)
|
124 |
+
x_midpoint = (xmin + xmax) / 2
|
125 |
+
y_midpoint = (ymin + ymax) / 2
|
126 |
+
|
127 |
+
return x_midpoint, y_midpoint
|
128 |
+
except Exception as e:
|
129 |
+
print(f'Could not get midpoint for file {file} and animation ID {animation_id}: {e}')
|
130 |
+
return 0, 0
|
131 |
+
|
132 |
+
|
133 |
+
def get_bbox_of_multiple_paths(file, animation_ids):
|
134 |
+
""" Get bounding box of multiple paths in an SVG.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
file (str): Path of SVG file.
|
138 |
+
animation_ids (list(int)): List of element IDs.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
float, float, float, float: Bounding box of given paths (xmin, xmax, ymin, ymax).
|
142 |
+
|
143 |
+
"""
|
144 |
+
try:
|
145 |
+
paths, attributes = svg2paths(file)
|
146 |
+
except Exception as e1:
|
147 |
+
print(f"{file}: svg2path fails and bbox of multiple paths cannot be computed. {e1}")
|
148 |
+
return 0, 0, 0, 0
|
149 |
+
|
150 |
+
xmin_paths, xmax_paths, ymin_paths, ymax_paths = 100, -100, 100, -100
|
151 |
+
|
152 |
+
for i, path in enumerate(paths):
|
153 |
+
if attributes[i]["animation_id"] in list(map(str, animation_ids)):
|
154 |
+
try:
|
155 |
+
xmin, xmax, ymin, ymax = path.bbox()
|
156 |
+
if xmin < xmin_paths:
|
157 |
+
xmin_paths = xmin
|
158 |
+
if xmax > xmax_paths:
|
159 |
+
xmax_paths = xmax
|
160 |
+
if ymin < ymin_paths:
|
161 |
+
ymin_paths = ymin
|
162 |
+
if ymax > ymax_paths:
|
163 |
+
ymax_paths = ymax
|
164 |
+
except:
|
165 |
+
pass
|
166 |
+
|
167 |
+
return xmin_paths, xmax_paths, ymin_paths, ymax_paths
|
168 |
+
|
169 |
+
|
170 |
+
def get_relative_path_pos(file, animation_id):
|
171 |
+
""" Get relative position of a path in an SVG.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
file (string): Path of SVG file.
|
175 |
+
animation_id (int): ID of element.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
float, float: Relative x- and y-position of path.
|
179 |
+
|
180 |
+
"""
|
181 |
+
path_midpoint_x, path_midpoint_y = get_midpoint_of_path_bbox(file, animation_id)
|
182 |
+
svg_xmin, svg_xmax, svg_ymin, svg_ymax = get_svg_bbox(file)
|
183 |
+
rel_x_position = (path_midpoint_x - svg_xmin) / (svg_xmax - svg_xmin)
|
184 |
+
rel_y_position = (path_midpoint_y - svg_ymin) / (svg_ymax - svg_ymin)
|
185 |
+
return rel_x_position, rel_y_position
|
186 |
+
|
187 |
+
|
188 |
+
def get_relative_pos_to_bounding_box_of_animated_paths(file, animation_id, animated_animation_ids):
|
189 |
+
""" Get relative position of a path to the bounding box of all animated paths.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
file (str): Path of SVG file.
|
193 |
+
animation_id (int): ID of element.
|
194 |
+
animated_animation_ids (list(int)): List of animated element IDs.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
float, float: Relative x- and y-position of path to bounding box of all animated paths.
|
198 |
+
|
199 |
+
"""
|
200 |
+
path_midpoint_x, path_midpoint_y = get_midpoint_of_path_bbox(file, animation_id)
|
201 |
+
xmin, xmax, ymin, ymax = get_bbox_of_multiple_paths(file, animated_animation_ids)
|
202 |
+
try:
|
203 |
+
rel_x_position = (path_midpoint_x - xmin) / (xmax - xmin)
|
204 |
+
except Exception as e1:
|
205 |
+
rel_x_position = 0.5
|
206 |
+
print(f"{file}, animation_id {animation_id}, animated_animation_ids {animated_animation_ids}: rel_x_position not defined and set to 0.5. {e1}")
|
207 |
+
try:
|
208 |
+
rel_y_position = (path_midpoint_y - ymin) / (ymax - ymin)
|
209 |
+
except Exception as e2:
|
210 |
+
rel_y_position = 0.5
|
211 |
+
print(f"{file}, animation_id {animation_id}, animated_animation_ids {animated_animation_ids}: rel_y_position not defined and set to 0.5. {e2}")
|
212 |
+
|
213 |
+
return rel_x_position, rel_y_position
|
214 |
+
|
215 |
+
|
216 |
+
def get_relative_path_size(file, animation_id):
|
217 |
+
""" Get relative size of a path in an SVG.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
file (str): Path of SVG file.
|
221 |
+
animation_id (int): ID of element.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
float, float: Relative width and height of path.
|
225 |
+
|
226 |
+
"""
|
227 |
+
svg_xmin, svg_xmax, svg_ymin, svg_ymax = get_svg_bbox(file)
|
228 |
+
svg_width = float(svg_xmax - svg_xmin)
|
229 |
+
svg_height = float(svg_ymax - svg_ymin)
|
230 |
+
|
231 |
+
path_xmin, path_xmax, path_ymin, path_ymax = get_path_bbox(file, animation_id)
|
232 |
+
path_width = float(path_xmax - path_xmin)
|
233 |
+
path_height = float(path_ymax - path_ymin)
|
234 |
+
|
235 |
+
rel_width = path_width / svg_width
|
236 |
+
rel_height = path_height / svg_height
|
237 |
+
|
238 |
+
return rel_width, rel_height
|
239 |
+
|
240 |
+
|
241 |
+
def get_begin_values_by_starting_pos(file, animation_ids, start=1, step=0.5):
|
242 |
+
""" Get begin values by sorting from left to right.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
file (str): Path of SVG file.
|
246 |
+
animation_ids (list(int)): List of element IDs.
|
247 |
+
start (float): First begin value.
|
248 |
+
step (float): Time between begin values.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
list: Begin values of element IDs.
|
252 |
+
|
253 |
+
"""
|
254 |
+
starting_point_list = []
|
255 |
+
begin_list = []
|
256 |
+
begin = start
|
257 |
+
for i in range(len(animation_ids)):
|
258 |
+
x, _, _, _ = get_path_bbox(file, animation_ids[i]) # get x value of upper left corner
|
259 |
+
starting_point_list.append(x)
|
260 |
+
begin_list.append(begin)
|
261 |
+
begin = begin + step
|
262 |
+
|
263 |
+
animation_id_order = [z for _, z in sorted(zip(starting_point_list, range(len(starting_point_list))))]
|
264 |
+
begin_values = [z for _, z in sorted(zip(animation_id_order, begin_list))]
|
265 |
+
|
266 |
+
return begin_values
|
267 |
+
|
268 |
+
|
src/postprocessing/insert_animation.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from xml.dom import minidom
|
3 |
+
from pathlib import Path
|
4 |
+
from src.postprocessing.get_svg_size_pos import get_midpoint_of_path_bbox, get_begin_values_by_starting_pos
|
5 |
+
from src.postprocessing.transform_animation_predictor_output import transform_animation_predictor_output
|
6 |
+
|
7 |
+
|
8 |
+
def create_animated_svg(file, animation_ids, model_output, filename_suffix="", save=True):
|
9 |
+
""" Insert multiple animation statements.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
file (str): Path of SVG file.
|
13 |
+
animation_ids (list[int]): List of element IDs that get animated.
|
14 |
+
model_output (ndarray): Array of 13 dimensional arrays with animation predictor model output.
|
15 |
+
filename_suffix (str): Suffix of animated SVG.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
list(float): List of begin values of elements in SVG.
|
19 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statements.
|
20 |
+
|
21 |
+
"""
|
22 |
+
doc = svg_to_doc(file)
|
23 |
+
begin_values = get_begin_values_by_starting_pos(file, animation_ids, start=1, step=0.25)
|
24 |
+
for i in range(len(animation_ids)):
|
25 |
+
if not (model_output[i][:6] == np.array([0] * 6)).all():
|
26 |
+
try: # there are some paths that can't be embedded and don't have style attributes
|
27 |
+
output_dict = transform_animation_predictor_output(file, animation_ids[i], model_output[i])
|
28 |
+
output_dict["begin"] = begin_values[i]
|
29 |
+
if output_dict["type"] == "translate":
|
30 |
+
doc = insert_translate_statement(doc, animation_ids[i], output_dict)
|
31 |
+
if output_dict["type"] == "scale":
|
32 |
+
doc = insert_scale_statement(doc, animation_ids[i], output_dict, file)
|
33 |
+
if output_dict["type"] == "rotate":
|
34 |
+
doc = insert_rotate_statement(doc, animation_ids[i], output_dict)
|
35 |
+
if output_dict["type"] in ["skewX", "skewY"]:
|
36 |
+
doc = insert_skew_statement(doc, animation_ids[i], output_dict)
|
37 |
+
if output_dict["type"] == "fill":
|
38 |
+
doc = insert_fill_statement(doc, animation_ids[i], output_dict)
|
39 |
+
if output_dict["type"] in ["opacity"]:
|
40 |
+
doc = insert_opacity_statement(doc, animation_ids[i], output_dict)
|
41 |
+
except Exception as e:
|
42 |
+
print(f"File {file}, animation ID {animation_ids[i]} can't be animated. {e}")
|
43 |
+
pass
|
44 |
+
|
45 |
+
if save:
|
46 |
+
filename = file.split('/')[-1].replace(".svg", "") + "_animated"
|
47 |
+
save_animated_svg(doc, filename)
|
48 |
+
|
49 |
+
return begin_values, doc
|
50 |
+
|
51 |
+
|
52 |
+
def svg_to_doc(file):
|
53 |
+
""" Parse an SVG file.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
file (string): Path of SVG file.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
60 |
+
|
61 |
+
"""
|
62 |
+
return minidom.parse(file)
|
63 |
+
|
64 |
+
|
65 |
+
def save_animated_svg(doc, filename):
|
66 |
+
""" Save animated SVGs to folder animated_svgs.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
doc (xml.dom.minidom.Document): Parsed file.
|
70 |
+
filename (str): Name of output file.
|
71 |
+
|
72 |
+
"""
|
73 |
+
Path("data/animated_svgs").mkdir(parents=True, exist_ok=True)
|
74 |
+
|
75 |
+
with open('data/animated_svgs/' + filename + '.svg', 'wb') as f:
|
76 |
+
f.write(doc.toprettyxml(encoding="iso-8859-1"))
|
77 |
+
|
78 |
+
|
79 |
+
def insert_translate_statement(doc, animation_id, model_output_dict):
|
80 |
+
""" Insert translate statement.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
doc (xml.dom.minidom.Document): Parsed file.
|
84 |
+
animation_id (int): ID of element that gets animated.
|
85 |
+
model_output_dict (dict): Dictionary containing animation statement.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
89 |
+
|
90 |
+
"""
|
91 |
+
pre_animations = []
|
92 |
+
opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
|
93 |
+
pre_animations.append(create_animation_statement(opacity_dict_1))
|
94 |
+
pre_animations.append(create_animation_statement(opacity_dict_2))
|
95 |
+
|
96 |
+
animation = create_animation_statement(model_output_dict)
|
97 |
+
doc = insert_animation(doc, animation_id, animation, pre_animations)
|
98 |
+
return doc
|
99 |
+
|
100 |
+
|
101 |
+
def insert_scale_statement(doc, animation_id, model_output_dict, file):
|
102 |
+
""" Insert scale statement.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
doc (xml.dom.minidom.Document): Parsed file.
|
106 |
+
animation_id (int): ID of element that gets animated.
|
107 |
+
model_output_dict (dict): Dictionary containing animation statement.
|
108 |
+
file (str): Path of SVG file. Needed to get midpoint of path bbox to suppress simultaneous translate movement.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
112 |
+
|
113 |
+
"""
|
114 |
+
pre_animations = []
|
115 |
+
opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
|
116 |
+
pre_animations.append(create_animation_statement(opacity_dict_1))
|
117 |
+
pre_animations.append(create_animation_statement(opacity_dict_2))
|
118 |
+
|
119 |
+
x_midpoint, y_midpoint = get_midpoint_of_path_bbox(file, animation_id)
|
120 |
+
if model_output_dict["from_"] > 1:
|
121 |
+
model_output_dict["from_"] = 2
|
122 |
+
pre_animation_from = f"-{x_midpoint} -{y_midpoint}" # negative midpoint
|
123 |
+
else:
|
124 |
+
model_output_dict["from_"] = 0
|
125 |
+
pre_animation_from = f"{x_midpoint} {y_midpoint}" # positive midpoint
|
126 |
+
|
127 |
+
translate_pre_animation_dict = {"type": "translate",
|
128 |
+
"begin": model_output_dict["begin"],
|
129 |
+
"dur": model_output_dict["dur"],
|
130 |
+
"from_": pre_animation_from,
|
131 |
+
"to": "0 0",
|
132 |
+
"fill": "freeze"}
|
133 |
+
pre_animations.append(create_animation_statement(translate_pre_animation_dict))
|
134 |
+
|
135 |
+
animation = create_animation_statement(model_output_dict) + ' additive="sum" '
|
136 |
+
doc = insert_animation(doc, animation_id, animation, pre_animations)
|
137 |
+
return doc
|
138 |
+
|
139 |
+
|
140 |
+
def insert_rotate_statement(doc, animation_id, model_output_dict):
|
141 |
+
""" Insert rotate statement.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
doc (xml.dom.minidom.Document): Parsed file.
|
145 |
+
animation_id (int): ID of element that gets animated.
|
146 |
+
model_output_dict (dict): Dictionary containing animation statement.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
150 |
+
|
151 |
+
"""
|
152 |
+
pre_animations = []
|
153 |
+
opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
|
154 |
+
pre_animations.append(create_animation_statement(opacity_dict_1))
|
155 |
+
pre_animations.append(create_animation_statement(opacity_dict_2))
|
156 |
+
|
157 |
+
animation = create_animation_statement(model_output_dict)
|
158 |
+
doc = insert_animation(doc, animation_id, animation, pre_animations)
|
159 |
+
return doc
|
160 |
+
|
161 |
+
|
162 |
+
def insert_skew_statement(doc, animation_id, model_output_dict):
|
163 |
+
""" Insert skew statement.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
doc (xml.dom.minidom.Document): Parsed file.
|
167 |
+
animation_id (int): ID of element that gets animated.
|
168 |
+
model_output_dict (dict): Dictionary containing animation statement.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
172 |
+
|
173 |
+
"""
|
174 |
+
pre_animations = []
|
175 |
+
opacity_dict_1, opacity_dict_2 = create_opacity_pre_animation_dicts(model_output_dict)
|
176 |
+
pre_animations.append(create_animation_statement(opacity_dict_1))
|
177 |
+
pre_animations.append(create_animation_statement(opacity_dict_2))
|
178 |
+
|
179 |
+
animation = create_animation_statement(model_output_dict)
|
180 |
+
doc = insert_animation(doc, animation_id, animation, pre_animations)
|
181 |
+
return doc
|
182 |
+
|
183 |
+
|
184 |
+
def insert_fill_statement(doc, animation_id, model_output_dict):
|
185 |
+
""" Insert fill statement.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
doc (xml.dom.minidom.Document): Parsed file
|
189 |
+
animation_id (int): ID of element that gets animated.
|
190 |
+
model_output_dict (dict): Dictionary containing animation statement.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
194 |
+
|
195 |
+
"""
|
196 |
+
pre_animations = []
|
197 |
+
model_output_dict['dur'] = 2
|
198 |
+
if model_output_dict['begin'] < 2:
|
199 |
+
model_output_dict['begin'] = 0
|
200 |
+
else: # Wave
|
201 |
+
pre_animation_dict = {"type": "fill",
|
202 |
+
"begin": 0,
|
203 |
+
"dur": model_output_dict["begin"],
|
204 |
+
"from_": model_output_dict["to"],
|
205 |
+
"to": model_output_dict["from_"],
|
206 |
+
"fill": "remove"}
|
207 |
+
pre_animations.append(create_animation_statement(pre_animation_dict))
|
208 |
+
|
209 |
+
animation = create_animation_statement(model_output_dict)
|
210 |
+
doc = insert_animation(doc, animation_id, animation, pre_animations)
|
211 |
+
return doc
|
212 |
+
|
213 |
+
|
214 |
+
def insert_opacity_statement(doc, animation_id, model_output_dict):
|
215 |
+
""" Insert opacity statement.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
doc (xml.dom.minidom.Document): Parsed file.
|
219 |
+
animation_id (int): ID of element that gets animated.
|
220 |
+
model_output_dict (dict): Dictionary containing animation statement.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
224 |
+
|
225 |
+
"""
|
226 |
+
pre_animations = []
|
227 |
+
opacity_pre_animation_dict = {"type": "opacity",
|
228 |
+
"begin": "0",
|
229 |
+
"dur": model_output_dict["begin"],
|
230 |
+
"from_": "0",
|
231 |
+
"to": "0",
|
232 |
+
"fill": "remove"}
|
233 |
+
pre_animations.append(create_animation_statement(opacity_pre_animation_dict))
|
234 |
+
|
235 |
+
animation = create_animation_statement(model_output_dict)
|
236 |
+
doc = insert_animation(doc, animation_id, animation, pre_animations)
|
237 |
+
return doc
|
238 |
+
|
239 |
+
|
240 |
+
def insert_animation(doc, animation_id, animation, pre_animations=None):
|
241 |
+
""" Insert animation statements including pre-animation statements.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
doc (xml.dom.minidom.Document): Parsed file.
|
245 |
+
animation_id (int): ID of element that gets animated.
|
246 |
+
animation (string): Animation that needs to be inserted.
|
247 |
+
pre_animations (list): List of animations that needs to be inserted before actual animation.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
xml.dom.minidom.Document: Parsed file with inserted animation statement.
|
251 |
+
|
252 |
+
"""
|
253 |
+
elements = doc.getElementsByTagName('path') + doc.getElementsByTagName('circle') + doc.getElementsByTagName(
|
254 |
+
'ellipse') + doc.getElementsByTagName('line') + doc.getElementsByTagName(
|
255 |
+
'polygon') + doc.getElementsByTagName('polyline') + doc.getElementsByTagName(
|
256 |
+
'rect') + doc.getElementsByTagName('text')
|
257 |
+
|
258 |
+
for element in elements:
|
259 |
+
if element.getAttribute('animation_id') == str(animation_id):
|
260 |
+
if pre_animations is not None:
|
261 |
+
for i in range(len(pre_animations)):
|
262 |
+
element.appendChild(doc.createElement(pre_animations[i]))
|
263 |
+
element.appendChild(doc.createElement(animation))
|
264 |
+
|
265 |
+
return doc
|
266 |
+
|
267 |
+
|
268 |
+
def create_animation_statement(animation_dict):
|
269 |
+
""" Set up animation statement from a dictionary.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
animation_dict (dict): Dictionary that is transformed into animation statement.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
str: Animation statement.
|
276 |
+
|
277 |
+
"""
|
278 |
+
if animation_dict["type"] in ["translate", "scale", "rotate", "skewX", "skewY"]:
|
279 |
+
return _create_animate_transform_statement(animation_dict)
|
280 |
+
elif animation_dict["type"] in ["fill", "opacity"]:
|
281 |
+
return _create_animate_statement(animation_dict)
|
282 |
+
|
283 |
+
|
284 |
+
def _create_animate_transform_statement(animation_dict):
|
285 |
+
""" Set up animation statement from model output for ANIMATETRANSFORM animations """
|
286 |
+
animation = f'animateTransform attributeName = "transform" attributeType = "XML" ' \
|
287 |
+
f'type = "{animation_dict["type"]}" ' \
|
288 |
+
f'begin = "{str(animation_dict["begin"])}" ' \
|
289 |
+
f'dur = "{str(animation_dict["dur"])}" ' \
|
290 |
+
f'from = "{str(animation_dict["from_"])}" ' \
|
291 |
+
f'to = "{str(animation_dict["to"])}" ' \
|
292 |
+
f'fill = "{str(animation_dict["fill"])}"'
|
293 |
+
|
294 |
+
return animation
|
295 |
+
|
296 |
+
|
297 |
+
def _create_animate_statement(animation_dict):
|
298 |
+
""" Set up animation statement from model output for ANIMATE animations """
|
299 |
+
animation = f'animate attributeName = "{animation_dict["type"]}" ' \
|
300 |
+
f'begin = "{str(animation_dict["begin"])}" ' \
|
301 |
+
f'dur = "{str(animation_dict["dur"])}" ' \
|
302 |
+
f'from = "{str(animation_dict["from_"])}" ' \
|
303 |
+
f'to = "{str(animation_dict["to"])}" ' \
|
304 |
+
f'fill = "{str(animation_dict["fill"])}"'
|
305 |
+
|
306 |
+
return animation
|
307 |
+
|
308 |
+
|
309 |
+
def create_opacity_pre_animation_dicts(animation_dict):
|
310 |
+
""" Set up pre_animation statements.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
animation_dict (dict): Dictionary from animation that is needed to set up opacity pre-animations.
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
str: Animation Statement.
|
317 |
+
|
318 |
+
"""
|
319 |
+
opacity_pre_animation_dict_1 = {"type": "opacity",
|
320 |
+
"begin": "0",
|
321 |
+
"dur": animation_dict["begin"],
|
322 |
+
"from_": "0",
|
323 |
+
"to": "0",
|
324 |
+
"fill": "remove"}
|
325 |
+
|
326 |
+
opacity_pre_animation_dict_2 = {"type": "opacity",
|
327 |
+
"begin": animation_dict["begin"],
|
328 |
+
"dur": "0.5",
|
329 |
+
"from_": "0",
|
330 |
+
"to": "1",
|
331 |
+
"fill": "remove"}
|
332 |
+
|
333 |
+
return opacity_pre_animation_dict_1, opacity_pre_animation_dict_2
|
src/postprocessing/logo_0.svg
ADDED
src/postprocessing/postprocessing.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from xml.dom import minidom
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
sys.path.append(os.getcwd())
|
10 |
+
from src.postprocessing.get_svg_size_pos import get_svg_bbox, get_path_bbox, get_midpoint_of_path_bbox
|
11 |
+
from src.postprocessing.get_style_attributes import get_style_attributes_path
|
12 |
+
|
13 |
+
random.seed(0)
|
14 |
+
|
15 |
+
filter_id = 0
|
16 |
+
|
17 |
+
def animate_logo(model_output: pd.DataFrame, logo_path: str):
|
18 |
+
logo_xmin, logo_xmax, logo_ymin, logo_ymax = get_svg_bbox(logo_path)
|
19 |
+
# ---- Normalize model output ----
|
20 |
+
animations_by_id = defaultdict(list)
|
21 |
+
for row in model_output.iterrows():
|
22 |
+
# Structure animations by animation id
|
23 |
+
animation_id = row[1]['animation_id']
|
24 |
+
output = row[1]['model_output']
|
25 |
+
animations_by_id[animation_id].append(output)
|
26 |
+
total_animations = []
|
27 |
+
for animation_id in animations_by_id.keys():
|
28 |
+
print(animation_id)
|
29 |
+
path_xmin, path_xmax, path_ymin, path_ymax = get_path_bbox(logo_path, animation_id)
|
30 |
+
xmin = logo_xmin - path_xmin
|
31 |
+
xmax = logo_xmax - path_xmax
|
32 |
+
ymin = logo_ymin - path_ymin
|
33 |
+
ymax = logo_ymax - path_ymax
|
34 |
+
# Structure animations by type (check first 10 parameters)
|
35 |
+
animations_by_type = defaultdict(list)
|
36 |
+
for animation in animations_by_id[animation_id]:
|
37 |
+
if animation[0] == 1:
|
38 |
+
# EOS
|
39 |
+
continue
|
40 |
+
try:
|
41 |
+
animation_type = animation[1:10].index(1)
|
42 |
+
animations_by_type[animation_type].append(animation)
|
43 |
+
except:
|
44 |
+
# No value found
|
45 |
+
print('Model output invalid: no animation type found')
|
46 |
+
return
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
for animation_type in animations_by_type.keys():
|
51 |
+
# Set up list of animations for later distribution
|
52 |
+
current_animations = []
|
53 |
+
# Sort animations by begin
|
54 |
+
animations_by_type[animation_type].sort(key=lambda l : l[10]) # Sort by begin
|
55 |
+
# For every animation, check consistency of begin and duration, then set parameters
|
56 |
+
for i in range(len(animations_by_type[animation_type])):
|
57 |
+
# Check if begin is equal to next animation's begin - in this case, set second begin to average of first and third animation
|
58 |
+
# Get next animation with different begin time
|
59 |
+
if len(animations_by_type[animation_type]) > 1:
|
60 |
+
j = 1
|
61 |
+
next_animation = animations_by_type[animation_type][j]
|
62 |
+
while (i + j) < len(animations_by_type[animation_type]) and animations_by_type[animation_type][i][10] == next_animation[10]:
|
63 |
+
j += 1
|
64 |
+
next_animation = animations_by_type[animation_type][j]
|
65 |
+
if j != 1:
|
66 |
+
# Get difference
|
67 |
+
difference = animations_by_type[animation_type][j][10] - animations_by_type[animation_type][i][10]
|
68 |
+
interval = difference / (j - i)
|
69 |
+
factor = 0
|
70 |
+
for a in range(i, j):
|
71 |
+
animations_by_type[animation_type][a][10] = animations_by_type[animation_type][i][10] + interval * factor
|
72 |
+
factor += 1
|
73 |
+
# Check if duration and begin of next animation are consistent - if not, shorten duration
|
74 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
75 |
+
max_duration = animations_by_type[animation_type][i+1][10] - animations_by_type[animation_type][i][10]
|
76 |
+
if animations_by_type[animation_type][i][11] > max_duration:
|
77 |
+
animations_by_type[animation_type][i][11] = max_duration
|
78 |
+
|
79 |
+
# Get general parameters
|
80 |
+
begin = animations_by_type[animation_type][i][10]
|
81 |
+
dur = animations_by_type[animation_type][i][10]
|
82 |
+
# Check type and call method
|
83 |
+
if animation_type == 1:
|
84 |
+
# animation: translate
|
85 |
+
from_x = animations_by_type[animation_type][i][12]
|
86 |
+
from_y = animations_by_type[animation_type][i][13]
|
87 |
+
# Check if there is a next translate animation
|
88 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
89 |
+
# animation endpoint is next translate animation's starting point
|
90 |
+
to_x = animations_by_type[animation_type][i+1][12]
|
91 |
+
to_y = animations_by_type[animation_type][i+1][13]
|
92 |
+
else:
|
93 |
+
# animation endpoint is final position of object
|
94 |
+
to_x = 0
|
95 |
+
to_y = 0
|
96 |
+
# Check if parameters are within boundary
|
97 |
+
if from_x < xmin:
|
98 |
+
from_x = xmin
|
99 |
+
elif from_x > xmax:
|
100 |
+
from_x = xmax
|
101 |
+
if from_y < ymin:
|
102 |
+
from_y = ymin
|
103 |
+
elif from_y > ymax:
|
104 |
+
from_y = ymax
|
105 |
+
if to_x < xmin:
|
106 |
+
to_x = xmin
|
107 |
+
elif to_x > xmax:
|
108 |
+
to_x = xmax
|
109 |
+
if to_y < ymin:
|
110 |
+
to_y = ymin
|
111 |
+
elif to_y > ymax:
|
112 |
+
to_y = ymax
|
113 |
+
# Append animation to list
|
114 |
+
current_animations.append(_animation_translate(animation_id, begin, dur, from_x, from_y, to_x, to_y))
|
115 |
+
elif animation_type == 2:
|
116 |
+
print('curve')
|
117 |
+
from_x = animations_by_type[animation_type][i][12]
|
118 |
+
from_y = animations_by_type[animation_type][i][13]
|
119 |
+
via_x = animations_by_type[animation_type][i][14]
|
120 |
+
via_y = animations_by_type[animation_type][i][15]
|
121 |
+
# Check if there is a next curve animation
|
122 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
123 |
+
# animation endpoint is next curve animation's starting point
|
124 |
+
to_x = animations_by_type[animation_type][i+1][12]
|
125 |
+
to_y = animations_by_type[animation_type][i+1][13]
|
126 |
+
else:
|
127 |
+
# animation endpoint is final position of object
|
128 |
+
to_x = 0
|
129 |
+
to_y = 0
|
130 |
+
# Check if parameters are within boundary
|
131 |
+
if from_x < xmin:
|
132 |
+
from_x = xmin
|
133 |
+
elif from_x > xmax:
|
134 |
+
from_x = xmax
|
135 |
+
if from_y < ymin:
|
136 |
+
from_y = ymin
|
137 |
+
elif from_y > ymax:
|
138 |
+
from_y = ymax
|
139 |
+
if via_x < xmin:
|
140 |
+
via_x = xmin
|
141 |
+
elif via_x > xmax:
|
142 |
+
via_x = xmax
|
143 |
+
if via_y < ymin:
|
144 |
+
via_y = ymin
|
145 |
+
elif via_y > ymax:
|
146 |
+
via_y = ymax
|
147 |
+
if to_x < xmin:
|
148 |
+
to_x = xmin
|
149 |
+
elif to_x > xmax:
|
150 |
+
to_x = xmax
|
151 |
+
if to_y < ymin:
|
152 |
+
to_y = ymin
|
153 |
+
elif to_y > ymax:
|
154 |
+
to_y = ymax
|
155 |
+
# Append animation to list
|
156 |
+
current_animations.append(_animation_curve(animation_id, begin, dur, from_x, from_y, via_x, via_y, to_x, to_y))
|
157 |
+
elif animation_type == 3:
|
158 |
+
# animation: scale
|
159 |
+
from_f = animations_by_type[animation_type][i][16]
|
160 |
+
# Check if there is a next scale animation
|
161 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
162 |
+
# animation endpoint is next scale animation's starting point
|
163 |
+
to_f = animations_by_type[animation_type][i+1][16]
|
164 |
+
else:
|
165 |
+
# animation endpoint is final position of object
|
166 |
+
to_f = 1
|
167 |
+
current_animations.append(_animation_scale(animation_id, begin, dur, from_f, to_f))
|
168 |
+
elif animation_type == 4:
|
169 |
+
# animation: rotate
|
170 |
+
from_degree = animations_by_type[animation_type][i][17]
|
171 |
+
# Get midpoints
|
172 |
+
midpoints = get_midpoint_of_path_bbox(logo_path, animation_id)
|
173 |
+
# Check if there is a next scale animation
|
174 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
175 |
+
# animation endpoint is next scale animation's starting point
|
176 |
+
to_degree = animations_by_type[animation_type][i+1][17]
|
177 |
+
else:
|
178 |
+
# animation endpoint is final position of object
|
179 |
+
to_degree = 360
|
180 |
+
current_animations.append(_animation_rotate(animation_id, begin, dur, from_degree, to_degree, midpoints))
|
181 |
+
elif animation_type == 5:
|
182 |
+
# animation: skewX
|
183 |
+
from_x = animations_by_type[animation_type][i][18]
|
184 |
+
# Check if there is a next skewX animation
|
185 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
186 |
+
# animation endpoint is next skewX animation's starting point
|
187 |
+
to_x = animations_by_type[animation_type][i+1][18]
|
188 |
+
else:
|
189 |
+
# animation endpoint is final position of object
|
190 |
+
to_x = 1
|
191 |
+
# Check if parameters are within boundary
|
192 |
+
if from_x < xmin:
|
193 |
+
from_x = xmin
|
194 |
+
elif from_x > xmax:
|
195 |
+
from_x = xmax
|
196 |
+
if to_x < xmin:
|
197 |
+
to_x = xmin
|
198 |
+
elif to_x > xmax:
|
199 |
+
to_x = xmax
|
200 |
+
current_animations.append(_animation_skewX(animation_id, begin, dur, from_x, to_x))
|
201 |
+
elif animation_type == 6:
|
202 |
+
# animation: skewY
|
203 |
+
from_y = animations_by_type[animation_type][i][19]
|
204 |
+
# Check if there is a next skewY animation
|
205 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
206 |
+
# animation endpoint is next skewY animation's starting point
|
207 |
+
to_y = animations_by_type[animation_type][i+1][19]
|
208 |
+
else:
|
209 |
+
# animation endpoint is final position of object
|
210 |
+
to_y = 1
|
211 |
+
# Check if parameters are within boundary
|
212 |
+
if from_y < ymin:
|
213 |
+
from_y = ymin
|
214 |
+
elif from_y > ymax:
|
215 |
+
from_y = ymax
|
216 |
+
if to_y < ymin:
|
217 |
+
to_y = ymin
|
218 |
+
elif to_y > ymax:
|
219 |
+
to_y = ymax
|
220 |
+
current_animations.append(_animation_skewY(animation_id, begin, dur, from_y, to_y))
|
221 |
+
elif animation_type == 7:
|
222 |
+
# animation: fill
|
223 |
+
from_rgb = '#' + _convert_to_hex_str(animations_by_type[animation_type][i][20]) + _convert_to_hex_str(animations_by_type[animation_type][i][21]) + _convert_to_hex_str(animations_by_type[animation_type][i][22])
|
224 |
+
# Check if there is a next fill animation
|
225 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
226 |
+
# animation endpoint is next fill animation's starting point
|
227 |
+
to_rgb = '#' + _convert_to_hex_str(animations_by_type[animation_type][i+1][20]) + _convert_to_hex_str(animations_by_type[animation_type][i+1][21]) + _convert_to_hex_str(animations_by_type[animation_type][i+1][22])
|
228 |
+
else:
|
229 |
+
fill_style = get_style_attributes_path(logo_path, animation_id, "fill")
|
230 |
+
stroke_style = get_style_attributes_path(logo_path, animation_id, "stroke")
|
231 |
+
if fill_style == "none" and stroke_style != "none":
|
232 |
+
color_hex = stroke_style
|
233 |
+
else:
|
234 |
+
color_hex = fill_style
|
235 |
+
to_rgb = color_hex
|
236 |
+
current_animations.append(_animation_fill(animation_id, begin, dur, from_rgb, to_rgb))
|
237 |
+
elif animation_type == 8:
|
238 |
+
# animation: opacity
|
239 |
+
from_f = animations_by_type[animation_type][i][23] / 100 # percent
|
240 |
+
# Check if there is a next opacity animation
|
241 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
242 |
+
# animation endpoint is next opacity animation's starting point
|
243 |
+
to_f = animations_by_type[animation_type][i+1][23] / 100 # percent
|
244 |
+
else:
|
245 |
+
# animation endpoint is final position of object
|
246 |
+
to_f = 1
|
247 |
+
current_animations.append(_animation_opacity(animation_id, begin, dur, from_f, to_f))
|
248 |
+
elif animation_type == 9:
|
249 |
+
# animation: blur
|
250 |
+
from_f = animations_by_type[animation_type][i][24]
|
251 |
+
# Check if there is a next blur animation
|
252 |
+
if i < len(animations_by_type[animation_type]) - 1:
|
253 |
+
# animation endpoint is next blur animation's starting point
|
254 |
+
to_f = animations_by_type[animation_type][i+1][24]
|
255 |
+
else:
|
256 |
+
# animation endpoint is final position of object
|
257 |
+
to_f = 1
|
258 |
+
current_animations.append(_animation_blur(animation_id, begin, dur, from_f, to_f))
|
259 |
+
total_animations += current_animations
|
260 |
+
# Shift begin - TODO test
|
261 |
+
min_b = np.inf
|
262 |
+
for animation in total_animations:
|
263 |
+
print(animation["begin"], min_b)
|
264 |
+
if float(animation["begin"]) < float(min_b):
|
265 |
+
min_b = animation["begin"]
|
266 |
+
for animation in total_animations:
|
267 |
+
animation["begin"] = float(animation["begin"]) - float(min_b)
|
268 |
+
|
269 |
+
_insert_animations(total_animations, logo_path, logo_path)
|
270 |
+
|
271 |
+
def _convert_to_hex_str(i: int):
|
272 |
+
h = str(hex(i))[2:]
|
273 |
+
if i < 16:
|
274 |
+
h = '0' + h
|
275 |
+
return h
|
276 |
+
|
277 |
+
def _animation_translate(animation_id: int, begin: float, dur: float, from_x: int, from_y: int, to_x: int, to_y: int):
|
278 |
+
print('animation: translate')
|
279 |
+
animation_dict = {}
|
280 |
+
animation_dict['animation_id'] = animation_id
|
281 |
+
animation_dict['animation_type'] = 'animate_transform'
|
282 |
+
animation_dict['attributeName'] = 'transform'
|
283 |
+
animation_dict['attributeType'] = 'XML'
|
284 |
+
animation_dict['type'] = 'translate'
|
285 |
+
animation_dict['begin'] = str(begin)
|
286 |
+
animation_dict['dur'] = str(dur)
|
287 |
+
animation_dict['fill'] = 'freeze'
|
288 |
+
animation_dict['from'] = f'{from_x} {from_y}'
|
289 |
+
animation_dict['to'] = f'{to_x} {to_y}'
|
290 |
+
return animation_dict
|
291 |
+
|
292 |
+
def _animation_curve(animation_id: int, begin: float, dur: float, from_x: int, from_y: int, via_x: int, via_y: int, to_x: int, to_y: int):
|
293 |
+
print('animation: curve')
|
294 |
+
animation_dict = {}
|
295 |
+
animation_dict['animation_id'] = animation_id
|
296 |
+
animation_dict['animation_type'] = 'animate_motion'
|
297 |
+
animation_dict['begin'] = str(begin)
|
298 |
+
animation_dict['dur'] = str(dur)
|
299 |
+
animation_dict['fill'] = 'freeze'
|
300 |
+
animation_dict['from'] = f'{from_x} {from_y}'
|
301 |
+
animation_dict['via'] = f'{via_x} {via_y}'
|
302 |
+
animation_dict['to'] = f'{to_x} {to_y}'
|
303 |
+
return animation_dict
|
304 |
+
|
305 |
+
def _animation_scale(animation_id: int, begin: float, dur: float, from_f: float, to_f: float):
|
306 |
+
print('animation: scale')
|
307 |
+
animation_dict = {}
|
308 |
+
animation_dict['animation_id'] = animation_id
|
309 |
+
animation_dict['animation_type'] = 'animate_transform'
|
310 |
+
animation_dict['attributeName'] = 'transform'
|
311 |
+
animation_dict['attributeType'] = 'XML'
|
312 |
+
animation_dict['type'] = 'scale'
|
313 |
+
animation_dict['begin'] = str(begin)
|
314 |
+
animation_dict['dur'] = str(dur)
|
315 |
+
animation_dict['fill'] = 'freeze'
|
316 |
+
animation_dict['from'] = str(from_f)
|
317 |
+
animation_dict['to'] = str(to_f)
|
318 |
+
return animation_dict
|
319 |
+
|
320 |
+
def _animation_rotate(animation_id: int, begin: float, dur: float, from_degree: int, to_degree: int, midpoints: list):
|
321 |
+
print('animation: rotate')
|
322 |
+
animation_dict = {}
|
323 |
+
animation_dict['animation_id'] = animation_id
|
324 |
+
animation_dict['animation_type'] = 'animate_transform'
|
325 |
+
animation_dict['attributeName'] = 'transform'
|
326 |
+
animation_dict['attributeType'] = 'XML'
|
327 |
+
animation_dict['type'] = 'rotate'
|
328 |
+
animation_dict['begin'] = str(begin)
|
329 |
+
animation_dict['dur'] = str(dur)
|
330 |
+
animation_dict['fill'] = 'freeze'
|
331 |
+
animation_dict['from'] = f'{from_degree} {midpoints[0]} {midpoints[1]}'
|
332 |
+
animation_dict['to'] = f'{to_degree} {midpoints[0]} {midpoints[1]}'
|
333 |
+
return animation_dict
|
334 |
+
|
335 |
+
def _animation_skewX(animation_id: int, begin: float, dur: float, from_i: int, to_i: int):
|
336 |
+
print('animation: skew')
|
337 |
+
animation_dict = {}
|
338 |
+
animation_dict['animation_id'] = animation_id
|
339 |
+
animation_dict['animation_type'] = 'animate_transform'
|
340 |
+
animation_dict['attributeName'] = 'transform'
|
341 |
+
animation_dict['attributeType'] = 'XML'
|
342 |
+
animation_dict['type'] = 'skewX'
|
343 |
+
animation_dict['begin'] = str(begin)
|
344 |
+
animation_dict['dur'] = str(dur)
|
345 |
+
animation_dict['fill'] = 'freeze'
|
346 |
+
animation_dict['from'] = f'{from_i}'
|
347 |
+
animation_dict['to'] = f'{to_i}'
|
348 |
+
return animation_dict
|
349 |
+
|
350 |
+
def _animation_skewY(animation_id: int, begin: float, dur: float, from_i: int, to_i: int):
|
351 |
+
print('animation: skew')
|
352 |
+
animation_dict = {}
|
353 |
+
animation_dict['animation_id'] = animation_id
|
354 |
+
animation_dict['animation_type'] = 'animate_transform'
|
355 |
+
animation_dict['attributeName'] = 'transform'
|
356 |
+
animation_dict['attributeType'] = 'XML'
|
357 |
+
animation_dict['type'] = 'skewY'
|
358 |
+
animation_dict['begin'] = str(begin)
|
359 |
+
animation_dict['dur'] = str(dur)
|
360 |
+
animation_dict['fill'] = 'freeze'
|
361 |
+
animation_dict['from'] = f'{from_i}'
|
362 |
+
animation_dict['to'] = f'{to_i}'
|
363 |
+
return animation_dict
|
364 |
+
|
365 |
+
def _animation_fill(animation_id: int, begin: float, dur: float, from_rgb: str, to_rgb: str):
|
366 |
+
print('animation: fill')
|
367 |
+
animation_dict = {}
|
368 |
+
animation_dict['animation_id'] = animation_id
|
369 |
+
animation_dict['animation_type'] = 'animate'
|
370 |
+
animation_dict['attributeName'] = 'fill'
|
371 |
+
animation_dict['attributeType'] = 'XML'
|
372 |
+
animation_dict['type'] = 'fill'
|
373 |
+
animation_dict['begin'] = str(begin)
|
374 |
+
animation_dict['dur'] = str(dur)
|
375 |
+
animation_dict['fill'] = 'freeze'
|
376 |
+
animation_dict['from'] = from_rgb
|
377 |
+
animation_dict['to'] = to_rgb
|
378 |
+
return animation_dict
|
379 |
+
|
380 |
+
def _animation_opacity(animation_id: int, begin: float, dur: float, from_f: float, to_f: float):
|
381 |
+
print('animation: opacity')
|
382 |
+
animation_dict = {}
|
383 |
+
animation_dict['animation_id'] = animation_id
|
384 |
+
animation_dict['animation_type'] = 'animate'
|
385 |
+
animation_dict['attributeName'] = 'opacity'
|
386 |
+
animation_dict['attributeType'] = 'XML'
|
387 |
+
animation_dict['type'] = 'opacity'
|
388 |
+
animation_dict['begin'] = str(begin)
|
389 |
+
animation_dict['dur'] = str(dur)
|
390 |
+
animation_dict['fill'] = 'freeze'
|
391 |
+
animation_dict['from'] = str(from_f)
|
392 |
+
animation_dict['to'] = str(to_f)
|
393 |
+
return animation_dict
|
394 |
+
|
395 |
+
def _animation_blur(animation_id: int, begin: float, dur: float, from_f: float, to_f: float):
|
396 |
+
print('animation: blur')
|
397 |
+
animation_dict = {}
|
398 |
+
animation_dict['animation_id'] = animation_id
|
399 |
+
animation_dict['animation_type'] = 'animate_filter'
|
400 |
+
animation_dict['attributeName'] = 'transform'
|
401 |
+
animation_dict['attributeType'] = 'XML'
|
402 |
+
animation_dict['type'] = 'blur'
|
403 |
+
animation_dict['begin'] = str(begin)
|
404 |
+
animation_dict['dur'] = str(dur)
|
405 |
+
animation_dict['fill'] = 'freeze'
|
406 |
+
animation_dict['from'] = str(from_f)
|
407 |
+
animation_dict['to'] = str(to_f)
|
408 |
+
return animation_dict
|
409 |
+
|
410 |
+
def _insert_animations(animations: list, path: str, target_path: str):
|
411 |
+
print('Insert animations')
|
412 |
+
# Load XML
|
413 |
+
document = minidom.parse(path)
|
414 |
+
# Collect all elements
|
415 |
+
elements = document.getElementsByTagName('path') + document.getElementsByTagName('circle') + document.getElementsByTagName(
|
416 |
+
'ellipse') + document.getElementsByTagName('line') + document.getElementsByTagName(
|
417 |
+
'polygon') + document.getElementsByTagName('polyline') + document.getElementsByTagName(
|
418 |
+
'rect') + document.getElementsByTagName('text')
|
419 |
+
# Create statement
|
420 |
+
for animation in animations:
|
421 |
+
|
422 |
+
# Search for element
|
423 |
+
current_element = None
|
424 |
+
for element in elements:
|
425 |
+
if element.getAttribute('animation_id') == str(animation['animation_id']):
|
426 |
+
current_element = element
|
427 |
+
if current_element == None:
|
428 |
+
# Animation id not found - take next animation
|
429 |
+
continue
|
430 |
+
if animation['animation_type'] == 'animate_transform':
|
431 |
+
animate_statement = _create_animate_transform_statement(animation)
|
432 |
+
current_element.appendChild(document.createElement(animate_statement))
|
433 |
+
elif animation['animation_type'] == 'animate_motion':
|
434 |
+
animate_statement = _create_animate_motion_statement(animation)
|
435 |
+
current_element.appendChild(document.createElement(animate_statement))
|
436 |
+
elif animation['animation_type'] == 'animate':
|
437 |
+
animate_statement = _create_animate_statement(animation)
|
438 |
+
current_element.appendChild(document.createElement(animate_statement))
|
439 |
+
elif animation['animation_type'] == 'animate_filter':
|
440 |
+
filter_element, fe_element, animate_statement = _create_animate_filter_statement(animation, document)
|
441 |
+
defs = document.getElementsByTagName('defs')
|
442 |
+
current_defs = None
|
443 |
+
# Check if defs tag exists; create otherwise
|
444 |
+
if len(defs) == 0:
|
445 |
+
svg = document.getElementsByTagName('svg')[0]
|
446 |
+
current_defs = document.createElement('defs')
|
447 |
+
svg.appendChild(current_defs)
|
448 |
+
else:
|
449 |
+
current_defs = defs[0]
|
450 |
+
# Check if filter to be appended
|
451 |
+
if filter_element != None:
|
452 |
+
# Create filter
|
453 |
+
print('append filter')
|
454 |
+
current_defs.appendChild(filter_element)
|
455 |
+
# Check if FE to be created
|
456 |
+
if fe_element != None:
|
457 |
+
print('create fe statement')
|
458 |
+
# Check if filter set; else search
|
459 |
+
if filter_element == None:
|
460 |
+
# Search for filter
|
461 |
+
id = 'filter_' + str(animation['animation_id'])
|
462 |
+
for f in document.getElementsByTagName('filter'):
|
463 |
+
if f.getAttribute('id') == id:
|
464 |
+
filter_element = f
|
465 |
+
# Create FE
|
466 |
+
filter_element.appendChild(fe_element)
|
467 |
+
current_defs.appendChild(document.createElement(animate_statement))
|
468 |
+
current_element.setAttribute('filter', f'url(#filter_{animation["animation_id"]})')
|
469 |
+
|
470 |
+
# Save XML to target path
|
471 |
+
with open(target_path, 'wb') as f:
|
472 |
+
f.write(document.toprettyxml(encoding="iso-8859-1"))
|
473 |
+
|
474 |
+
|
475 |
+
|
476 |
+
def _create_animate_transform_statement(animation_dict: dict):
|
477 |
+
""" Set up animation statement from model output for ANIMATETRANSFORM animations
|
478 |
+
(Adapted from AnimateSVG)
|
479 |
+
"""
|
480 |
+
animation = f'animateTransform attributeName="transform" attributeType="XML" ' \
|
481 |
+
f'type="{animation_dict["type"]}" ' \
|
482 |
+
f'begin="{str(animation_dict["begin"])}" ' \
|
483 |
+
f'dur="{str(animation_dict["dur"])}" ' \
|
484 |
+
f'from="{str(animation_dict["from"])}" ' \
|
485 |
+
f'to="{str(animation_dict["to"])}" ' \
|
486 |
+
f'fill="{str(animation_dict["fill"])}" ' \
|
487 |
+
'additive="sum"'
|
488 |
+
|
489 |
+
return animation
|
490 |
+
|
491 |
+
def _create_animate_statement(animation_dict: dict):
|
492 |
+
""" Set up animation statement from model output for ANIMATE animations
|
493 |
+
(adapted from AnimateSVG)
|
494 |
+
"""
|
495 |
+
animation = f'animate attributeName="{animation_dict["type"]}" ' \
|
496 |
+
f'begin="{str(animation_dict["begin"])}" ' \
|
497 |
+
f'dur="{str(animation_dict["dur"])}" ' \
|
498 |
+
f'from="{str(animation_dict["from"])}" ' \
|
499 |
+
f'to="{str(animation_dict["to"])}" ' \
|
500 |
+
f'fill="{str(animation_dict["fill"])}" '\
|
501 |
+
'additive="sum"'
|
502 |
+
|
503 |
+
return animation
|
504 |
+
|
505 |
+
def _create_animate_motion_statement(animation_dict: dict):
|
506 |
+
""" Set up animatie motion statement from model output for ANIMATE_MOTION animations
|
507 |
+
"""
|
508 |
+
animation = f'animateMotion ' \
|
509 |
+
f'begin="{str(animation_dict["begin"])}" ' \
|
510 |
+
f'dur="{str(animation_dict["dur"])}" ' \
|
511 |
+
f'path="M{animation_dict["from"]}" Q{animation_dict["via"]} {animation_dict["to"]}' \
|
512 |
+
f'fill="{str(animation_dict["fill"])}" '\
|
513 |
+
'additive="sum"'
|
514 |
+
return animation
|
515 |
+
|
516 |
+
def _create_animate_filter_statement(animation_dict: dict, document: minidom.Document):
|
517 |
+
global filter_id
|
518 |
+
filter_id += 1
|
519 |
+
filter_element = None
|
520 |
+
fe_element = None
|
521 |
+
animate_statement = None
|
522 |
+
if animation_dict['type'] == 'blur':
|
523 |
+
# Check if filter already exists
|
524 |
+
filters = document.getElementsByTagName('filter')
|
525 |
+
current_filter = None
|
526 |
+
current_fe = None
|
527 |
+
for f in filters:
|
528 |
+
#print(f.getAttribute('id') == f'filter_{str(animation_dict["animation_id"])}')
|
529 |
+
if f.getAttribute('id') == f'filter_{str(animation_dict["animation_id"])}':
|
530 |
+
current_filter = f
|
531 |
+
fe_elements = document.getElementsByTagName('feGaussianBlur')
|
532 |
+
for fe in fe_elements:
|
533 |
+
if fe.getAttribute('id') == f'filter_blur_{str(animation_dict["animation_id"])}':
|
534 |
+
current_fe = fe
|
535 |
+
if current_filter == None:
|
536 |
+
filter_element = document.createElement('filter')
|
537 |
+
filter_element.setAttribute('id', f'filter_{str(animation_dict["animation_id"])}')
|
538 |
+
if current_fe == None:
|
539 |
+
fe_element = document.createElement('feGaussianBlur')
|
540 |
+
fe_element.setAttribute('id', f'filter_blur_{str(animation_dict["animation_id"])}')
|
541 |
+
fe_element.setAttribute('stdDeviation', '0')
|
542 |
+
animate_statement = f'animate href="#filter_blur_{str(animation_dict["animation_id"])}" ' \
|
543 |
+
f'attributeName="stdDeviation" ' \
|
544 |
+
f'begin="{str(animation_dict["begin"])}" ' \
|
545 |
+
f'dur="{str(animation_dict["dur"])}" ' \
|
546 |
+
f'from="{str(animation_dict["from"])}" ' \
|
547 |
+
f'to="{str(animation_dict["to"])}" ' \
|
548 |
+
f'fill="{str(animation_dict["fill"])}"'\
|
549 |
+
'additive="sum"'
|
550 |
+
return filter_element, fe_element, animate_statement
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
|
555 |
+
|
556 |
+
|
557 |
+
|
558 |
+
|
559 |
+
def randomly_animate_logo(logo_path: str, target_path: str, number_of_animations: int, previously_generated: pd.DataFrame = None):
|
560 |
+
# Creates model output equal to defined number of animations. They are then randomly distributed over the paths.
|
561 |
+
# Assign animation id to every path - TODO this changes the original logo!
|
562 |
+
document = minidom.parse(logo_path)
|
563 |
+
paths = document.getElementsByTagName('path') + document.getElementsByTagName('circle') + document.getElementsByTagName(
|
564 |
+
'ellipse') + document.getElementsByTagName('line') + document.getElementsByTagName(
|
565 |
+
'polygon') + document.getElementsByTagName('polyline') + document.getElementsByTagName(
|
566 |
+
'rect') + document.getElementsByTagName('text')
|
567 |
+
for i in range(len(paths)):
|
568 |
+
paths[i].setAttribute('animation_id', str(i))
|
569 |
+
with open(target_path, 'wb') as svg_file:
|
570 |
+
svg_file.write(document.toxml(encoding='iso-8859-1'))
|
571 |
+
# Create random animations
|
572 |
+
for i in range(0, number_of_animations):
|
573 |
+
animation_type = random.randint(0, 8) # Determine animation type (as of now only primitive animation types)
|
574 |
+
model_output = np.zeros(18)
|
575 |
+
model_output[animation_type] = 1 # Set animation type
|
576 |
+
# Set animation parameters
|
577 |
+
|
578 |
+
|
579 |
+
|
580 |
+
|
581 |
+
|
582 |
+
|
583 |
+
# model_output = [
|
584 |
+
# {
|
585 |
+
# 'animation_id': 1,
|
586 |
+
# 'model_output': [0, 0, 0, 0, 0, 0, 0, 1, 1, 10, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10]
|
587 |
+
# },
|
588 |
+
# {
|
589 |
+
# 'animation_id': 1,
|
590 |
+
# 'model_output': [0, 0, 0, 0, 0, 0, 0, 1, 5, 3, 4, 5, 2, 1, 2, 3, 4, 5, 6, 7, 1000, 20]
|
591 |
+
# }
|
592 |
+
# ]
|
593 |
+
# model_output = pd.DataFrame(model_output)
|
594 |
+
# #print(model_output)
|
595 |
+
# path = 'src/postprocessing/logo_0.svg'
|
596 |
+
# # Assign animation id to every path - TODO this changes the original logo!
|
597 |
+
# document = minidom.parse(path)
|
598 |
+
# paths = document.getElementsByTagName('path')
|
599 |
+
# for i in range(len(paths)):
|
600 |
+
# paths[i].setAttribute('animation_id', str(i))
|
601 |
+
# with open(path, 'wb') as svg_file:
|
602 |
+
# svg_file.write(document.toxml(encoding='iso-8859-1'))
|
603 |
+
# #print('Inserted animation id')
|
604 |
+
# animate_logo(model_output, path)
|
src/postprocessing/transform_animation_predictor_output.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.postprocessing.get_svg_size_pos import get_svg_size, get_midpoint_of_path_bbox
|
2 |
+
from src.postprocessing.get_style_attributes import get_style_attributes_path
|
3 |
+
from src.postprocessing.get_svg_color_tendency import get_svg_color_tendencies
|
4 |
+
|
5 |
+
|
6 |
+
def transform_animation_predictor_output(file, animation_id, output):
|
7 |
+
""" Function to translate the numeric model output to animation commands.
|
8 |
+
Example: transform_animation_predictor_output("data/svgs/logo_1.svg", 0, [0,0,1,0,0,0,-1,-1,-1,0.42,-1,-1])
|
9 |
+
|
10 |
+
Args:
|
11 |
+
file (str): Path of SVG file.
|
12 |
+
animation_id (int): ID of element in SVG that gets animated.
|
13 |
+
output (list): 12-dimensional list of numeric values of which first 6 determine the animation to be used and
|
14 |
+
the last 6 determine the attribute from. Format: [translate, scale, rotate, skew, fill, opacity, translate_from_1, translate_from_2, scale_from, rotate_from, skew_from_1, skew_from_2].
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
dict: Animation statement as dictionary.
|
18 |
+
|
19 |
+
"""
|
20 |
+
animation = {}
|
21 |
+
width, height = get_svg_size(file)
|
22 |
+
x_midpoint, y_midpoint = get_midpoint_of_path_bbox(file, animation_id)
|
23 |
+
fill_style = get_style_attributes_path(file, animation_id, "fill")
|
24 |
+
stroke_style = get_style_attributes_path(file, animation_id, "stroke")
|
25 |
+
opacity_style = get_style_attributes_path(file, animation_id, "opacity")
|
26 |
+
color_1, color_2 = get_svg_color_tendencies(file)
|
27 |
+
|
28 |
+
if output[0] == 1:
|
29 |
+
animation["type"] = "translate"
|
30 |
+
x = (output[6] * 2 - 1) * width # between -width and width
|
31 |
+
y = (output[7] * 2 - 1) * height # between -height and height
|
32 |
+
animation["from_"] = f"{str(x)} {str(y)}"
|
33 |
+
animation["to"] = "0 0"
|
34 |
+
|
35 |
+
elif output[1] == 1:
|
36 |
+
animation["type"] = "scale"
|
37 |
+
animation["from_"] = output[8] * 2 # between 0 and 2
|
38 |
+
animation["to"] = 1
|
39 |
+
|
40 |
+
elif output[2] == 1:
|
41 |
+
animation["type"] = "rotate"
|
42 |
+
degree = int(output[9]*720) - 360 # between -360 and 360
|
43 |
+
animation["from_"] = f"{str(degree)} {str(x_midpoint)} {str(y_midpoint)}"
|
44 |
+
animation["to"] = f"0 {str(x_midpoint)} {str(y_midpoint)}"
|
45 |
+
|
46 |
+
elif output[3] == 1:
|
47 |
+
if output[10] > 0.5:
|
48 |
+
animation["type"] = "skewX"
|
49 |
+
animation["from_"] = (output[11] * 2 - 1) * width/20 # between -width/20 and width/20
|
50 |
+
else:
|
51 |
+
animation["type"] = "skewY"
|
52 |
+
animation["from_"] = (output[11] * 2 - 1) * height/20 # between -height/20 and height/20
|
53 |
+
animation["to"] = 0
|
54 |
+
|
55 |
+
elif output[4] == 1:
|
56 |
+
animation["type"] = "fill"
|
57 |
+
if fill_style == "none" and stroke_style != "none":
|
58 |
+
color_hex = stroke_style
|
59 |
+
else:
|
60 |
+
color_hex = fill_style
|
61 |
+
animation["to"] = color_hex
|
62 |
+
|
63 |
+
if color_hex != color_1:
|
64 |
+
color_from = color_1
|
65 |
+
else:
|
66 |
+
color_from = color_2
|
67 |
+
animation["from_"] = color_from
|
68 |
+
|
69 |
+
elif output[5] == 1:
|
70 |
+
animation["type"] = "opacity"
|
71 |
+
animation["from_"] = 0
|
72 |
+
animation["to"] = opacity_style
|
73 |
+
|
74 |
+
animation["dur"] = 4
|
75 |
+
animation["begin"] = 1
|
76 |
+
animation["fill"] = "freeze"
|
77 |
+
|
78 |
+
return animation
|
src/preprocessing/deepsvg/deepsvg_config/config.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch.optim as optim
|
7 |
+
from src.preprocessing.deepsvg.deepsvg_schedulers.warmup import GradualWarmupScheduler
|
8 |
+
|
9 |
+
|
10 |
+
class _Config:
|
11 |
+
"""
|
12 |
+
Training config.
|
13 |
+
"""
|
14 |
+
def __init__(self, num_gpus=1):
|
15 |
+
|
16 |
+
self.num_gpus = num_gpus #
|
17 |
+
|
18 |
+
self.dataloader_module = "deepsvg.svgtensor_dataset" #
|
19 |
+
self.collate_fn = None #
|
20 |
+
self.data_dir = "./data/svgs_tensors/" #
|
21 |
+
self.meta_filepath = "./data/svgs_meta.csv" #
|
22 |
+
self.loader_num_workers = 0 #
|
23 |
+
|
24 |
+
self.pretrained_path = "./models/hierarchical_ordered.pth.tar" #
|
25 |
+
|
26 |
+
self.model_cfg = None #
|
27 |
+
|
28 |
+
self.num_epochs = None #
|
29 |
+
self.num_steps = None #
|
30 |
+
self.learning_rate = 1e-3 #
|
31 |
+
self.batch_size = 100 #
|
32 |
+
self.warmup_steps = 500 #
|
33 |
+
|
34 |
+
|
35 |
+
# Dataset
|
36 |
+
self.train_ratio = 1.0 #
|
37 |
+
self.nb_augmentations = 1 #
|
38 |
+
|
39 |
+
self.max_num_groups = 15 #
|
40 |
+
self.max_seq_len = 30 #
|
41 |
+
self.max_total_len = None #
|
42 |
+
|
43 |
+
self.filter_uni = None #
|
44 |
+
self.filter_category = None #
|
45 |
+
self.filter_platform = None #
|
46 |
+
|
47 |
+
self.filter_labels = None #
|
48 |
+
|
49 |
+
self.grad_clip = None #
|
50 |
+
|
51 |
+
self.log_every = 20 #
|
52 |
+
self.val_every = 1000 #
|
53 |
+
self.ckpt_every = 1000 #
|
54 |
+
|
55 |
+
self.stats_to_print = {
|
56 |
+
"train": ["lr", "time"]
|
57 |
+
}
|
58 |
+
|
59 |
+
self.model_args = [] #
|
60 |
+
self.optimizer_starts = [0] #
|
61 |
+
|
62 |
+
# Overridable methods
|
63 |
+
def make_model(self):
|
64 |
+
raise NotImplementedError
|
65 |
+
|
66 |
+
def make_losses(self):
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
def make_optimizers(self, model):
|
70 |
+
return [optim.AdamW(model.parameters(), self.learning_rate)]
|
71 |
+
|
72 |
+
def make_schedulers(self, optimizers, epoch_size):
|
73 |
+
return [None] * len(optimizers)
|
74 |
+
|
75 |
+
def make_warmup_schedulers(self, optimizers, scheduler_lrs):
|
76 |
+
return [GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=self.warmup_steps, after_scheduler=scheduler_lr)
|
77 |
+
for optimizer, scheduler_lr in zip(optimizers, scheduler_lrs)]
|
78 |
+
|
79 |
+
def get_params(self, step, epoch):
|
80 |
+
return {}
|
81 |
+
|
82 |
+
def get_weights(self, step, epoch):
|
83 |
+
return {}
|
84 |
+
|
85 |
+
def set_train_vars(self, train_vars, dataloader):
|
86 |
+
pass
|
87 |
+
|
88 |
+
def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
|
89 |
+
pass
|
90 |
+
|
91 |
+
# Utility methods
|
92 |
+
def values(self):
|
93 |
+
for key in dir(self):
|
94 |
+
if not key.startswith("__") and not callable(getattr(self, key)):
|
95 |
+
yield key, getattr(self, key)
|
96 |
+
|
97 |
+
def to_dict(self):
|
98 |
+
return {key: val for key, val in self.values()}
|
99 |
+
|
100 |
+
def load_dict(self, dict):
|
101 |
+
for key, val in dict.items():
|
102 |
+
setattr(self, key, val)
|
103 |
+
|
104 |
+
def print_params(self):
|
105 |
+
for key, val in self.values():
|
106 |
+
print(f" {key} = {val}")
|
src/preprocessing/deepsvg/deepsvg_config/config_hierarchical_ordered.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .default_icons import *
|
7 |
+
|
8 |
+
|
9 |
+
class ModelConfig(Hierarchical):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.label_condition = False
|
14 |
+
self.use_vae = False
|
15 |
+
|
16 |
+
|
17 |
+
class Config(Config):
|
18 |
+
def __init__(self, num_gpus=1):
|
19 |
+
super().__init__(num_gpus=num_gpus)
|
20 |
+
|
21 |
+
self.model_cfg = ModelConfig()
|
22 |
+
self.model_args = self.model_cfg.get_model_args()
|
23 |
+
|
24 |
+
self.filter_category = None
|
25 |
+
|
26 |
+
self.learning_rate = 1e-3 * num_gpus
|
27 |
+
self.batch_size = 20 #60 * num_gpus
|
28 |
+
|
29 |
+
self.val_every = 10 #2000
|
src/preprocessing/deepsvg/deepsvg_config/default_icons.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from src.preprocessing.deepsvg.deepsvg_config.config import _Config
|
7 |
+
from src.preprocessing.deepsvg.deepsvg_models.model import SVGTransformer
|
8 |
+
from src.preprocessing.deepsvg.deepsvg_models.loss import SVGLoss
|
9 |
+
from src.preprocessing.deepsvg.deepsvg_models.model_config import *
|
10 |
+
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
|
11 |
+
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
|
12 |
+
from src.preprocessing.deepsvg.deepsvg_svglib.svglib_utils import make_grid
|
13 |
+
from src.preprocessing.deepsvg.deepsvg_svglib.geom import Bbox
|
14 |
+
from src.preprocessing.deepsvg.deepsvg_utils.utils import batchify, linear
|
15 |
+
|
16 |
+
import torchvision.transforms.functional as TF
|
17 |
+
import torch.optim.lr_scheduler as lr_scheduler
|
18 |
+
import random
|
19 |
+
|
20 |
+
|
21 |
+
class ModelConfig(Hierarchical):
|
22 |
+
"""
|
23 |
+
Overriding default model config.
|
24 |
+
"""
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
|
29 |
+
class Config(_Config):
|
30 |
+
"""
|
31 |
+
Overriding default training config.
|
32 |
+
"""
|
33 |
+
def __init__(self, num_gpus=1):
|
34 |
+
super().__init__(num_gpus=num_gpus)
|
35 |
+
|
36 |
+
# Model
|
37 |
+
self.model_cfg = ModelConfig()
|
38 |
+
self.model_args = self.model_cfg.get_model_args()
|
39 |
+
|
40 |
+
# Dataset
|
41 |
+
self.filter_category = None
|
42 |
+
|
43 |
+
self.train_ratio = 1.0
|
44 |
+
|
45 |
+
self.max_num_groups = 8
|
46 |
+
self.max_total_len = 50
|
47 |
+
|
48 |
+
# Dataloader
|
49 |
+
self.loader_num_workers = 4 * num_gpus
|
50 |
+
|
51 |
+
# Training
|
52 |
+
self.num_epochs = 50
|
53 |
+
self.val_every = 1000
|
54 |
+
|
55 |
+
# Optimization
|
56 |
+
self.learning_rate = 1e-3 * num_gpus
|
57 |
+
self.batch_size = 60 * num_gpus
|
58 |
+
self.grad_clip = 1.0
|
59 |
+
|
60 |
+
def make_schedulers(self, optimizers, epoch_size):
|
61 |
+
optimizer, = optimizers
|
62 |
+
return [lr_scheduler.StepLR(optimizer, step_size=2.5 * epoch_size, gamma=0.9)]
|
63 |
+
|
64 |
+
def make_model(self):
|
65 |
+
return SVGTransformer(self.model_cfg)
|
66 |
+
|
67 |
+
def make_losses(self):
|
68 |
+
return [SVGLoss(self.model_cfg)]
|
69 |
+
|
70 |
+
def get_weights(self, step, epoch):
|
71 |
+
return {
|
72 |
+
"kl_tolerance": 0.1,
|
73 |
+
"loss_kl_weight": linear(0, 10, step, 0, 10000),
|
74 |
+
"loss_hierarch_weight": 1.0,
|
75 |
+
"loss_cmd_weight": 1.0,
|
76 |
+
"loss_args_weight": 2.0,
|
77 |
+
"loss_visibility_weight": 1.0
|
78 |
+
}
|
79 |
+
|
80 |
+
def set_train_vars(self, train_vars, dataloader):
|
81 |
+
train_vars.x_inputs_train = [dataloader.dataset.get(idx, [*self.model_args, "tensor_grouped"])
|
82 |
+
for idx in random.sample(range(len(dataloader.dataset)), k=10)]
|
83 |
+
|
84 |
+
def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
|
85 |
+
device = next(model.parameters()).device
|
86 |
+
|
87 |
+
# Reconstruction
|
88 |
+
for i, data in enumerate(train_vars.x_inputs_train):
|
89 |
+
model_args = batchify((data[key] for key in self.model_args), device)
|
90 |
+
commands_y, args_y = model.module.greedy_sample(*model_args)
|
91 |
+
tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
|
92 |
+
|
93 |
+
try:
|
94 |
+
svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color("random")
|
95 |
+
except:
|
96 |
+
continue
|
97 |
+
|
98 |
+
tensor_target = data["tensor_grouped"][0].copy().drop_sos().unpad()
|
99 |
+
svg_path_gt = SVG.from_tensor(tensor_target.data, viewbox=Bbox(256)).normalize().split_paths().set_color("random")
|
100 |
+
|
101 |
+
img = make_grid([svg_path_sample, svg_path_gt]).draw(do_display=False, return_png=True, fill=False, with_points=False)
|
102 |
+
summary_writer.add_image(f"reconstructions_train/{i}", TF.to_tensor(img), step)
|
src/preprocessing/deepsvg/deepsvg_dataloader/svg_dataset.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from src.preprocessing.deepsvg.deepsvg_config.config import _Config
|
7 |
+
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
|
8 |
+
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
|
9 |
+
from src.preprocessing.deepsvg.deepsvg_svglib.geom import Point
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.utils.data
|
14 |
+
import random
|
15 |
+
from typing import List, Union
|
16 |
+
import pandas as pd
|
17 |
+
import os
|
18 |
+
import pickle
|
19 |
+
Num = Union[int, float]
|
20 |
+
|
21 |
+
|
22 |
+
class SVGDataset(torch.utils.data.Dataset):
|
23 |
+
def __init__(self, data_dir, meta_filepath, model_args, max_num_groups, max_seq_len, max_total_len=None,
|
24 |
+
filter_uni=None, filter_platform=None, filter_category=None, train_ratio=1.0, df=None, PAD_VAL=-1,
|
25 |
+
nb_augmentations=1, already_preprocessed=True):
|
26 |
+
self.data_dir = data_dir
|
27 |
+
|
28 |
+
self.already_preprocessed = already_preprocessed
|
29 |
+
|
30 |
+
self.MAX_NUM_GROUPS = max_num_groups
|
31 |
+
self.MAX_SEQ_LEN = max_seq_len
|
32 |
+
self.MAX_TOTAL_LEN = max_total_len
|
33 |
+
|
34 |
+
if max_total_len is None:
|
35 |
+
self.MAX_TOTAL_LEN = max_num_groups * max_seq_len
|
36 |
+
|
37 |
+
if df is None:
|
38 |
+
df = pd.read_csv(meta_filepath)
|
39 |
+
|
40 |
+
if len(df) > 0:
|
41 |
+
if filter_uni is not None:
|
42 |
+
df = df[df.uni.isin(filter_uni)]
|
43 |
+
|
44 |
+
if filter_platform is not None:
|
45 |
+
df = df[df.platform.isin(filter_platform)]
|
46 |
+
|
47 |
+
if filter_category is not None:
|
48 |
+
df = df[df.category.isin(filter_category)]
|
49 |
+
|
50 |
+
df = df[(df.nb_groups <= max_num_groups) & (df.max_len_group <= max_seq_len)]
|
51 |
+
if max_total_len is not None:
|
52 |
+
df = df[df.total_len <= max_total_len]
|
53 |
+
|
54 |
+
self.df = df.sample(frac=train_ratio) if train_ratio < 1.0 else df
|
55 |
+
|
56 |
+
self.model_args = model_args
|
57 |
+
|
58 |
+
self.PAD_VAL = PAD_VAL
|
59 |
+
|
60 |
+
self.nb_augmentations = nb_augmentations
|
61 |
+
|
62 |
+
def search_name(self, name):
|
63 |
+
return self.df[self.df.commonName.str.contains(name)]
|
64 |
+
|
65 |
+
def _filter_categories(self, filter_category):
|
66 |
+
self.df = self.df[self.df.category.isin(filter_category)]
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def _uni_to_label(uni):
|
70 |
+
if 48 <= uni <= 57:
|
71 |
+
return uni - 48
|
72 |
+
elif 65 <= uni <= 90:
|
73 |
+
return uni - 65 + 10
|
74 |
+
return uni - 97 + 36
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def _label_to_uni(label_id):
|
78 |
+
if 0 <= label_id <= 9:
|
79 |
+
return label_id + 48
|
80 |
+
elif 10 <= label_id <= 35:
|
81 |
+
return label_id + 65 - 10
|
82 |
+
return label_id + 97 - 36
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def _category_to_label(category):
|
86 |
+
categories = ['characters', 'free-icons', 'logos', 'alphabet', 'animals', 'arrows', 'astrology', 'baby', 'beauty',
|
87 |
+
'business', 'cinema', 'city', 'clothing', 'computer-hardware', 'crime', 'cultures', 'data', 'diy',
|
88 |
+
'drinks', 'ecommerce', 'editing', 'files', 'finance', 'folders', 'food', 'gaming', 'hands', 'healthcare',
|
89 |
+
'holidays', 'household', 'industry', 'maps', 'media-controls', 'messaging', 'military', 'mobile',
|
90 |
+
'music', 'nature', 'network', 'photo-video', 'plants', 'printing', 'profile', 'programming', 'science',
|
91 |
+
'security', 'shopping', 'social-networks', 'sports', 'time-and-date', 'transport', 'travel', 'user-interface',
|
92 |
+
'users', 'weather', 'flags', 'emoji', 'men', 'women']
|
93 |
+
return categories.index(category)
|
94 |
+
|
95 |
+
def get_label(self, idx=0, entry=None):
|
96 |
+
if entry is None:
|
97 |
+
entry = self.df.iloc[idx]
|
98 |
+
|
99 |
+
if "uni" in self.df.columns: # Font dataset
|
100 |
+
label = self._uni_to_label(entry.uni)
|
101 |
+
return torch.tensor(label)
|
102 |
+
elif "category" in self.df.columns: # Icons dataset
|
103 |
+
label = self._category_to_label(entry.category)
|
104 |
+
return torch.tensor(label)
|
105 |
+
|
106 |
+
return None
|
107 |
+
|
108 |
+
def idx_to_id(self, idx):
|
109 |
+
return self.df.iloc[idx].id
|
110 |
+
|
111 |
+
def entry_from_id(self, id):
|
112 |
+
return self.df[self.df.id == str(id)].iloc[0]
|
113 |
+
|
114 |
+
def _load_svg(self, icon_id):
|
115 |
+
svg = SVG.load_svg(os.path.join(self.data_dir, f"{icon_id}.svg"))
|
116 |
+
|
117 |
+
if not self.already_preprocessed:
|
118 |
+
svg.fill_(False)
|
119 |
+
svg.normalize().zoom(0.9)
|
120 |
+
svg.canonicalize()
|
121 |
+
svg = svg.simplify_heuristic()
|
122 |
+
|
123 |
+
return svg
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
return len(self.df) * self.nb_augmentations
|
127 |
+
|
128 |
+
def random_icon(self):
|
129 |
+
return self[random.randrange(0, len(self))]
|
130 |
+
|
131 |
+
def random_id(self):
|
132 |
+
idx = random.randrange(0, len(self)) % len(self.df)
|
133 |
+
return self.idx_to_id(idx)
|
134 |
+
|
135 |
+
def random_id_by_uni(self, uni):
|
136 |
+
df = self.df[self.df.uni == uni]
|
137 |
+
return df.id.sample().iloc[0]
|
138 |
+
|
139 |
+
def __getitem__(self, idx):
|
140 |
+
return self.get(idx, self.model_args)
|
141 |
+
|
142 |
+
@staticmethod
|
143 |
+
def _augment(svg, mean=False):
|
144 |
+
dx, dy = (0, 0) if mean else (5 * random.random() - 2.5, 5 * random.random() - 2.5)
|
145 |
+
factor = 0.7 if mean else 0.2 * random.random() + 0.6
|
146 |
+
|
147 |
+
return svg.zoom(factor).translate(Point(dx, dy))
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def simplify(svg, normalize=True):
|
151 |
+
svg.canonicalize(normalize=normalize)
|
152 |
+
svg = svg.simplify_heuristic()
|
153 |
+
return svg.normalize()
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def preprocess(svg, augment=True, numericalize=True, mean=False):
|
157 |
+
if augment:
|
158 |
+
svg = SVGDataset._augment(svg, mean=mean)
|
159 |
+
if numericalize:
|
160 |
+
return svg.numericalize(256)
|
161 |
+
return svg
|
162 |
+
|
163 |
+
def get(self, idx=0, model_args=None, random_aug=True, id=None, svg: SVG=None):
|
164 |
+
if id is None:
|
165 |
+
idx = idx % len(self.df)
|
166 |
+
id = self.idx_to_id(idx)
|
167 |
+
|
168 |
+
if svg is None:
|
169 |
+
svg = self._load_svg(id)
|
170 |
+
|
171 |
+
svg = SVGDataset.preprocess(svg, augment=random_aug)
|
172 |
+
|
173 |
+
t_sep, fillings = svg.to_tensor(concat_groups=False, PAD_VAL=self.PAD_VAL), svg.to_fillings()
|
174 |
+
|
175 |
+
# Note: DeepSVG can only handle 8 paths in a SVG and 30 sequences per path
|
176 |
+
if len(t_sep) > 8:
|
177 |
+
#print(f"SVG {id} has more than 30 segments.")
|
178 |
+
t_sep = t_sep[0:8]
|
179 |
+
fillings = fillings[0:8]
|
180 |
+
|
181 |
+
for i in range(len(t_sep)):
|
182 |
+
if len(t_sep[i]) > 30:
|
183 |
+
#print(f"SVG {id}: Path nr {i} has more than 30 segments.")
|
184 |
+
t_sep[i] = t_sep[i][0:30]
|
185 |
+
|
186 |
+
label = self.get_label(idx)
|
187 |
+
|
188 |
+
return self.get_data(t_sep, fillings, model_args=model_args, label=label)
|
189 |
+
|
190 |
+
def get_data(self, t_sep, fillings, model_args=None, label=None):
|
191 |
+
res = {}
|
192 |
+
|
193 |
+
if model_args is None:
|
194 |
+
model_args = self.model_args
|
195 |
+
|
196 |
+
pad_len = max(self.MAX_NUM_GROUPS - len(t_sep), 0)
|
197 |
+
|
198 |
+
t_sep.extend([torch.empty(0, 14)] * pad_len)
|
199 |
+
fillings.extend([0] * pad_len)
|
200 |
+
|
201 |
+
t_grouped = [SVGTensor.from_data(torch.cat(t_sep, dim=0), PAD_VAL=self.PAD_VAL).add_eos().add_sos().pad(
|
202 |
+
seq_len=self.MAX_TOTAL_LEN + 2)]
|
203 |
+
|
204 |
+
t_sep = [SVGTensor.from_data(t, PAD_VAL=self.PAD_VAL, filling=f).add_eos().add_sos().pad(seq_len=self.MAX_SEQ_LEN + 2) for
|
205 |
+
t, f in zip(t_sep, fillings)]
|
206 |
+
|
207 |
+
for arg in set(model_args):
|
208 |
+
if "_grouped" in arg:
|
209 |
+
arg_ = arg.split("_grouped")[0]
|
210 |
+
t_list = t_grouped
|
211 |
+
else:
|
212 |
+
arg_ = arg
|
213 |
+
t_list = t_sep
|
214 |
+
|
215 |
+
if arg_ == "tensor":
|
216 |
+
res[arg] = t_list
|
217 |
+
|
218 |
+
if arg_ == "commands":
|
219 |
+
res[arg] = torch.stack([t.cmds() for t in t_list])
|
220 |
+
|
221 |
+
if arg_ == "args_rel":
|
222 |
+
res[arg] = torch.stack([t.get_relative_args() for t in t_list])
|
223 |
+
if arg_ == "args":
|
224 |
+
res[arg] = torch.stack([t.args() for t in t_list])
|
225 |
+
|
226 |
+
if "filling" in model_args:
|
227 |
+
res["filling"] = torch.stack([torch.tensor(t.filling) for t in t_sep]).unsqueeze(-1)
|
228 |
+
|
229 |
+
if "label" in model_args:
|
230 |
+
res["label"] = label
|
231 |
+
|
232 |
+
return res
|
233 |
+
|
234 |
+
|
235 |
+
def load_dataset(cfg: _Config, already_preprocessed=True):
|
236 |
+
dataset = SVGDataset(cfg.data_dir, cfg.meta_filepath, cfg.model_args, cfg.max_num_groups, cfg.max_seq_len, cfg.max_total_len,
|
237 |
+
cfg.filter_uni, cfg.filter_platform, cfg.filter_category, cfg.train_ratio,
|
238 |
+
nb_augmentations=cfg.nb_augmentations, already_preprocessed=already_preprocessed)
|
239 |
+
return dataset
|
src/preprocessing/deepsvg/deepsvg_difflib/tensor.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
import torch
|
8 |
+
import torch.utils.data
|
9 |
+
from typing import Union
|
10 |
+
Num = Union[int, float]
|
11 |
+
|
12 |
+
|
13 |
+
class AnimationTensor:
|
14 |
+
|
15 |
+
COMMANDS_SIMPLIFIED = ['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9']
|
16 |
+
|
17 |
+
CMD_ARGS_MASK = torch.tensor([[0, 0, 0], # a0
|
18 |
+
[0, 0, 0], # a1
|
19 |
+
[0, 0, 0], # a2
|
20 |
+
[1, 1, 1], # a3
|
21 |
+
[0, 0, 0], # a4
|
22 |
+
[0, 0, 0], # a5
|
23 |
+
[0, 0, 0], # a6
|
24 |
+
[0, 0, 0], # a7
|
25 |
+
[1, 1, 1], # a8
|
26 |
+
[0, 0, 0]]) # a9
|
27 |
+
|
28 |
+
class Index:
|
29 |
+
COMMAND = 0
|
30 |
+
DURATION = 1
|
31 |
+
FROM = 2
|
32 |
+
BEGIN = 3
|
33 |
+
|
34 |
+
class IndexArgs:
|
35 |
+
DURATION = 0
|
36 |
+
FROM = 1
|
37 |
+
BEGIN = 2
|
38 |
+
|
39 |
+
all_arg_keys = ['duration', 'from', 'begin']
|
40 |
+
cmd_arg_keys = ["commands", *all_arg_keys]
|
41 |
+
all_keys = ["commands", *all_arg_keys]
|
42 |
+
|
43 |
+
def __init__(self, commands, duration, from_, begin,
|
44 |
+
seq_len=None, label=None, PAD_VAL=-1, ARGS_DIM=256, filling=0):
|
45 |
+
|
46 |
+
self.commands = commands.reshape(-1, 1).float()
|
47 |
+
|
48 |
+
self.duration = duration.float()
|
49 |
+
self.from_ = from_.float()
|
50 |
+
self.begin = begin.float()
|
51 |
+
|
52 |
+
self.seq_len = torch.tensor(len(commands)) if seq_len is None else seq_len
|
53 |
+
self.label = label
|
54 |
+
|
55 |
+
self.PAD_VAL = PAD_VAL
|
56 |
+
self.ARGS_DIM = ARGS_DIM
|
57 |
+
|
58 |
+
# self.sos_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("SOS")]).unsqueeze(-1)
|
59 |
+
# self.eos_token = self.pad_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("EOS")]).unsqueeze(-1)
|
60 |
+
|
61 |
+
self.filling = filling
|
62 |
+
|
63 |
+
|
64 |
+
class SVGTensor:
|
65 |
+
# 0 1 2 3 4 5 6
|
66 |
+
COMMANDS_SIMPLIFIED = ["m", "l", "c", "a", "EOS", "SOS", "z"]
|
67 |
+
|
68 |
+
# rad x lrg sw ctrl ctrl end
|
69 |
+
# ius axs arc eep 1 2 pos
|
70 |
+
# rot fg fg
|
71 |
+
CMD_ARGS_MASK = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # m
|
72 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # l
|
73 |
+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], # c
|
74 |
+
[1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1], # a
|
75 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # EOS
|
76 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # SOS
|
77 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) # z
|
78 |
+
|
79 |
+
class Index:
|
80 |
+
COMMAND = 0
|
81 |
+
RADIUS = slice(1, 3)
|
82 |
+
X_AXIS_ROT = 3
|
83 |
+
LARGE_ARC_FLG = 4
|
84 |
+
SWEEP_FLG = 5
|
85 |
+
START_POS = slice(6, 8)
|
86 |
+
CONTROL1 = slice(8, 10)
|
87 |
+
CONTROL2 = slice(10, 12)
|
88 |
+
END_POS = slice(12, 14)
|
89 |
+
|
90 |
+
class IndexArgs:
|
91 |
+
RADIUS = slice(0, 2)
|
92 |
+
X_AXIS_ROT = 2
|
93 |
+
LARGE_ARC_FLG = 3
|
94 |
+
SWEEP_FLG = 4
|
95 |
+
CONTROL1 = slice(5, 7)
|
96 |
+
CONTROL2 = slice(7, 9)
|
97 |
+
END_POS = slice(9, 11)
|
98 |
+
|
99 |
+
position_keys = ["control1", "control2", "end_pos"]
|
100 |
+
all_position_keys = ["start_pos", *position_keys]
|
101 |
+
arg_keys = ["radius", "x_axis_rot", "large_arc_flg", "sweep_flg", *position_keys]
|
102 |
+
all_arg_keys = [*arg_keys[:4], "start_pos", *arg_keys[4:]]
|
103 |
+
cmd_arg_keys = ["commands", *arg_keys]
|
104 |
+
all_keys = ["commands", *all_arg_keys]
|
105 |
+
|
106 |
+
def __init__(self, commands, radius, x_axis_rot, large_arc_flg, sweep_flg, control1, control2, end_pos,
|
107 |
+
seq_len=None, label=None, PAD_VAL=-1, ARGS_DIM=256, filling=0):
|
108 |
+
|
109 |
+
self.commands = commands.reshape(-1, 1).float()
|
110 |
+
|
111 |
+
self.radius = radius.float()
|
112 |
+
self.x_axis_rot = x_axis_rot.reshape(-1, 1).float()
|
113 |
+
self.large_arc_flg = large_arc_flg.reshape(-1, 1).float()
|
114 |
+
self.sweep_flg = sweep_flg.reshape(-1, 1).float()
|
115 |
+
|
116 |
+
self.control1 = control1.float()
|
117 |
+
self.control2 = control2.float()
|
118 |
+
self.end_pos = end_pos.float()
|
119 |
+
|
120 |
+
self.seq_len = torch.tensor(len(commands)) if seq_len is None else seq_len
|
121 |
+
self.label = label
|
122 |
+
|
123 |
+
self.PAD_VAL = PAD_VAL
|
124 |
+
self.ARGS_DIM = ARGS_DIM
|
125 |
+
|
126 |
+
self.sos_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("SOS")]).unsqueeze(-1)
|
127 |
+
self.eos_token = self.pad_token = torch.Tensor([self.COMMANDS_SIMPLIFIED.index("EOS")]).unsqueeze(-1)
|
128 |
+
|
129 |
+
self.filling = filling
|
130 |
+
|
131 |
+
@property
|
132 |
+
def start_pos(self):
|
133 |
+
start_pos = self.end_pos[:-1]
|
134 |
+
|
135 |
+
return torch.cat([
|
136 |
+
start_pos.new_zeros(1, 2),
|
137 |
+
start_pos
|
138 |
+
])
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def from_data(data, *args, **kwargs):
|
142 |
+
return SVGTensor(data[:, SVGTensor.Index.COMMAND], data[:, SVGTensor.Index.RADIUS], data[:, SVGTensor.Index.X_AXIS_ROT],
|
143 |
+
data[:, SVGTensor.Index.LARGE_ARC_FLG], data[:, SVGTensor.Index.SWEEP_FLG], data[:, SVGTensor.Index.CONTROL1],
|
144 |
+
data[:, SVGTensor.Index.CONTROL2], data[:, SVGTensor.Index.END_POS], *args, **kwargs)
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def from_cmd_args(commands, args, *nargs, **kwargs):
|
148 |
+
return SVGTensor(commands, args[:, SVGTensor.IndexArgs.RADIUS], args[:, SVGTensor.IndexArgs.X_AXIS_ROT],
|
149 |
+
args[:, SVGTensor.IndexArgs.LARGE_ARC_FLG], args[:, SVGTensor.IndexArgs.SWEEP_FLG], args[:, SVGTensor.IndexArgs.CONTROL1],
|
150 |
+
args[:, SVGTensor.IndexArgs.CONTROL2], args[:, SVGTensor.IndexArgs.END_POS], *nargs, **kwargs)
|
151 |
+
|
152 |
+
def get_data(self, keys):
|
153 |
+
return torch.cat([self.__getattribute__(key) for key in keys], dim=-1)
|
154 |
+
|
155 |
+
@property
|
156 |
+
def data(self):
|
157 |
+
return self.get_data(self.all_keys)
|
158 |
+
|
159 |
+
def copy(self):
|
160 |
+
return SVGTensor(*[self.__getattribute__(key).clone() for key in self.cmd_arg_keys],
|
161 |
+
seq_len=self.seq_len.clone(), label=self.label, PAD_VAL=self.PAD_VAL, ARGS_DIM=self.ARGS_DIM,
|
162 |
+
filling=self.filling)
|
163 |
+
|
164 |
+
def add_sos(self):
|
165 |
+
self.commands = torch.cat([self.sos_token, self.commands])
|
166 |
+
|
167 |
+
for key in self.arg_keys:
|
168 |
+
v = self.__getattribute__(key)
|
169 |
+
self.__setattr__(key, torch.cat([v.new_full((1, v.size(-1)), self.PAD_VAL), v]))
|
170 |
+
|
171 |
+
self.seq_len += 1
|
172 |
+
return self
|
173 |
+
|
174 |
+
def drop_sos(self):
|
175 |
+
for key in self.cmd_arg_keys:
|
176 |
+
self.__setattr__(key, self.__getattribute__(key)[1:])
|
177 |
+
|
178 |
+
self.seq_len -= 1
|
179 |
+
return self
|
180 |
+
|
181 |
+
def add_eos(self):
|
182 |
+
self.commands = torch.cat([self.commands, self.eos_token])
|
183 |
+
|
184 |
+
for key in self.arg_keys:
|
185 |
+
v = self.__getattribute__(key)
|
186 |
+
self.__setattr__(key, torch.cat([v, v.new_full((1, v.size(-1)), self.PAD_VAL)]))
|
187 |
+
|
188 |
+
return self
|
189 |
+
|
190 |
+
def pad(self, seq_len=51):
|
191 |
+
pad_len = max(seq_len - len(self.commands), 0)
|
192 |
+
|
193 |
+
self.commands = torch.cat([self.commands, self.pad_token.repeat(pad_len, 1)])
|
194 |
+
|
195 |
+
for key in self.arg_keys:
|
196 |
+
v = self.__getattribute__(key)
|
197 |
+
self.__setattr__(key, torch.cat([v, v.new_full((pad_len, v.size(-1)), self.PAD_VAL)]))
|
198 |
+
|
199 |
+
return self
|
200 |
+
|
201 |
+
def unpad(self):
|
202 |
+
# Remove EOS + padding
|
203 |
+
for key in self.cmd_arg_keys:
|
204 |
+
self.__setattr__(key, self.__getattribute__(key)[:self.seq_len])
|
205 |
+
return self
|
206 |
+
|
207 |
+
def draw(self, *args, **kwags):
|
208 |
+
from deepsvg.svglib.svg import SVGPath
|
209 |
+
return SVGPath.from_tensor(self.data).draw(*args, **kwags)
|
210 |
+
|
211 |
+
def cmds(self):
|
212 |
+
return self.commands.reshape(-1)
|
213 |
+
|
214 |
+
def args(self, with_start_pos=False):
|
215 |
+
if with_start_pos:
|
216 |
+
return self.get_data(self.all_arg_keys)
|
217 |
+
|
218 |
+
return self.get_data(self.arg_keys)
|
219 |
+
|
220 |
+
def _get_real_commands_mask(self):
|
221 |
+
mask = self.cmds() < self.COMMANDS_SIMPLIFIED.index("EOS")
|
222 |
+
return mask
|
223 |
+
|
224 |
+
def _get_args_mask(self):
|
225 |
+
mask = SVGTensor.CMD_ARGS_MASK[self.cmds().long()].bool()
|
226 |
+
return mask
|
227 |
+
|
228 |
+
def get_relative_args(self):
|
229 |
+
data = self.args().clone()
|
230 |
+
|
231 |
+
real_commands = self._get_real_commands_mask()
|
232 |
+
data_real_commands = data[real_commands]
|
233 |
+
|
234 |
+
start_pos = data_real_commands[:-1, SVGTensor.IndexArgs.END_POS].clone()
|
235 |
+
|
236 |
+
data_real_commands[1:, SVGTensor.IndexArgs.CONTROL1] -= start_pos
|
237 |
+
data_real_commands[1:, SVGTensor.IndexArgs.CONTROL2] -= start_pos
|
238 |
+
data_real_commands[1:, SVGTensor.IndexArgs.END_POS] -= start_pos
|
239 |
+
data[real_commands] = data_real_commands
|
240 |
+
|
241 |
+
mask = self._get_args_mask()
|
242 |
+
data[mask] += self.ARGS_DIM - 1
|
243 |
+
data[~mask] = self.PAD_VAL
|
244 |
+
|
245 |
+
return data
|
246 |
+
|
247 |
+
def sample_points(self, n=10):
|
248 |
+
device = self.commands.device
|
249 |
+
|
250 |
+
z = torch.linspace(0, 1, n, device=device)
|
251 |
+
Z = torch.stack([torch.ones_like(z), z, z.pow(2), z.pow(3)], dim=1)
|
252 |
+
|
253 |
+
Q = torch.tensor([
|
254 |
+
[[0., 0., 0., 0.], # "m"
|
255 |
+
[0., 0., 0., 0.],
|
256 |
+
[0., 0., 0., 0.],
|
257 |
+
[0., 0., 0., 0.]],
|
258 |
+
|
259 |
+
[[1., 0., 0., 0.], # "l"
|
260 |
+
[-1, 0., 0., 1.],
|
261 |
+
[0., 0., 0., 0.],
|
262 |
+
[0., 0., 0., 0.]],
|
263 |
+
|
264 |
+
[[1., 0., 0., 0.], # "c"
|
265 |
+
[-3, 3., 0., 0.],
|
266 |
+
[3., -6, 3., 0.],
|
267 |
+
[-1, 3., -3, 1.]],
|
268 |
+
|
269 |
+
torch.zeros(4, 4), # "a", no support yet
|
270 |
+
|
271 |
+
torch.zeros(4, 4), # "EOS"
|
272 |
+
torch.zeros(4, 4), # "SOS"
|
273 |
+
torch.zeros(4, 4), # "z"
|
274 |
+
], device=device)
|
275 |
+
|
276 |
+
commands, pos = self.commands.reshape(-1).long(), self.get_data(self.all_position_keys).reshape(-1, 4, 2)
|
277 |
+
inds = (commands == self.COMMANDS_SIMPLIFIED.index("l")) | (commands == self.COMMANDS_SIMPLIFIED.index("c"))
|
278 |
+
commands, pos = commands[inds], pos[inds]
|
279 |
+
|
280 |
+
Z_coeffs = torch.matmul(Q[commands], pos)
|
281 |
+
|
282 |
+
# Last point being first point of next command, we drop last point except the one from the last command
|
283 |
+
sample_points = torch.matmul(Z, Z_coeffs)
|
284 |
+
sample_points = torch.cat([sample_points[:, :-1].reshape(-1, 2), sample_points[-1, -1].unsqueeze(0)])
|
285 |
+
|
286 |
+
return sample_points
|
287 |
+
|
288 |
+
@staticmethod
|
289 |
+
def get_length_distribution(p, normalize=True):
|
290 |
+
start, end = p[:-1], p[1:]
|
291 |
+
length_distr = torch.norm(end - start, dim=-1).cumsum(dim=0)
|
292 |
+
length_distr = torch.cat([length_distr.new_zeros(1), length_distr])
|
293 |
+
if normalize:
|
294 |
+
length_distr = length_distr / length_distr[-1]
|
295 |
+
return length_distr
|
296 |
+
|
297 |
+
def sample_uniform_points(self, n=100):
|
298 |
+
p = self.sample_points(n=n)
|
299 |
+
|
300 |
+
distr_unif = torch.linspace(0., 1., n).to(p.device)
|
301 |
+
distr = self.get_length_distribution(p, normalize=True)
|
302 |
+
d = torch.cdist(distr_unif.unsqueeze(-1), distr.unsqueeze(-1))
|
303 |
+
matching = d.argmin(dim=-1)
|
304 |
+
|
305 |
+
return p[matching]
|
src/preprocessing/deepsvg/deepsvg_models/basic_blocks.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class FCN(nn.Module):
|
11 |
+
def __init__(self, d_model, n_commands, n_args, args_dim=256):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.n_args = n_args
|
15 |
+
self.args_dim = args_dim
|
16 |
+
|
17 |
+
self.command_fcn = nn.Linear(d_model, n_commands)
|
18 |
+
self.args_fcn = nn.Linear(d_model, n_args * args_dim)
|
19 |
+
|
20 |
+
def forward(self, out):
|
21 |
+
S, N, _ = out.shape
|
22 |
+
|
23 |
+
command_logits = self.command_fcn(out) # Shape [S, N, n_commands]
|
24 |
+
|
25 |
+
args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim]
|
26 |
+
args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim]
|
27 |
+
|
28 |
+
return command_logits, args_logits
|
29 |
+
|
30 |
+
|
31 |
+
class HierarchFCN(nn.Module):
|
32 |
+
def __init__(self, d_model, dim_z):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.visibility_fcn = nn.Linear(d_model, 2)
|
36 |
+
self.z_fcn = nn.Linear(d_model, dim_z)
|
37 |
+
|
38 |
+
def forward(self, out):
|
39 |
+
G, N, _ = out.shape
|
40 |
+
|
41 |
+
visibility_logits = self.visibility_fcn(out) # Shape [G, N, 2]
|
42 |
+
z = self.z_fcn(out) # Shape [G, N, dim_z]
|
43 |
+
|
44 |
+
return visibility_logits.unsqueeze(0), z.unsqueeze(0)
|
45 |
+
|
46 |
+
|
47 |
+
class ResNet(nn.Module):
|
48 |
+
def __init__(self, d_model):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.linear1 = nn.Sequential(
|
52 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
53 |
+
)
|
54 |
+
self.linear2 = nn.Sequential(
|
55 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
56 |
+
)
|
57 |
+
self.linear3 = nn.Sequential(
|
58 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
59 |
+
)
|
60 |
+
self.linear4 = nn.Sequential(
|
61 |
+
nn.Linear(d_model, d_model), nn.ReLU()
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, z):
|
65 |
+
z = z + self.linear1(z)
|
66 |
+
z = z + self.linear2(z)
|
67 |
+
z = z + self.linear3(z)
|
68 |
+
z = z + self.linear4(z)
|
69 |
+
|
70 |
+
return z
|
src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar
ADDED
File without changes
|
src/preprocessing/deepsvg/deepsvg_models/layers/attention.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.nn import Linear
|
8 |
+
from torch.nn.init import xavier_uniform_
|
9 |
+
from torch.nn.init import constant_
|
10 |
+
from torch.nn.init import xavier_normal_
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
from torch.nn.modules.module import Module
|
13 |
+
|
14 |
+
from .functional import multi_head_attention_forward
|
15 |
+
|
16 |
+
|
17 |
+
class MultiheadAttention(Module):
|
18 |
+
r"""Allows the model to jointly attend to information
|
19 |
+
from different representation subspaces.
|
20 |
+
See reference: Attention Is All You Need
|
21 |
+
|
22 |
+
.. math::
|
23 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
24 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
25 |
+
|
26 |
+
Args:
|
27 |
+
embed_dim: total dimension of the model.
|
28 |
+
num_heads: parallel attention heads.
|
29 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
30 |
+
bias: add bias as module parameter. Default: True.
|
31 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
32 |
+
add_zero_attn: add a new batch of zeros to the key and
|
33 |
+
value sequences at dim=1.
|
34 |
+
kdim: total number of features in key. Default: None.
|
35 |
+
vdim: total number of features in key. Default: None.
|
36 |
+
|
37 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
38 |
+
query, key, and value have the same number of features.
|
39 |
+
|
40 |
+
Examples::
|
41 |
+
|
42 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
43 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
44 |
+
"""
|
45 |
+
__annotations__ = {
|
46 |
+
'bias_k': torch._jit_internal.Optional[torch.Tensor],
|
47 |
+
'bias_v': torch._jit_internal.Optional[torch.Tensor],
|
48 |
+
}
|
49 |
+
__constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
|
50 |
+
|
51 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
52 |
+
super(MultiheadAttention, self).__init__()
|
53 |
+
self.embed_dim = embed_dim
|
54 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
55 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
56 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
57 |
+
|
58 |
+
self.num_heads = num_heads
|
59 |
+
self.dropout = dropout
|
60 |
+
self.head_dim = embed_dim // num_heads
|
61 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
62 |
+
|
63 |
+
if self._qkv_same_embed_dim is False:
|
64 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
65 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
66 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
67 |
+
self.register_parameter('in_proj_weight', None)
|
68 |
+
else:
|
69 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
70 |
+
self.register_parameter('q_proj_weight', None)
|
71 |
+
self.register_parameter('k_proj_weight', None)
|
72 |
+
self.register_parameter('v_proj_weight', None)
|
73 |
+
|
74 |
+
if bias:
|
75 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
76 |
+
else:
|
77 |
+
self.register_parameter('in_proj_bias', None)
|
78 |
+
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
|
79 |
+
|
80 |
+
if add_bias_kv:
|
81 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
82 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
83 |
+
else:
|
84 |
+
self.bias_k = self.bias_v = None
|
85 |
+
|
86 |
+
self.add_zero_attn = add_zero_attn
|
87 |
+
|
88 |
+
self._reset_parameters()
|
89 |
+
|
90 |
+
def _reset_parameters(self):
|
91 |
+
if self._qkv_same_embed_dim:
|
92 |
+
xavier_uniform_(self.in_proj_weight)
|
93 |
+
else:
|
94 |
+
xavier_uniform_(self.q_proj_weight)
|
95 |
+
xavier_uniform_(self.k_proj_weight)
|
96 |
+
xavier_uniform_(self.v_proj_weight)
|
97 |
+
|
98 |
+
if self.in_proj_bias is not None:
|
99 |
+
constant_(self.in_proj_bias, 0.)
|
100 |
+
constant_(self.out_proj.bias, 0.)
|
101 |
+
if self.bias_k is not None:
|
102 |
+
xavier_normal_(self.bias_k)
|
103 |
+
if self.bias_v is not None:
|
104 |
+
xavier_normal_(self.bias_v)
|
105 |
+
|
106 |
+
def __setstate__(self, state):
|
107 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
108 |
+
if '_qkv_same_embed_dim' not in state:
|
109 |
+
state['_qkv_same_embed_dim'] = True
|
110 |
+
|
111 |
+
super(MultiheadAttention, self).__setstate__(state)
|
112 |
+
|
113 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
114 |
+
need_weights=True, attn_mask=None):
|
115 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
116 |
+
r"""
|
117 |
+
Args:
|
118 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
119 |
+
See "Attention Is All You Need" for more details.
|
120 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
121 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
122 |
+
the corresponding value on the attention layer will be filled with -inf.
|
123 |
+
need_weights: output attn_output_weights.
|
124 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
|
125 |
+
(i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
|
126 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
127 |
+
|
128 |
+
Shape:
|
129 |
+
- Inputs:
|
130 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
131 |
+
the embedding dimension.
|
132 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
133 |
+
the embedding dimension.
|
134 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
135 |
+
the embedding dimension.
|
136 |
+
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
137 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
138 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
139 |
+
S is the source sequence length.
|
140 |
+
|
141 |
+
- Outputs:
|
142 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
143 |
+
E is the embedding dimension.
|
144 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
145 |
+
L is the target sequence length, S is the source sequence length.
|
146 |
+
"""
|
147 |
+
if not self._qkv_same_embed_dim:
|
148 |
+
return multi_head_attention_forward(
|
149 |
+
query, key, value, self.embed_dim, self.num_heads,
|
150 |
+
self.in_proj_weight, self.in_proj_bias,
|
151 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
152 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
153 |
+
training=self.training,
|
154 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
155 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
156 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
157 |
+
v_proj_weight=self.v_proj_weight)
|
158 |
+
else:
|
159 |
+
return multi_head_attention_forward(
|
160 |
+
query, key, value, self.embed_dim, self.num_heads,
|
161 |
+
self.in_proj_weight, self.in_proj_bias,
|
162 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
163 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
164 |
+
training=self.training,
|
165 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
166 |
+
attn_mask=attn_mask)
|
src/preprocessing/deepsvg/deepsvg_models/layers/functional.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import division
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def multi_head_attention_forward(query, # type: Tensor
|
14 |
+
key, # type: Tensor
|
15 |
+
value, # type: Tensor
|
16 |
+
embed_dim_to_check, # type: int
|
17 |
+
num_heads, # type: int
|
18 |
+
in_proj_weight, # type: Tensor
|
19 |
+
in_proj_bias, # type: Tensor
|
20 |
+
bias_k, # type: Optional[Tensor]
|
21 |
+
bias_v, # type: Optional[Tensor]
|
22 |
+
add_zero_attn, # type: bool
|
23 |
+
dropout_p, # type: float
|
24 |
+
out_proj_weight, # type: Tensor
|
25 |
+
out_proj_bias, # type: Tensor
|
26 |
+
training=True, # type: bool
|
27 |
+
key_padding_mask=None, # type: Optional[Tensor]
|
28 |
+
need_weights=True, # type: bool
|
29 |
+
attn_mask=None, # type: Optional[Tensor]
|
30 |
+
use_separate_proj_weight=False, # type: bool
|
31 |
+
q_proj_weight=None, # type: Optional[Tensor]
|
32 |
+
k_proj_weight=None, # type: Optional[Tensor]
|
33 |
+
v_proj_weight=None, # type: Optional[Tensor]
|
34 |
+
static_k=None, # type: Optional[Tensor]
|
35 |
+
static_v=None # type: Optional[Tensor]
|
36 |
+
):
|
37 |
+
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
|
38 |
+
r"""
|
39 |
+
Args:
|
40 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
41 |
+
See "Attention Is All You Need" for more details.
|
42 |
+
embed_dim_to_check: total dimension of the model.
|
43 |
+
num_heads: parallel attention heads.
|
44 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
45 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
46 |
+
add_zero_attn: add a new batch of zeros to the key and
|
47 |
+
value sequences at dim=1.
|
48 |
+
dropout_p: probability of an element to be zeroed.
|
49 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
50 |
+
training: apply dropout if is ``True``.
|
51 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
52 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
53 |
+
the corresponding value on the attention layer will be filled with -inf.
|
54 |
+
need_weights: output attn_output_weights.
|
55 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
|
56 |
+
(i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
|
57 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
58 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
59 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
60 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
61 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
62 |
+
static_k, static_v: static key and value used for attention operators.
|
63 |
+
Shape:
|
64 |
+
Inputs:
|
65 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
66 |
+
the embedding dimension.
|
67 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
68 |
+
the embedding dimension.
|
69 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
70 |
+
the embedding dimension.
|
71 |
+
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
72 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
73 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
74 |
+
S is the source sequence length.
|
75 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
76 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
77 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
78 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
79 |
+
Outputs:
|
80 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
81 |
+
E is the embedding dimension.
|
82 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
83 |
+
L is the target sequence length, S is the source sequence length.
|
84 |
+
"""
|
85 |
+
|
86 |
+
tgt_len, bsz, embed_dim = query.size()
|
87 |
+
assert embed_dim == embed_dim_to_check
|
88 |
+
assert key.size() == value.size()
|
89 |
+
|
90 |
+
head_dim = embed_dim // num_heads
|
91 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
92 |
+
scaling = float(head_dim) ** -0.5
|
93 |
+
|
94 |
+
if not use_separate_proj_weight:
|
95 |
+
if torch.equal(query, key) and torch.equal(key, value):
|
96 |
+
# self-attention
|
97 |
+
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
98 |
+
|
99 |
+
elif torch.equal(key, value):
|
100 |
+
# encoder-decoder attention
|
101 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
102 |
+
_b = in_proj_bias
|
103 |
+
_start = 0
|
104 |
+
_end = embed_dim
|
105 |
+
_w = in_proj_weight[_start:_end, :]
|
106 |
+
if _b is not None:
|
107 |
+
_b = _b[_start:_end]
|
108 |
+
q = F.linear(query, _w, _b)
|
109 |
+
|
110 |
+
if key is None:
|
111 |
+
assert value is None
|
112 |
+
k = None
|
113 |
+
v = None
|
114 |
+
else:
|
115 |
+
|
116 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
117 |
+
_b = in_proj_bias
|
118 |
+
_start = embed_dim
|
119 |
+
_end = None
|
120 |
+
_w = in_proj_weight[_start:, :]
|
121 |
+
if _b is not None:
|
122 |
+
_b = _b[_start:]
|
123 |
+
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
124 |
+
|
125 |
+
else:
|
126 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
127 |
+
_b = in_proj_bias
|
128 |
+
_start = 0
|
129 |
+
_end = embed_dim
|
130 |
+
_w = in_proj_weight[_start:_end, :]
|
131 |
+
if _b is not None:
|
132 |
+
_b = _b[_start:_end]
|
133 |
+
q = F.linear(query, _w, _b)
|
134 |
+
|
135 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
136 |
+
_b = in_proj_bias
|
137 |
+
_start = embed_dim
|
138 |
+
_end = embed_dim * 2
|
139 |
+
_w = in_proj_weight[_start:_end, :]
|
140 |
+
if _b is not None:
|
141 |
+
_b = _b[_start:_end]
|
142 |
+
k = F.linear(key, _w, _b)
|
143 |
+
|
144 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
145 |
+
_b = in_proj_bias
|
146 |
+
_start = embed_dim * 2
|
147 |
+
_end = None
|
148 |
+
_w = in_proj_weight[_start:, :]
|
149 |
+
if _b is not None:
|
150 |
+
_b = _b[_start:]
|
151 |
+
v = F.linear(value, _w, _b)
|
152 |
+
else:
|
153 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
154 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
155 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
156 |
+
|
157 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
158 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
159 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
160 |
+
|
161 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
162 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
163 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
164 |
+
|
165 |
+
if in_proj_bias is not None:
|
166 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
167 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
168 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
169 |
+
else:
|
170 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
171 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
172 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
173 |
+
q = q * scaling
|
174 |
+
|
175 |
+
if attn_mask is not None:
|
176 |
+
if attn_mask.dim() == 2:
|
177 |
+
attn_mask = attn_mask.unsqueeze(0)
|
178 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
179 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
180 |
+
elif attn_mask.dim() == 3:
|
181 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
182 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
183 |
+
else:
|
184 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
185 |
+
# attn_mask's dim is 3 now.
|
186 |
+
|
187 |
+
if bias_k is not None and bias_v is not None:
|
188 |
+
if static_k is None and static_v is None:
|
189 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
190 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
191 |
+
if attn_mask is not None:
|
192 |
+
attn_mask = F.pad(attn_mask, (0, 1))
|
193 |
+
if key_padding_mask is not None:
|
194 |
+
key_padding_mask = F.pad(key_padding_mask, (0, 1))
|
195 |
+
else:
|
196 |
+
assert static_k is None, "bias cannot be added to static key."
|
197 |
+
assert static_v is None, "bias cannot be added to static value."
|
198 |
+
else:
|
199 |
+
assert bias_k is None
|
200 |
+
assert bias_v is None
|
201 |
+
|
202 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
203 |
+
if k is not None:
|
204 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
205 |
+
if v is not None:
|
206 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
207 |
+
|
208 |
+
if static_k is not None:
|
209 |
+
assert static_k.size(0) == bsz * num_heads
|
210 |
+
assert static_k.size(2) == head_dim
|
211 |
+
k = static_k
|
212 |
+
|
213 |
+
if static_v is not None:
|
214 |
+
assert static_v.size(0) == bsz * num_heads
|
215 |
+
assert static_v.size(2) == head_dim
|
216 |
+
v = static_v
|
217 |
+
|
218 |
+
src_len = k.size(1)
|
219 |
+
|
220 |
+
if key_padding_mask is not None:
|
221 |
+
assert key_padding_mask.size(0) == bsz
|
222 |
+
assert key_padding_mask.size(1) == src_len
|
223 |
+
|
224 |
+
if add_zero_attn:
|
225 |
+
src_len += 1
|
226 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
227 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
228 |
+
if attn_mask is not None:
|
229 |
+
attn_mask = F.pad(attn_mask, (0, 1))
|
230 |
+
if key_padding_mask is not None:
|
231 |
+
key_padding_mask = F.pad(key_padding_mask, (0, 1))
|
232 |
+
|
233 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
234 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
235 |
+
|
236 |
+
if attn_mask is not None:
|
237 |
+
attn_output_weights += attn_mask
|
238 |
+
|
239 |
+
if key_padding_mask is not None:
|
240 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
241 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
242 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
243 |
+
float('-inf'),
|
244 |
+
)
|
245 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
246 |
+
|
247 |
+
attn_output_weights = F.softmax(
|
248 |
+
attn_output_weights, dim=-1)
|
249 |
+
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
250 |
+
|
251 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
252 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
253 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
254 |
+
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
255 |
+
|
256 |
+
if need_weights:
|
257 |
+
# average attention weights over heads
|
258 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
259 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
260 |
+
else:
|
261 |
+
return attn_output, None
|
src/preprocessing/deepsvg/deepsvg_models/layers/improved_transformer.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn.modules.module import Module
|
11 |
+
from torch.nn.modules.container import ModuleList
|
12 |
+
from torch.nn.init import xavier_uniform_
|
13 |
+
from torch.nn.modules.dropout import Dropout
|
14 |
+
from torch.nn.modules.linear import Linear
|
15 |
+
from torch.nn.modules.normalization import LayerNorm
|
16 |
+
|
17 |
+
from .attention import MultiheadAttention
|
18 |
+
from .transformer import _get_activation_fn
|
19 |
+
|
20 |
+
|
21 |
+
class TransformerEncoderLayerImproved(Module):
|
22 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
|
23 |
+
super(TransformerEncoderLayerImproved, self).__init__()
|
24 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
25 |
+
|
26 |
+
if d_global2 is not None:
|
27 |
+
self.linear_global2 = Linear(d_global2, d_model)
|
28 |
+
|
29 |
+
# Implementation of Feedforward model
|
30 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
31 |
+
self.dropout = Dropout(dropout)
|
32 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
33 |
+
|
34 |
+
self.norm1 = LayerNorm(d_model)
|
35 |
+
self.norm2 = LayerNorm(d_model)
|
36 |
+
self.dropout1 = Dropout(dropout)
|
37 |
+
self.dropout2_2 = Dropout(dropout)
|
38 |
+
self.dropout2 = Dropout(dropout)
|
39 |
+
|
40 |
+
self.activation = _get_activation_fn(activation)
|
41 |
+
|
42 |
+
def __setstate__(self, state):
|
43 |
+
if 'activation' not in state:
|
44 |
+
state['activation'] = F.relu
|
45 |
+
super(TransformerEncoderLayerImproved, self).__setstate__(state)
|
46 |
+
|
47 |
+
def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None):
|
48 |
+
src1 = self.norm1(src)
|
49 |
+
src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
50 |
+
src = src + self.dropout1(src2)
|
51 |
+
|
52 |
+
if memory2 is not None:
|
53 |
+
src2_2 = self.linear_global2(memory2)
|
54 |
+
src = src + self.dropout2_2(src2_2)
|
55 |
+
|
56 |
+
src1 = self.norm2(src)
|
57 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src1))))
|
58 |
+
src = src + self.dropout2(src2)
|
59 |
+
return src
|
60 |
+
|
61 |
+
|
62 |
+
class TransformerDecoderLayerImproved(Module):
|
63 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
64 |
+
super(TransformerDecoderLayerImproved, self).__init__()
|
65 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
66 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
67 |
+
# Implementation of Feedforward model
|
68 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
69 |
+
self.dropout = Dropout(dropout)
|
70 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
71 |
+
|
72 |
+
self.norm1 = LayerNorm(d_model)
|
73 |
+
self.norm2 = LayerNorm(d_model)
|
74 |
+
self.norm3 = LayerNorm(d_model)
|
75 |
+
self.dropout1 = Dropout(dropout)
|
76 |
+
self.dropout2 = Dropout(dropout)
|
77 |
+
self.dropout3 = Dropout(dropout)
|
78 |
+
|
79 |
+
self.activation = _get_activation_fn(activation)
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
if 'activation' not in state:
|
83 |
+
state['activation'] = F.relu
|
84 |
+
super(TransformerDecoderLayerImproved, self).__setstate__(state)
|
85 |
+
|
86 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
87 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
88 |
+
tgt1 = self.norm1(tgt)
|
89 |
+
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
90 |
+
tgt = tgt + self.dropout1(tgt2)
|
91 |
+
|
92 |
+
tgt1 = self.norm2(tgt)
|
93 |
+
tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
|
94 |
+
tgt = tgt + self.dropout2(tgt2)
|
95 |
+
|
96 |
+
tgt1 = self.norm3(tgt)
|
97 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
|
98 |
+
tgt = tgt + self.dropout3(tgt2)
|
99 |
+
return tgt
|
100 |
+
|
101 |
+
|
102 |
+
class TransformerDecoderLayerGlobalImproved(Module):
|
103 |
+
def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
|
104 |
+
super(TransformerDecoderLayerGlobalImproved, self).__init__()
|
105 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
106 |
+
|
107 |
+
self.linear_global = Linear(d_global, d_model)
|
108 |
+
|
109 |
+
if d_global2 is not None:
|
110 |
+
self.linear_global2 = Linear(d_global2, d_model)
|
111 |
+
|
112 |
+
# Implementation of Feedforward model
|
113 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
114 |
+
self.dropout = Dropout(dropout)
|
115 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
116 |
+
|
117 |
+
self.norm1 = LayerNorm(d_model)
|
118 |
+
self.norm2 = LayerNorm(d_model)
|
119 |
+
self.dropout1 = Dropout(dropout)
|
120 |
+
self.dropout2 = Dropout(dropout)
|
121 |
+
self.dropout2_2 = Dropout(dropout)
|
122 |
+
self.dropout3 = Dropout(dropout)
|
123 |
+
|
124 |
+
self.activation = _get_activation_fn(activation)
|
125 |
+
|
126 |
+
def __setstate__(self, state):
|
127 |
+
if 'activation' not in state:
|
128 |
+
state['activation'] = F.relu
|
129 |
+
super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state)
|
130 |
+
|
131 |
+
def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs):
|
132 |
+
tgt1 = self.norm1(tgt)
|
133 |
+
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
134 |
+
tgt = tgt + self.dropout1(tgt2)
|
135 |
+
|
136 |
+
tgt2 = self.linear_global(memory)
|
137 |
+
tgt = tgt + self.dropout2(tgt2) # implicit broadcast
|
138 |
+
|
139 |
+
if memory2 is not None:
|
140 |
+
tgt2_2 = self.linear_global2(memory2)
|
141 |
+
tgt = tgt + self.dropout2_2(tgt2_2)
|
142 |
+
|
143 |
+
tgt1 = self.norm2(tgt)
|
144 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
|
145 |
+
tgt = tgt + self.dropout3(tgt2)
|
146 |
+
return tgt
|
src/preprocessing/deepsvg/deepsvg_models/layers/positional_encoding.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class PositionalEncodingSinCos(nn.Module):
|
12 |
+
def __init__(self, d_model, dropout=0.1, max_len=250):
|
13 |
+
super(PositionalEncodingSinCos, self).__init__()
|
14 |
+
self.dropout = nn.Dropout(p=dropout)
|
15 |
+
|
16 |
+
pe = torch.zeros(max_len, d_model)
|
17 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
18 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
19 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
20 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
21 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
22 |
+
self.register_buffer('pe', pe)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = x + self.pe[:x.size(0), :]
|
26 |
+
return self.dropout(x)
|
27 |
+
|
28 |
+
|
29 |
+
class PositionalEncodingLUT(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, d_model, dropout=0.1, max_len=250):
|
32 |
+
super(PositionalEncodingLUT, self).__init__()
|
33 |
+
self.dropout = nn.Dropout(p=dropout)
|
34 |
+
|
35 |
+
position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
|
36 |
+
self.register_buffer('position', position)
|
37 |
+
|
38 |
+
self.pos_embed = nn.Embedding(max_len, d_model)
|
39 |
+
|
40 |
+
self._init_embeddings()
|
41 |
+
|
42 |
+
def _init_embeddings(self):
|
43 |
+
nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
pos = self.position[:x.size(0)]
|
47 |
+
x = x + self.pos_embed(pos)
|
48 |
+
return self.dropout(x)
|
src/preprocessing/deepsvg/deepsvg_models/layers/transformer.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn.modules.module import Module
|
11 |
+
from torch.nn.modules.container import ModuleList
|
12 |
+
from torch.nn.init import xavier_uniform_
|
13 |
+
from torch.nn.modules.dropout import Dropout
|
14 |
+
from torch.nn.modules.linear import Linear
|
15 |
+
from torch.nn.modules.normalization import LayerNorm
|
16 |
+
|
17 |
+
from .attention import MultiheadAttention
|
18 |
+
|
19 |
+
|
20 |
+
class Transformer(Module):
|
21 |
+
r"""A transformer model. User is able to modify the attributes as needed. The architecture
|
22 |
+
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
23 |
+
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
24 |
+
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
25 |
+
Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
|
26 |
+
model with corresponding parameters.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
30 |
+
nhead: the number of heads in the multiheadattention models (default=8).
|
31 |
+
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
32 |
+
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
33 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
34 |
+
dropout: the dropout value (default=0.1).
|
35 |
+
activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
|
36 |
+
custom_encoder: custom encoder (default=None).
|
37 |
+
custom_decoder: custom decoder (default=None).
|
38 |
+
|
39 |
+
Examples::
|
40 |
+
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
41 |
+
>>> src = torch.rand((10, 32, 512))
|
42 |
+
>>> tgt = torch.rand((20, 32, 512))
|
43 |
+
>>> out = transformer_model(src, tgt)
|
44 |
+
|
45 |
+
Note: A full example to apply nn.Transformer module for the word language model is available in
|
46 |
+
https://github.com/pytorch/examples/tree/master/word_language_model
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
50 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
51 |
+
activation="relu", custom_encoder=None, custom_decoder=None):
|
52 |
+
super(Transformer, self).__init__()
|
53 |
+
|
54 |
+
if custom_encoder is not None:
|
55 |
+
self.encoder = custom_encoder
|
56 |
+
else:
|
57 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
|
58 |
+
encoder_norm = LayerNorm(d_model)
|
59 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
60 |
+
|
61 |
+
if custom_decoder is not None:
|
62 |
+
self.decoder = custom_decoder
|
63 |
+
else:
|
64 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
|
65 |
+
decoder_norm = LayerNorm(d_model)
|
66 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
67 |
+
|
68 |
+
self._reset_parameters()
|
69 |
+
|
70 |
+
self.d_model = d_model
|
71 |
+
self.nhead = nhead
|
72 |
+
|
73 |
+
def forward(self, src, tgt, src_mask=None, tgt_mask=None,
|
74 |
+
memory_mask=None, src_key_padding_mask=None,
|
75 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
76 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa
|
77 |
+
r"""Take in and process masked source/target sequences.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
src: the sequence to the encoder (required).
|
81 |
+
tgt: the sequence to the decoder (required).
|
82 |
+
src_mask: the additive mask for the src sequence (optional).
|
83 |
+
tgt_mask: the additive mask for the tgt sequence (optional).
|
84 |
+
memory_mask: the additive mask for the encoder output (optional).
|
85 |
+
src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
|
86 |
+
tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
|
87 |
+
memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
|
88 |
+
|
89 |
+
Shape:
|
90 |
+
- src: :math:`(S, N, E)`.
|
91 |
+
- tgt: :math:`(T, N, E)`.
|
92 |
+
- src_mask: :math:`(S, S)`.
|
93 |
+
- tgt_mask: :math:`(T, T)`.
|
94 |
+
- memory_mask: :math:`(T, S)`.
|
95 |
+
- src_key_padding_mask: :math:`(N, S)`.
|
96 |
+
- tgt_key_padding_mask: :math:`(N, T)`.
|
97 |
+
- memory_key_padding_mask: :math:`(N, S)`.
|
98 |
+
|
99 |
+
Note: [src/tgt/memory]_mask should be filled with
|
100 |
+
float('-inf') for the masked positions and float(0.0) else. These masks
|
101 |
+
ensure that predictions for position i depend only on the unmasked positions
|
102 |
+
j and are applied identically for each sequence in a batch.
|
103 |
+
[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
|
104 |
+
that should be masked with float('-inf') and False values will be unchanged.
|
105 |
+
This mask ensures that no information will be taken from position i if
|
106 |
+
it is masked, and has a separate mask for each sequence in a batch.
|
107 |
+
|
108 |
+
- output: :math:`(T, N, E)`.
|
109 |
+
|
110 |
+
Note: Due to the multi-head attention architecture in the transformer model,
|
111 |
+
the output sequence length of a transformer is same as the input sequence
|
112 |
+
(i.e. target) length of the decode.
|
113 |
+
|
114 |
+
where S is the source sequence length, T is the target sequence length, N is the
|
115 |
+
batch size, E is the feature number
|
116 |
+
|
117 |
+
Examples:
|
118 |
+
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
119 |
+
"""
|
120 |
+
|
121 |
+
if src.size(1) != tgt.size(1):
|
122 |
+
raise RuntimeError("the batch number of src and tgt must be equal")
|
123 |
+
|
124 |
+
if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
|
125 |
+
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
|
126 |
+
|
127 |
+
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
|
128 |
+
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
|
129 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
130 |
+
memory_key_padding_mask=memory_key_padding_mask)
|
131 |
+
return output
|
132 |
+
|
133 |
+
|
134 |
+
def generate_square_subsequent_mask(self, sz):
|
135 |
+
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
136 |
+
Unmasked positions are filled with float(0.0).
|
137 |
+
"""
|
138 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
139 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
140 |
+
return mask
|
141 |
+
|
142 |
+
|
143 |
+
def _reset_parameters(self):
|
144 |
+
r"""Initiate parameters in the transformer model."""
|
145 |
+
|
146 |
+
for p in self.parameters():
|
147 |
+
if p.dim() > 1:
|
148 |
+
xavier_uniform_(p)
|
149 |
+
|
150 |
+
|
151 |
+
class TransformerEncoder(Module):
|
152 |
+
r"""TransformerEncoder is a stack of N encoder layers
|
153 |
+
|
154 |
+
Args:
|
155 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
156 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
157 |
+
norm: the layer normalization component (optional).
|
158 |
+
|
159 |
+
Examples::
|
160 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
161 |
+
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
162 |
+
>>> src = torch.rand(10, 32, 512)
|
163 |
+
>>> out = transformer_encoder(src)
|
164 |
+
"""
|
165 |
+
__constants__ = ['norm']
|
166 |
+
|
167 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
168 |
+
super(TransformerEncoder, self).__init__()
|
169 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
170 |
+
self.num_layers = num_layers
|
171 |
+
self.norm = norm
|
172 |
+
|
173 |
+
def forward(self, src, memory2=None, mask=None, src_key_padding_mask=None):
|
174 |
+
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
175 |
+
r"""Pass the input through the encoder layers in turn.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
src: the sequence to the encoder (required).
|
179 |
+
mask: the mask for the src sequence (optional).
|
180 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
181 |
+
|
182 |
+
Shape:
|
183 |
+
see the docs in Transformer class.
|
184 |
+
"""
|
185 |
+
output = src
|
186 |
+
|
187 |
+
for mod in self.layers:
|
188 |
+
output = mod(output, memory2=memory2, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
|
189 |
+
|
190 |
+
if self.norm is not None:
|
191 |
+
output = self.norm(output)
|
192 |
+
|
193 |
+
return output
|
194 |
+
|
195 |
+
|
196 |
+
class TransformerDecoder(Module):
|
197 |
+
r"""TransformerDecoder is a stack of N decoder layers
|
198 |
+
|
199 |
+
Args:
|
200 |
+
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
201 |
+
num_layers: the number of sub-decoder-layers in the decoder (required).
|
202 |
+
norm: the layer normalization component (optional).
|
203 |
+
|
204 |
+
Examples::
|
205 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
206 |
+
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
207 |
+
>>> memory = torch.rand(10, 32, 512)
|
208 |
+
>>> tgt = torch.rand(20, 32, 512)
|
209 |
+
>>> out = transformer_decoder(tgt, memory)
|
210 |
+
"""
|
211 |
+
__constants__ = ['norm']
|
212 |
+
|
213 |
+
def __init__(self, decoder_layer, num_layers, norm=None):
|
214 |
+
super(TransformerDecoder, self).__init__()
|
215 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
216 |
+
self.num_layers = num_layers
|
217 |
+
self.norm = norm
|
218 |
+
|
219 |
+
def forward(self, tgt, memory, memory2=None, tgt_mask=None,
|
220 |
+
memory_mask=None, tgt_key_padding_mask=None,
|
221 |
+
memory_key_padding_mask=None):
|
222 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
223 |
+
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
tgt: the sequence to the decoder (required).
|
227 |
+
memory: the sequence from the last layer of the encoder (required).
|
228 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
229 |
+
memory_mask: the mask for the memory sequence (optional).
|
230 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
231 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
232 |
+
|
233 |
+
Shape:
|
234 |
+
see the docs in Transformer class.
|
235 |
+
"""
|
236 |
+
output = tgt
|
237 |
+
|
238 |
+
for mod in self.layers:
|
239 |
+
output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
|
240 |
+
memory_mask=memory_mask,
|
241 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
242 |
+
memory_key_padding_mask=memory_key_padding_mask)
|
243 |
+
|
244 |
+
if self.norm is not None:
|
245 |
+
output = self.norm(output)
|
246 |
+
|
247 |
+
return output
|
248 |
+
|
249 |
+
|
250 |
+
class TransformerEncoderLayer(Module):
|
251 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
252 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
253 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
254 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
255 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
256 |
+
in a different way during application.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
d_model: the number of expected features in the input (required).
|
260 |
+
nhead: the number of heads in the multiheadattention models (required).
|
261 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
262 |
+
dropout: the dropout value (default=0.1).
|
263 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
264 |
+
|
265 |
+
Examples::
|
266 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
267 |
+
>>> src = torch.rand(10, 32, 512)
|
268 |
+
>>> out = encoder_layer(src)
|
269 |
+
"""
|
270 |
+
|
271 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
272 |
+
super(TransformerEncoderLayer, self).__init__()
|
273 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
274 |
+
# Implementation of Feedforward model
|
275 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
276 |
+
self.dropout = Dropout(dropout)
|
277 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
278 |
+
|
279 |
+
self.norm1 = LayerNorm(d_model)
|
280 |
+
self.norm2 = LayerNorm(d_model)
|
281 |
+
self.dropout1 = Dropout(dropout)
|
282 |
+
self.dropout2 = Dropout(dropout)
|
283 |
+
|
284 |
+
self.activation = _get_activation_fn(activation)
|
285 |
+
|
286 |
+
def __setstate__(self, state):
|
287 |
+
if 'activation' not in state:
|
288 |
+
state['activation'] = F.relu
|
289 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
290 |
+
|
291 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
292 |
+
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
|
293 |
+
r"""Pass the input through the encoder layer.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
src: the sequence to the encoder layer (required).
|
297 |
+
src_mask: the mask for the src sequence (optional).
|
298 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
299 |
+
|
300 |
+
Shape:
|
301 |
+
see the docs in Transformer class.
|
302 |
+
"""
|
303 |
+
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
|
304 |
+
key_padding_mask=src_key_padding_mask)[0]
|
305 |
+
src = src + self.dropout1(src2)
|
306 |
+
src = self.norm1(src)
|
307 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
308 |
+
src = src + self.dropout2(src2)
|
309 |
+
src = self.norm2(src)
|
310 |
+
return src
|
311 |
+
|
312 |
+
|
313 |
+
class TransformerDecoderLayer(Module):
|
314 |
+
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
315 |
+
This standard decoder layer is based on the paper "Attention Is All You Need".
|
316 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
317 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
318 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
319 |
+
in a different way during application.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
d_model: the number of expected features in the input (required).
|
323 |
+
nhead: the number of heads in the multiheadattention models (required).
|
324 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
325 |
+
dropout: the dropout value (default=0.1).
|
326 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
327 |
+
|
328 |
+
Examples::
|
329 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
330 |
+
>>> memory = torch.rand(10, 32, 512)
|
331 |
+
>>> tgt = torch.rand(20, 32, 512)
|
332 |
+
>>> out = decoder_layer(tgt, memory)
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
336 |
+
super(TransformerDecoderLayer, self).__init__()
|
337 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
338 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
339 |
+
# Implementation of Feedforward model
|
340 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
341 |
+
self.dropout = Dropout(dropout)
|
342 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
343 |
+
|
344 |
+
self.norm1 = LayerNorm(d_model)
|
345 |
+
self.norm2 = LayerNorm(d_model)
|
346 |
+
self.norm3 = LayerNorm(d_model)
|
347 |
+
self.dropout1 = Dropout(dropout)
|
348 |
+
self.dropout2 = Dropout(dropout)
|
349 |
+
self.dropout3 = Dropout(dropout)
|
350 |
+
|
351 |
+
self.activation = _get_activation_fn(activation)
|
352 |
+
|
353 |
+
def __setstate__(self, state):
|
354 |
+
if 'activation' not in state:
|
355 |
+
state['activation'] = F.relu
|
356 |
+
super(TransformerDecoderLayer, self).__setstate__(state)
|
357 |
+
|
358 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
359 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
360 |
+
# type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
|
361 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
tgt: the sequence to the decoder layer (required).
|
365 |
+
memory: the sequence from the last layer of the encoder (required).
|
366 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
367 |
+
memory_mask: the mask for the memory sequence (optional).
|
368 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
369 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
370 |
+
|
371 |
+
Shape:
|
372 |
+
see the docs in Transformer class.
|
373 |
+
"""
|
374 |
+
tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
|
375 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
376 |
+
tgt = tgt + self.dropout1(tgt2)
|
377 |
+
tgt = self.norm1(tgt)
|
378 |
+
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
379 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
380 |
+
tgt = tgt + self.dropout2(tgt2)
|
381 |
+
tgt = self.norm2(tgt)
|
382 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
383 |
+
tgt = tgt + self.dropout3(tgt2)
|
384 |
+
tgt = self.norm3(tgt)
|
385 |
+
return tgt
|
386 |
+
|
387 |
+
|
388 |
+
def _get_clones(module, N):
|
389 |
+
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
390 |
+
|
391 |
+
|
392 |
+
def _get_activation_fn(activation):
|
393 |
+
if activation == "relu":
|
394 |
+
return F.relu
|
395 |
+
elif activation == "gelu":
|
396 |
+
return F.gelu
|
397 |
+
|
398 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
src/preprocessing/deepsvg/deepsvg_models/loss.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
|
10 |
+
from .model_utils import _get_padding_mask, _get_visibility_mask
|
11 |
+
from .model_config import _DefaultConfig
|
12 |
+
|
13 |
+
|
14 |
+
class SVGLoss(nn.Module):
|
15 |
+
def __init__(self, cfg: _DefaultConfig):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.cfg = cfg
|
19 |
+
|
20 |
+
self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
|
21 |
+
|
22 |
+
self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
|
23 |
+
|
24 |
+
def forward(self, output, labels, weights):
|
25 |
+
loss = 0.
|
26 |
+
res = {}
|
27 |
+
|
28 |
+
# VAE
|
29 |
+
if self.cfg.use_vae:
|
30 |
+
mu, logsigma = output["mu"], output["logsigma"]
|
31 |
+
loss_kl = -0.5 * torch.mean(1 + logsigma - mu.pow(2) - torch.exp(logsigma))
|
32 |
+
loss_kl = loss_kl.clamp(min=weights["kl_tolerance"])
|
33 |
+
|
34 |
+
loss += weights["loss_kl_weight"] * loss_kl
|
35 |
+
res["loss_kl"] = loss_kl
|
36 |
+
|
37 |
+
# Target & predictions
|
38 |
+
tgt_commands, tgt_args = output["tgt_commands"], output["tgt_args"]
|
39 |
+
|
40 |
+
visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
|
41 |
+
padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)
|
42 |
+
|
43 |
+
command_logits, args_logits = output["command_logits"], output["args_logits"]
|
44 |
+
|
45 |
+
# 2-stage visibility
|
46 |
+
if self.cfg.decode_stages == 2:
|
47 |
+
visibility_logits = output["visibility_logits"]
|
48 |
+
loss_visibility = F.cross_entropy(visibility_logits.reshape(-1, 2), visibility_mask.reshape(-1).long())
|
49 |
+
|
50 |
+
loss += weights["loss_visibility_weight"] * loss_visibility
|
51 |
+
res["loss_visibility"] = loss_visibility
|
52 |
+
|
53 |
+
# Commands & args
|
54 |
+
tgt_commands, tgt_args, padding_mask = tgt_commands[..., 1:], tgt_args[..., 1:, :], padding_mask[..., 1:]
|
55 |
+
|
56 |
+
mask = self.cmd_args_mask[tgt_commands.long()]
|
57 |
+
|
58 |
+
loss_cmd = F.cross_entropy(command_logits[padding_mask.bool()].reshape(-1, self.cfg.n_commands), tgt_commands[padding_mask.bool()].reshape(-1).long())
|
59 |
+
loss_args = F.cross_entropy(args_logits[mask.bool()].reshape(-1, self.args_dim), tgt_args[mask.bool()].reshape(-1).long() + 1) # shift due to -1 PAD_VAL
|
60 |
+
|
61 |
+
loss += weights["loss_cmd_weight"] * loss_cmd \
|
62 |
+
+ weights["loss_args_weight"] * loss_args
|
63 |
+
|
64 |
+
res.update({
|
65 |
+
"loss": loss,
|
66 |
+
"loss_cmd": loss_cmd,
|
67 |
+
"loss_args": loss_args
|
68 |
+
})
|
69 |
+
|
70 |
+
return res
|
src/preprocessing/deepsvg/deepsvg_models/model.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
|
7 |
+
from src.preprocessing.deepsvg.deepsvg_utils.utils import _pack_group_batch, _unpack_group_batch, _make_seq_first, _make_batch_first
|
8 |
+
|
9 |
+
from .layers.transformer import *
|
10 |
+
from .layers.improved_transformer import *
|
11 |
+
from .layers.positional_encoding import *
|
12 |
+
from .basic_blocks import FCN, HierarchFCN, ResNet
|
13 |
+
from .model_config import _DefaultConfig
|
14 |
+
from .model_utils import (_get_padding_mask, _get_key_padding_mask, _get_group_mask, _get_visibility_mask,
|
15 |
+
_get_key_visibility_mask, _generate_square_subsequent_mask, _sample_categorical, _threshold_sample)
|
16 |
+
|
17 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
|
18 |
+
from scipy.optimize import linear_sum_assignment
|
19 |
+
|
20 |
+
|
21 |
+
class SVGEmbedding(nn.Module):
|
22 |
+
def __init__(self, cfg: _DefaultConfig, seq_len, rel_args=False, use_group=True, group_len=None):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.cfg = cfg
|
26 |
+
|
27 |
+
self.command_embed = nn.Embedding(cfg.n_commands, cfg.d_model)
|
28 |
+
|
29 |
+
args_dim = 2 * cfg.args_dim if rel_args else cfg.args_dim + 1
|
30 |
+
self.arg_embed = nn.Embedding(args_dim, 64)
|
31 |
+
self.embed_fcn = nn.Linear(64 * cfg.n_args, cfg.d_model)
|
32 |
+
|
33 |
+
self.use_group = use_group
|
34 |
+
if use_group:
|
35 |
+
if group_len is None:
|
36 |
+
group_len = cfg.max_num_groups
|
37 |
+
self.group_embed = nn.Embedding(group_len+2, cfg.d_model)
|
38 |
+
|
39 |
+
self.pos_encoding = PositionalEncodingLUT(cfg.d_model, max_len=seq_len+2)
|
40 |
+
|
41 |
+
self._init_embeddings()
|
42 |
+
|
43 |
+
def _init_embeddings(self):
|
44 |
+
nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in")
|
45 |
+
nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in")
|
46 |
+
nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in")
|
47 |
+
|
48 |
+
if self.use_group:
|
49 |
+
nn.init.kaiming_normal_(self.group_embed.weight, mode="fan_in")
|
50 |
+
|
51 |
+
def forward(self, commands, args, groups=None):
|
52 |
+
S, GN = commands.shape
|
53 |
+
|
54 |
+
src = self.command_embed(commands.long()) + \
|
55 |
+
self.embed_fcn(self.arg_embed((args + 1).long()).view(S, GN, -1)) # shift due to -1 PAD_VAL
|
56 |
+
|
57 |
+
if self.use_group:
|
58 |
+
src = src + self.group_embed(groups.long())
|
59 |
+
|
60 |
+
src = self.pos_encoding(src)
|
61 |
+
|
62 |
+
return src
|
63 |
+
|
64 |
+
|
65 |
+
class ConstEmbedding(nn.Module):
|
66 |
+
def __init__(self, cfg: _DefaultConfig, seq_len):
|
67 |
+
super().__init__()
|
68 |
+
|
69 |
+
self.cfg = cfg
|
70 |
+
|
71 |
+
self.seq_len = seq_len
|
72 |
+
|
73 |
+
self.PE = PositionalEncodingLUT(cfg.d_model, max_len=seq_len)
|
74 |
+
|
75 |
+
def forward(self, z):
|
76 |
+
N = z.size(1)
|
77 |
+
src = self.PE(z.new_zeros(self.seq_len, N, self.cfg.d_model))
|
78 |
+
return src
|
79 |
+
|
80 |
+
|
81 |
+
class LabelEmbedding(nn.Module):
|
82 |
+
def __init__(self, cfg: _DefaultConfig):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.label_embedding = nn.Embedding(cfg.n_labels, cfg.dim_label)
|
86 |
+
|
87 |
+
self._init_embeddings()
|
88 |
+
|
89 |
+
def _init_embeddings(self):
|
90 |
+
nn.init.kaiming_normal_(self.label_embedding.weight, mode="fan_in")
|
91 |
+
|
92 |
+
def forward(self, label):
|
93 |
+
src = self.label_embedding(label)
|
94 |
+
return src
|
95 |
+
|
96 |
+
|
97 |
+
class Encoder(nn.Module):
|
98 |
+
def __init__(self, cfg: _DefaultConfig):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.cfg = cfg
|
102 |
+
|
103 |
+
seq_len = cfg.max_seq_len if cfg.encode_stages == 2 else cfg.max_total_len
|
104 |
+
self.use_group = cfg.encode_stages == 1
|
105 |
+
self.embedding = SVGEmbedding(cfg, seq_len, use_group=self.use_group)
|
106 |
+
|
107 |
+
if cfg.label_condition:
|
108 |
+
self.label_embedding = LabelEmbedding(cfg)
|
109 |
+
dim_label = cfg.dim_label if cfg.label_condition else None
|
110 |
+
|
111 |
+
if cfg.model_type == "transformer":
|
112 |
+
encoder_layer = TransformerEncoderLayerImproved(cfg.d_model, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
113 |
+
encoder_norm = LayerNorm(cfg.d_model)
|
114 |
+
self.encoder = TransformerEncoder(encoder_layer, cfg.n_layers, encoder_norm)
|
115 |
+
else: # "lstm"
|
116 |
+
self.encoder = nn.LSTM(cfg.d_model, cfg.d_model // 2, dropout=cfg.dropout, bidirectional=True)
|
117 |
+
|
118 |
+
if cfg.encode_stages == 2:
|
119 |
+
if not cfg.self_match:
|
120 |
+
self.hierarchical_PE = PositionalEncodingLUT(cfg.d_model, max_len=cfg.max_num_groups)
|
121 |
+
|
122 |
+
hierarchical_encoder_layer = TransformerEncoderLayerImproved(cfg.d_model, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
123 |
+
hierarchical_encoder_norm = LayerNorm(cfg.d_model)
|
124 |
+
self.hierarchical_encoder = TransformerEncoder(hierarchical_encoder_layer, cfg.n_layers, hierarchical_encoder_norm)
|
125 |
+
|
126 |
+
def forward(self, commands, args, label=None):
|
127 |
+
S, G, N = commands.shape
|
128 |
+
l = self.label_embedding(label).unsqueeze(0).unsqueeze(0).repeat(1, commands.size(1), 1, 1) if self.cfg.label_condition else None
|
129 |
+
|
130 |
+
if self.cfg.encode_stages == 2:
|
131 |
+
visibility_mask, key_visibility_mask = _get_visibility_mask(commands, seq_dim=0), _get_key_visibility_mask(commands, seq_dim=0)
|
132 |
+
|
133 |
+
commands, args, l = _pack_group_batch(commands, args, l)
|
134 |
+
padding_mask, key_padding_mask = _get_padding_mask(commands, seq_dim=0), _get_key_padding_mask(commands, seq_dim=0)
|
135 |
+
group_mask = _get_group_mask(commands, seq_dim=0) if self.use_group else None
|
136 |
+
|
137 |
+
src = self.embedding(commands, args, group_mask)
|
138 |
+
|
139 |
+
if self.cfg.model_type == "transformer":
|
140 |
+
memory = self.encoder(src, mask=None, src_key_padding_mask=key_padding_mask, memory2=l)
|
141 |
+
|
142 |
+
z = (memory * padding_mask).sum(dim=0, keepdim=True) / padding_mask.sum(dim=0, keepdim=True)
|
143 |
+
else: # "lstm"
|
144 |
+
hidden_cell = (src.new_zeros(2, N, self.cfg.d_model // 2),
|
145 |
+
src.new_zeros(2, N, self.cfg.d_model // 2))
|
146 |
+
sequence_lengths = padding_mask.sum(dim=0).squeeze(-1)
|
147 |
+
x = pack_padded_sequence(src, sequence_lengths, enforce_sorted=False)
|
148 |
+
|
149 |
+
packed_output, _ = self.encoder(x, hidden_cell)
|
150 |
+
|
151 |
+
memory, _ = pad_packed_sequence(packed_output)
|
152 |
+
idx = (sequence_lengths - 1).long().view(1, -1, 1).repeat(1, 1, self.cfg.d_model)
|
153 |
+
z = memory.gather(dim=0, index=idx)
|
154 |
+
|
155 |
+
z = _unpack_group_batch(N, z)
|
156 |
+
|
157 |
+
if self.cfg.encode_stages == 2:
|
158 |
+
src = z.transpose(0, 1)
|
159 |
+
src = _pack_group_batch(src)
|
160 |
+
l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None
|
161 |
+
|
162 |
+
if not self.cfg.self_match:
|
163 |
+
src = self.hierarchical_PE(src)
|
164 |
+
|
165 |
+
memory = self.hierarchical_encoder(src, mask=None, src_key_padding_mask=key_visibility_mask, memory2=l)
|
166 |
+
z = (memory * visibility_mask).sum(dim=0, keepdim=True) / visibility_mask.sum(dim=0, keepdim=True)
|
167 |
+
z = _unpack_group_batch(N, z)
|
168 |
+
|
169 |
+
return z
|
170 |
+
|
171 |
+
|
172 |
+
class VAE(nn.Module):
|
173 |
+
def __init__(self, cfg: _DefaultConfig):
|
174 |
+
super(VAE, self).__init__()
|
175 |
+
|
176 |
+
self.enc_mu_fcn = nn.Linear(cfg.d_model, cfg.dim_z)
|
177 |
+
self.enc_sigma_fcn = nn.Linear(cfg.d_model, cfg.dim_z)
|
178 |
+
|
179 |
+
self._init_embeddings()
|
180 |
+
|
181 |
+
def _init_embeddings(self):
|
182 |
+
nn.init.normal_(self.enc_mu_fcn.weight, std=0.001)
|
183 |
+
nn.init.constant_(self.enc_mu_fcn.bias, 0)
|
184 |
+
nn.init.normal_(self.enc_sigma_fcn.weight, std=0.001)
|
185 |
+
nn.init.constant_(self.enc_sigma_fcn.bias, 0)
|
186 |
+
|
187 |
+
def forward(self, z):
|
188 |
+
mu, logsigma = self.enc_mu_fcn(z), self.enc_sigma_fcn(z)
|
189 |
+
sigma = torch.exp(logsigma / 2.)
|
190 |
+
z = mu + sigma * torch.randn_like(sigma)
|
191 |
+
|
192 |
+
return z, mu, logsigma
|
193 |
+
|
194 |
+
|
195 |
+
class Bottleneck(nn.Module):
|
196 |
+
def __init__(self, cfg: _DefaultConfig):
|
197 |
+
super(Bottleneck, self).__init__()
|
198 |
+
|
199 |
+
self.bottleneck = nn.Linear(cfg.d_model, cfg.dim_z)
|
200 |
+
|
201 |
+
def forward(self, z):
|
202 |
+
return self.bottleneck(z)
|
203 |
+
|
204 |
+
|
205 |
+
class Decoder(nn.Module):
|
206 |
+
def __init__(self, cfg: _DefaultConfig):
|
207 |
+
super(Decoder, self).__init__()
|
208 |
+
|
209 |
+
self.cfg = cfg
|
210 |
+
|
211 |
+
if cfg.label_condition:
|
212 |
+
self.label_embedding = LabelEmbedding(cfg)
|
213 |
+
dim_label = cfg.dim_label if cfg.label_condition else None
|
214 |
+
|
215 |
+
if cfg.decode_stages == 2:
|
216 |
+
self.hierarchical_embedding = ConstEmbedding(cfg, cfg.num_groups_proposal)
|
217 |
+
|
218 |
+
hierarchical_decoder_layer = TransformerDecoderLayerGlobalImproved(cfg.d_model, cfg.dim_z, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
219 |
+
hierarchical_decoder_norm = LayerNorm(cfg.d_model)
|
220 |
+
self.hierarchical_decoder = TransformerDecoder(hierarchical_decoder_layer, cfg.n_layers_decode, hierarchical_decoder_norm)
|
221 |
+
self.hierarchical_fcn = HierarchFCN(cfg.d_model, cfg.dim_z)
|
222 |
+
|
223 |
+
if cfg.pred_mode == "autoregressive":
|
224 |
+
self.embedding = SVGEmbedding(cfg, cfg.max_total_len, rel_args=cfg.rel_targets, use_group=True, group_len=cfg.max_total_len)
|
225 |
+
|
226 |
+
square_subsequent_mask = _generate_square_subsequent_mask(self.cfg.max_total_len+1)
|
227 |
+
self.register_buffer("square_subsequent_mask", square_subsequent_mask)
|
228 |
+
else: # "one_shot"
|
229 |
+
seq_len = cfg.max_seq_len+1 if cfg.decode_stages == 2 else cfg.max_total_len+1
|
230 |
+
self.embedding = ConstEmbedding(cfg, seq_len)
|
231 |
+
|
232 |
+
if cfg.model_type == "transformer":
|
233 |
+
decoder_layer = TransformerDecoderLayerGlobalImproved(cfg.d_model, cfg.dim_z, cfg.n_heads, cfg.dim_feedforward, cfg.dropout, d_global2=dim_label)
|
234 |
+
decoder_norm = LayerNorm(cfg.d_model)
|
235 |
+
self.decoder = TransformerDecoder(decoder_layer, cfg.n_layers_decode, decoder_norm)
|
236 |
+
else: # "lstm"
|
237 |
+
self.fc_hc = nn.Linear(cfg.dim_z, 2 * cfg.d_model)
|
238 |
+
self.decoder = nn.LSTM(cfg.d_model, cfg.d_model, dropout=cfg.dropout)
|
239 |
+
|
240 |
+
args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
|
241 |
+
self.fcn = FCN(cfg.d_model, cfg.n_commands, cfg.n_args, args_dim)
|
242 |
+
|
243 |
+
def _get_initial_state(self, z):
|
244 |
+
hidden, cell = torch.split(torch.tanh(self.fc_hc(z)), self.cfg.d_model, dim=2)
|
245 |
+
hidden_cell = hidden.contiguous(), cell.contiguous()
|
246 |
+
return hidden_cell
|
247 |
+
|
248 |
+
def forward(self, z, commands, args, label=None, hierarch_logits=None, return_hierarch=False):
|
249 |
+
N = z.size(2)
|
250 |
+
l = self.label_embedding(label).unsqueeze(0) if self.cfg.label_condition else None
|
251 |
+
if hierarch_logits is None:
|
252 |
+
z = _pack_group_batch(z)
|
253 |
+
|
254 |
+
if self.cfg.decode_stages == 2:
|
255 |
+
if hierarch_logits is None:
|
256 |
+
src = self.hierarchical_embedding(z)
|
257 |
+
out = self.hierarchical_decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
|
258 |
+
hierarch_logits, z = self.hierarchical_fcn(out)
|
259 |
+
|
260 |
+
if self.cfg.label_condition: l = l.unsqueeze(0).repeat(1, z.size(1), 1, 1)
|
261 |
+
|
262 |
+
hierarch_logits, z, l = _pack_group_batch(hierarch_logits, z, l)
|
263 |
+
|
264 |
+
if return_hierarch:
|
265 |
+
return _unpack_group_batch(N, hierarch_logits, z)
|
266 |
+
|
267 |
+
if self.cfg.pred_mode == "autoregressive":
|
268 |
+
S = commands.size(0)
|
269 |
+
commands, args = _pack_group_batch(commands, args)
|
270 |
+
|
271 |
+
group_mask = _get_group_mask(commands, seq_dim=0)
|
272 |
+
|
273 |
+
src = self.embedding(commands, args, group_mask)
|
274 |
+
|
275 |
+
if self.cfg.model_type == "transformer":
|
276 |
+
key_padding_mask = _get_key_padding_mask(commands, seq_dim=0)
|
277 |
+
out = self.decoder(src, z, tgt_mask=self.square_subsequent_mask[:S, :S], tgt_key_padding_mask=key_padding_mask, memory2=l)
|
278 |
+
else: # "lstm"
|
279 |
+
hidden_cell = self._get_initial_state(z)
|
280 |
+
out, _ = self.decoder(src, hidden_cell)
|
281 |
+
|
282 |
+
else: # "one_shot"
|
283 |
+
src = self.embedding(z)
|
284 |
+
out = self.decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None, memory2=l)
|
285 |
+
|
286 |
+
command_logits, args_logits = self.fcn(out)
|
287 |
+
|
288 |
+
out_logits = (command_logits, args_logits) + ((hierarch_logits,) if self.cfg.decode_stages == 2 else ())
|
289 |
+
|
290 |
+
return _unpack_group_batch(N, *out_logits)
|
291 |
+
|
292 |
+
|
293 |
+
class SVGTransformer(nn.Module):
|
294 |
+
def __init__(self, cfg: _DefaultConfig):
|
295 |
+
super(SVGTransformer, self).__init__()
|
296 |
+
|
297 |
+
self.cfg = cfg
|
298 |
+
self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
|
299 |
+
|
300 |
+
if self.cfg.encode_stages > 0:
|
301 |
+
|
302 |
+
self.encoder = Encoder(cfg)
|
303 |
+
|
304 |
+
if cfg.use_resnet:
|
305 |
+
self.resnet = ResNet(cfg.d_model)
|
306 |
+
|
307 |
+
if cfg.use_vae:
|
308 |
+
self.vae = VAE(cfg)
|
309 |
+
else:
|
310 |
+
self.bottleneck = Bottleneck(cfg)
|
311 |
+
|
312 |
+
self.decoder = Decoder(cfg)
|
313 |
+
|
314 |
+
self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
|
315 |
+
|
316 |
+
def perfect_matching(self, command_logits, args_logits, hierarch_logits, tgt_commands, tgt_args):
|
317 |
+
with torch.no_grad():
|
318 |
+
N, G, S, n_args = tgt_args.shape
|
319 |
+
visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
|
320 |
+
padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)
|
321 |
+
|
322 |
+
# Unsqueeze
|
323 |
+
tgt_commands, tgt_args, tgt_hierarch = tgt_commands.unsqueeze(2), tgt_args.unsqueeze(2), visibility_mask.unsqueeze(2)
|
324 |
+
command_logits, args_logits, hierarch_logits = command_logits.unsqueeze(1), args_logits.unsqueeze(1), hierarch_logits.unsqueeze(1).squeeze(-2)
|
325 |
+
|
326 |
+
# Loss
|
327 |
+
tgt_hierarch, hierarch_logits = tgt_hierarch.repeat(1, 1, self.cfg.num_groups_proposal), hierarch_logits.repeat(1, G, 1, 1)
|
328 |
+
tgt_commands, command_logits = tgt_commands.repeat(1, 1, self.cfg.num_groups_proposal, 1), command_logits.repeat(1, G, 1, 1, 1)
|
329 |
+
tgt_args, args_logits = tgt_args.repeat(1, 1, self.cfg.num_groups_proposal, 1, 1), args_logits.repeat(1, G, 1, 1, 1, 1)
|
330 |
+
|
331 |
+
padding_mask, mask = padding_mask.unsqueeze(2).repeat(1, 1, self.cfg.num_groups_proposal, 1), self.cmd_args_mask[tgt_commands.long()]
|
332 |
+
|
333 |
+
loss_args = F.cross_entropy(args_logits.reshape(-1, self.args_dim), tgt_args.reshape(-1).long() + 1, reduction="none").reshape(N, G, self.cfg.num_groups_proposal, S, n_args) # shift due to -1 PAD_VAL
|
334 |
+
loss_cmd = F.cross_entropy(command_logits.reshape(-1, self.cfg.n_commands), tgt_commands.reshape(-1).long(), reduction="none").reshape(N, G, self.cfg.num_groups_proposal, S)
|
335 |
+
loss_hierarch = F.cross_entropy(hierarch_logits.reshape(-1, 2), tgt_hierarch.reshape(-1).long(), reduction="none").reshape(N, G, self.cfg.num_groups_proposal)
|
336 |
+
|
337 |
+
loss_args = (loss_args * mask).sum(dim=[-1, -2]) / mask.sum(dim=[-1, -2])
|
338 |
+
loss_cmd = (loss_cmd * padding_mask).sum(dim=-1) / padding_mask.sum(dim=-1)
|
339 |
+
|
340 |
+
loss = 2.0 * loss_args + 1.0 * loss_cmd + 1.0 * loss_hierarch
|
341 |
+
|
342 |
+
# Iterate over the batch-dimension
|
343 |
+
assignment_list = []
|
344 |
+
|
345 |
+
full_set = set(range(self.cfg.num_groups_proposal))
|
346 |
+
for i in range(N):
|
347 |
+
costs = loss[i]
|
348 |
+
mask = visibility_mask[i]
|
349 |
+
_, assign = linear_sum_assignment(costs[mask].cpu())
|
350 |
+
assign = assign.tolist()
|
351 |
+
assignment_list.append(assign + list(full_set - set(assign)))
|
352 |
+
|
353 |
+
assignment = torch.tensor(assignment_list, device=command_logits.device)
|
354 |
+
|
355 |
+
return assignment.unsqueeze(-1).unsqueeze(-1)
|
356 |
+
|
357 |
+
def forward(self, commands_enc, args_enc, commands_dec, args_dec, label=None,
|
358 |
+
z=None, hierarch_logits=None,
|
359 |
+
return_tgt=True, params=None, encode_mode=False, return_hierarch=False):
|
360 |
+
commands_enc, args_enc = _make_seq_first(commands_enc, args_enc) # Possibly None, None
|
361 |
+
commands_dec_, args_dec_ = _make_seq_first(commands_dec, args_dec)
|
362 |
+
|
363 |
+
if z is None:
|
364 |
+
z = self.encoder(commands_enc, args_enc, label)
|
365 |
+
|
366 |
+
if self.cfg.use_resnet:
|
367 |
+
z = self.resnet(z)
|
368 |
+
|
369 |
+
if self.cfg.use_vae:
|
370 |
+
z, mu, logsigma = self.vae(z)
|
371 |
+
else:
|
372 |
+
z = self.bottleneck(z)
|
373 |
+
else:
|
374 |
+
z = _make_seq_first(z)
|
375 |
+
|
376 |
+
if encode_mode: return z
|
377 |
+
|
378 |
+
if return_tgt: # Train mode
|
379 |
+
commands_dec_, args_dec_ = commands_dec_[:-1], args_dec_[:-1]
|
380 |
+
|
381 |
+
out_logits = self.decoder(z, commands_dec_, args_dec_, label, hierarch_logits=hierarch_logits,
|
382 |
+
return_hierarch=return_hierarch)
|
383 |
+
|
384 |
+
if return_hierarch:
|
385 |
+
return out_logits
|
386 |
+
|
387 |
+
out_logits = _make_batch_first(*out_logits)
|
388 |
+
|
389 |
+
if return_tgt and self.cfg.self_match: # Assignment
|
390 |
+
assert self.cfg.decode_stages == 2 # Self-matching expects two-stage decoder
|
391 |
+
command_logits, args_logits, hierarch_logits = out_logits
|
392 |
+
|
393 |
+
assignment = self.perfect_matching(command_logits, args_logits, hierarch_logits, commands_dec[..., 1:], args_dec[..., 1:, :])
|
394 |
+
|
395 |
+
command_logits = torch.gather(command_logits, dim=1, index=assignment.expand_as(command_logits))
|
396 |
+
args_logits = torch.gather(args_logits, dim=1, index=assignment.unsqueeze(-1).expand_as(args_logits))
|
397 |
+
hierarch_logits = torch.gather(hierarch_logits, dim=1, index=assignment.expand_as(hierarch_logits))
|
398 |
+
|
399 |
+
out_logits = (command_logits, args_logits, hierarch_logits)
|
400 |
+
|
401 |
+
res = {
|
402 |
+
"command_logits": out_logits[0],
|
403 |
+
"args_logits": out_logits[1]
|
404 |
+
}
|
405 |
+
|
406 |
+
if self.cfg.decode_stages == 2:
|
407 |
+
res["visibility_logits"] = out_logits[2]
|
408 |
+
|
409 |
+
if return_tgt:
|
410 |
+
res["tgt_commands"] = commands_dec
|
411 |
+
res["tgt_args"] = args_dec
|
412 |
+
|
413 |
+
if self.cfg.use_vae:
|
414 |
+
res["mu"] = _make_batch_first(mu)
|
415 |
+
res["logsigma"] = _make_batch_first(logsigma)
|
416 |
+
|
417 |
+
return res
|
418 |
+
|
419 |
+
def greedy_sample(self, commands_enc=None, args_enc=None, commands_dec=None, args_dec=None, label=None,
|
420 |
+
z=None, hierarch_logits=None,
|
421 |
+
concat_groups=True, temperature=0.0001):
|
422 |
+
if self.cfg.pred_mode == "one_shot":
|
423 |
+
res = self.forward(commands_enc, args_enc, commands_dec, args_dec, label=label, z=z, hierarch_logits=hierarch_logits, return_tgt=False)
|
424 |
+
commands_y, args_y = _sample_categorical(temperature, res["command_logits"], res["args_logits"])
|
425 |
+
args_y -= 1 # shift due to -1 PAD_VAL
|
426 |
+
visibility_y = _threshold_sample(res["visibility_logits"], threshold=0.7).bool().squeeze(-1) if self.cfg.decode_stages == 2 else None
|
427 |
+
commands_y, args_y = self._make_valid(commands_y, args_y, visibility_y)
|
428 |
+
else:
|
429 |
+
if z is None:
|
430 |
+
z = self.forward(commands_enc, args_enc, None, None, label=label, encode_mode=True)
|
431 |
+
|
432 |
+
PAD_VAL = -1
|
433 |
+
commands_y, args_y = z.new_zeros(1, 1, 1).fill_(SVGTensor.COMMANDS_SIMPLIFIED.index("SOS")).long(), z.new_ones(1, 1, 1, self.cfg.n_args).fill_(PAD_VAL).long()
|
434 |
+
|
435 |
+
for i in range(self.cfg.max_total_len):
|
436 |
+
res = self.forward(None, None, commands_y, args_y, label=label, z=z, hierarch_logits=hierarch_logits, return_tgt=False)
|
437 |
+
commands_new_y, args_new_y = _sample_categorical(temperature, res["command_logits"], res["args_logits"])
|
438 |
+
args_new_y -= 1 # shift due to -1 PAD_VAL
|
439 |
+
_, args_new_y = self._make_valid(commands_new_y, args_new_y)
|
440 |
+
|
441 |
+
commands_y, args_y = torch.cat([commands_y, commands_new_y[..., -1:]], dim=-1), torch.cat([args_y, args_new_y[..., -1:, :]], dim=-2)
|
442 |
+
|
443 |
+
commands_y, args_y = commands_y[..., 1:], args_y[..., 1:, :] # Discard SOS token
|
444 |
+
|
445 |
+
if self.cfg.rel_targets:
|
446 |
+
args_y = self._make_absolute(commands_y, args_y)
|
447 |
+
|
448 |
+
if concat_groups:
|
449 |
+
N = commands_y.size(0)
|
450 |
+
padding_mask_y = _get_padding_mask(commands_y, seq_dim=-1).bool()
|
451 |
+
commands_y, args_y = commands_y[padding_mask_y].reshape(N, -1), args_y[padding_mask_y].reshape(N, -1, self.cfg.n_args)
|
452 |
+
|
453 |
+
return commands_y, args_y
|
454 |
+
|
455 |
+
def _make_valid(self, commands_y, args_y, visibility_y=None, PAD_VAL=-1):
|
456 |
+
if visibility_y is not None:
|
457 |
+
S = commands_y.size(-1)
|
458 |
+
commands_y[~visibility_y] = commands_y.new_tensor([SVGTensor.COMMANDS_SIMPLIFIED.index("m"), *[SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")] * (S - 1)])
|
459 |
+
args_y[~visibility_y] = PAD_VAL
|
460 |
+
|
461 |
+
mask = self.cmd_args_mask[commands_y.long()].bool()
|
462 |
+
args_y[~mask] = PAD_VAL
|
463 |
+
|
464 |
+
return commands_y, args_y
|
465 |
+
|
466 |
+
def _make_absolute(self, commands_y, args_y):
|
467 |
+
|
468 |
+
mask = self.cmd_args_mask[commands_y.long()].bool()
|
469 |
+
args_y[mask] -= self.cfg.args_dim - 1
|
470 |
+
|
471 |
+
real_commands = commands_y < SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")
|
472 |
+
|
473 |
+
args_real_commands = args_y[real_commands]
|
474 |
+
end_pos = args_real_commands[:-1, SVGTensor.IndexArgs.END_POS].cumsum(dim=0)
|
475 |
+
|
476 |
+
args_real_commands[1:, SVGTensor.IndexArgs.CONTROL1] += end_pos
|
477 |
+
args_real_commands[1:, SVGTensor.IndexArgs.CONTROL2] += end_pos
|
478 |
+
args_real_commands[1:, SVGTensor.IndexArgs.END_POS] += end_pos
|
479 |
+
|
480 |
+
args_y[real_commands] = args_real_commands
|
481 |
+
|
482 |
+
_, args_y = self._make_valid(commands_y, args_y)
|
483 |
+
|
484 |
+
return args_y
|
src/preprocessing/deepsvg/deepsvg_models/model_config.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
|
7 |
+
|
8 |
+
|
9 |
+
class _DefaultConfig:
|
10 |
+
"""
|
11 |
+
Model config.
|
12 |
+
"""
|
13 |
+
def __init__(self):
|
14 |
+
self.args_dim = 256 # Coordinate numericalization, default: 256 (8-bit)
|
15 |
+
self.n_args = 11 # Tensor nb of arguments, default: 11 (rx,ry,phi,fA,fS,qx1,qy1,qx2,qy2,x1,x2)
|
16 |
+
self.n_commands = len(SVGTensor.COMMANDS_SIMPLIFIED) # m, l, c, a, EOS, SOS, z
|
17 |
+
|
18 |
+
self.dropout = 0.1 # Dropout rate used in basic layers and Transformers
|
19 |
+
|
20 |
+
self.model_type = "transformer" # "transformer" ("lstm" implementation is work in progress)
|
21 |
+
|
22 |
+
self.encode_stages = 1 # One-stage or two-stage: 1 | 2
|
23 |
+
self.decode_stages = 1 # One-stage or two-stage: 1 | 2
|
24 |
+
|
25 |
+
self.use_resnet = True # Use extra fully-connected residual blocks after Encoder
|
26 |
+
|
27 |
+
self.use_vae = True # Sample latent vector (with reparametrization trick) or use encodings directly
|
28 |
+
|
29 |
+
self.pred_mode = "one_shot" # Feed-forward (one-shot) or autogressive: "one_shot" | "autoregressive"
|
30 |
+
self.rel_targets = False # Predict coordinates in relative or absolute format
|
31 |
+
|
32 |
+
self.label_condition = False # Make all blocks conditional on the label
|
33 |
+
self.n_labels = 100 # Number of labels (when used)
|
34 |
+
self.dim_label = 64 # Label embedding dimensionality
|
35 |
+
|
36 |
+
self.self_match = False # Use Hungarian (self-match) or Ordered assignment
|
37 |
+
|
38 |
+
self.n_layers = 4 # Number of Encoder blocks
|
39 |
+
self.n_layers_decode = 4 # Number of Decoder blocks
|
40 |
+
self.n_heads = 8 # Transformer config: number of heads
|
41 |
+
self.dim_feedforward = 512 # Transformer config: FF dimensionality
|
42 |
+
self.d_model = 256 # Transformer config: model dimensionality
|
43 |
+
|
44 |
+
self.dim_z = 256 # Latent vector dimensionality
|
45 |
+
|
46 |
+
self.max_num_groups = 8 # Number of paths (N_P)
|
47 |
+
self.max_seq_len = 30 # Number of commands (N_C)
|
48 |
+
self.max_total_len = self.max_num_groups * self.max_seq_len # Concatenated sequence length for baselines
|
49 |
+
|
50 |
+
self.num_groups_proposal = self.max_num_groups # Number of predicted paths, default: N_P
|
51 |
+
|
52 |
+
def get_model_args(self):
|
53 |
+
model_args = []
|
54 |
+
|
55 |
+
model_args += ["commands_grouped", "args_grouped"] if self.encode_stages <= 1 else ["commands", "args"]
|
56 |
+
|
57 |
+
if self.rel_targets:
|
58 |
+
model_args += ["commands_grouped", "args_rel_grouped"] if self.decode_stages == 1 else ["commands", "args_rel"]
|
59 |
+
else:
|
60 |
+
model_args += ["commands_grouped", "args_grouped"] if self.decode_stages == 1 else ["commands", "args"]
|
61 |
+
|
62 |
+
if self.label_condition:
|
63 |
+
model_args.append("label")
|
64 |
+
|
65 |
+
return model_args
|
66 |
+
|
67 |
+
|
68 |
+
class SketchRNN(_DefaultConfig):
|
69 |
+
# LSTM - Autoregressive - One-stage
|
70 |
+
def __init__(self):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.model_type = "lstm"
|
74 |
+
|
75 |
+
self.pred_mode = "autoregressive"
|
76 |
+
self.rel_targets = True
|
77 |
+
|
78 |
+
|
79 |
+
class Sketchformer(_DefaultConfig):
|
80 |
+
# Transformer - Autoregressive - One-stage
|
81 |
+
def __init__(self):
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
self.pred_mode = "autoregressive"
|
85 |
+
self.rel_targets = True
|
86 |
+
|
87 |
+
|
88 |
+
class OneStageOneShot(_DefaultConfig):
|
89 |
+
# Transformer - One-shot - One-stage
|
90 |
+
def __init__(self):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.encode_stages = 1
|
94 |
+
self.decode_stages = 1
|
95 |
+
|
96 |
+
|
97 |
+
class Hierarchical(_DefaultConfig):
|
98 |
+
# Transformer - One-shot - Two-stage - Ordered
|
99 |
+
def __init__(self):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.encode_stages = 2
|
103 |
+
self.decode_stages = 2
|
104 |
+
|
105 |
+
|
106 |
+
class HierarchicalSelfMatching(_DefaultConfig):
|
107 |
+
# Transformer - One-shot - Two-stage - Hungarian
|
108 |
+
def __init__(self):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
self.encode_stages = 2
|
112 |
+
self.decode_stages = 2
|
113 |
+
self.self_match = True
|
src/preprocessing/deepsvg/deepsvg_models/model_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
|
8 |
+
from torch.distributions.categorical import Categorical
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def _get_key_padding_mask(commands, seq_dim=0):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
commands: Shape [S, ...]
|
16 |
+
"""
|
17 |
+
with torch.no_grad():
|
18 |
+
key_padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) > 0
|
19 |
+
|
20 |
+
if seq_dim == 0:
|
21 |
+
return key_padding_mask.transpose(0, 1)
|
22 |
+
return key_padding_mask
|
23 |
+
|
24 |
+
|
25 |
+
def _get_padding_mask(commands, seq_dim=0, extended=False):
|
26 |
+
with torch.no_grad():
|
27 |
+
padding_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).cumsum(dim=seq_dim) == 0
|
28 |
+
padding_mask = padding_mask.float()
|
29 |
+
|
30 |
+
if extended:
|
31 |
+
# padding_mask doesn't include the final EOS, extend by 1 position to include it in the loss
|
32 |
+
S = commands.size(seq_dim)
|
33 |
+
torch.narrow(padding_mask, seq_dim, 3, S-3).add_(torch.narrow(padding_mask, seq_dim, 0, S-3)).clamp_(max=1)
|
34 |
+
|
35 |
+
if seq_dim == 0:
|
36 |
+
return padding_mask.unsqueeze(-1)
|
37 |
+
return padding_mask
|
38 |
+
|
39 |
+
|
40 |
+
def _get_group_mask(commands, seq_dim=0):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
commands: Shape [S, ...]
|
44 |
+
"""
|
45 |
+
with torch.no_grad():
|
46 |
+
group_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("m")).cumsum(dim=seq_dim)
|
47 |
+
return group_mask
|
48 |
+
|
49 |
+
|
50 |
+
def _get_visibility_mask(commands, seq_dim=0):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
commands: Shape [S, ...]
|
54 |
+
"""
|
55 |
+
S = commands.size(seq_dim)
|
56 |
+
with torch.no_grad():
|
57 |
+
visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) < S - 1
|
58 |
+
|
59 |
+
if seq_dim == 0:
|
60 |
+
return visibility_mask.unsqueeze(-1)
|
61 |
+
return visibility_mask
|
62 |
+
|
63 |
+
|
64 |
+
def _get_key_visibility_mask(commands, seq_dim=0):
|
65 |
+
S = commands.size(seq_dim)
|
66 |
+
with torch.no_grad():
|
67 |
+
key_visibility_mask = (commands == SVGTensor.COMMANDS_SIMPLIFIED.index("EOS")).sum(dim=seq_dim) >= S - 1
|
68 |
+
|
69 |
+
if seq_dim == 0:
|
70 |
+
return key_visibility_mask.transpose(0, 1)
|
71 |
+
return key_visibility_mask
|
72 |
+
|
73 |
+
|
74 |
+
def _generate_square_subsequent_mask(sz):
|
75 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
76 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
77 |
+
return mask
|
78 |
+
|
79 |
+
|
80 |
+
def _sample_categorical(temperature=0.0001, *args_logits):
|
81 |
+
if len(args_logits) == 1:
|
82 |
+
arg_logits, = args_logits
|
83 |
+
return Categorical(logits=arg_logits / temperature).sample()
|
84 |
+
return (*(Categorical(logits=arg_logits / temperature).sample() for arg_logits in args_logits),)
|
85 |
+
|
86 |
+
|
87 |
+
def _threshold_sample(arg_logits, threshold=0.5, temperature=1.0):
|
88 |
+
scores = F.softmax(arg_logits / temperature, dim=-1)[..., 1]
|
89 |
+
return scores > threshold
|
src/preprocessing/deepsvg/deepsvg_schedulers/warmup.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
7 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
8 |
+
|
9 |
+
|
10 |
+
class GradualWarmupScheduler(_LRScheduler):
|
11 |
+
""" Gradually warm-up(increasing) learning rate in optimizer.
|
12 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
13 |
+
Args:
|
14 |
+
optimizer (Optimizer): Wrapped optimizer.
|
15 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
16 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
17 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
21 |
+
self.multiplier = multiplier
|
22 |
+
if self.multiplier < 1.:
|
23 |
+
raise ValueError('multiplier should be greater thant or equal to 1.')
|
24 |
+
self.total_epoch = total_epoch
|
25 |
+
self.after_scheduler = after_scheduler
|
26 |
+
self.finished = False
|
27 |
+
super(GradualWarmupScheduler, self).__init__(optimizer)
|
28 |
+
|
29 |
+
def get_lr(self):
|
30 |
+
if self.last_epoch > self.total_epoch:
|
31 |
+
if self.after_scheduler:
|
32 |
+
if not self.finished:
|
33 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
34 |
+
self.finished = True
|
35 |
+
return self.after_scheduler.get_last_lr()
|
36 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
37 |
+
|
38 |
+
if self.multiplier == 1.0:
|
39 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
40 |
+
else:
|
41 |
+
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
42 |
+
|
43 |
+
def step_ReduceLROnPlateau(self, metrics, epoch=None):
|
44 |
+
if epoch is None:
|
45 |
+
epoch = self.last_epoch + 1
|
46 |
+
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
47 |
+
if self.last_epoch <= self.total_epoch:
|
48 |
+
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
49 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
50 |
+
param_group['lr'] = lr
|
51 |
+
else:
|
52 |
+
if epoch is None:
|
53 |
+
self.after_scheduler.step(metrics, None)
|
54 |
+
else:
|
55 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
56 |
+
|
57 |
+
def step(self, epoch=None, metrics=None):
|
58 |
+
if type(self.after_scheduler) != ReduceLROnPlateau:
|
59 |
+
if self.finished and self.after_scheduler:
|
60 |
+
if epoch is None:
|
61 |
+
self.after_scheduler.step(None)
|
62 |
+
else:
|
63 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
64 |
+
self._last_lr = self.after_scheduler.get_last_lr()
|
65 |
+
else:
|
66 |
+
return super(GradualWarmupScheduler, self).step(epoch)
|
67 |
+
else:
|
68 |
+
self.step_ReduceLROnPlateau(metrics, epoch)
|
src/preprocessing/deepsvg/deepsvg_svglib/geom.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
import numpy as np
|
8 |
+
from enum import Enum
|
9 |
+
import torch
|
10 |
+
from typing import List, Union
|
11 |
+
Num = Union[int, float]
|
12 |
+
float_type = (int, float, np.float32)
|
13 |
+
|
14 |
+
|
15 |
+
def det(a: Point, b: Point):
|
16 |
+
return a.pos[0] * b.pos[1] - a.pos[1] * b.pos[0]
|
17 |
+
|
18 |
+
|
19 |
+
def get_rotation_matrix(angle: Union[Angle, float]):
|
20 |
+
if isinstance(angle, Angle):
|
21 |
+
theta = angle.rad
|
22 |
+
else:
|
23 |
+
theta = angle
|
24 |
+
c, s = np.cos(theta), np.sin(theta)
|
25 |
+
rot_m = np.array([[c, -s],
|
26 |
+
[s, c]], dtype=np.float32)
|
27 |
+
return rot_m
|
28 |
+
|
29 |
+
|
30 |
+
def union_bbox(bbox_list: List[Bbox]):
|
31 |
+
res = None
|
32 |
+
for bbox in bbox_list:
|
33 |
+
res = bbox.union(res)
|
34 |
+
return res
|
35 |
+
|
36 |
+
|
37 |
+
class Geom:
|
38 |
+
def copy(self):
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
def to_str(self):
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
def to_tensor(self):
|
45 |
+
raise NotImplementedError
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def from_tensor(vector: torch.Tensor):
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
def scale(self, factor):
|
52 |
+
pass
|
53 |
+
|
54 |
+
def translate(self, vec):
|
55 |
+
pass
|
56 |
+
|
57 |
+
def rotate(self, angle: Union[Angle, float]):
|
58 |
+
pass
|
59 |
+
|
60 |
+
def numericalize(self, n=256):
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
|
64 |
+
######### Point
|
65 |
+
class Point(Geom):
|
66 |
+
num_args = 2
|
67 |
+
|
68 |
+
def __init__(self, x=None, y=None):
|
69 |
+
if isinstance(x, np.ndarray):
|
70 |
+
self.pos = x.astype(np.float32)
|
71 |
+
elif x is None and y is None:
|
72 |
+
self.pos = np.array([0., 0.], dtype=np.float32)
|
73 |
+
elif (isinstance(x, float_type) or x is None) and (isinstance(y, float_type) or y is None):
|
74 |
+
if x is None:
|
75 |
+
x = y
|
76 |
+
if y is None:
|
77 |
+
y = x
|
78 |
+
self.pos = np.array([x, y], dtype=np.float32)
|
79 |
+
else:
|
80 |
+
raise ValueError()
|
81 |
+
|
82 |
+
def copy(self):
|
83 |
+
return Point(self.pos.copy())
|
84 |
+
|
85 |
+
@property
|
86 |
+
def x(self):
|
87 |
+
return self.pos[0]
|
88 |
+
|
89 |
+
@property
|
90 |
+
def y(self):
|
91 |
+
return self.pos[1]
|
92 |
+
|
93 |
+
def xproj(self):
|
94 |
+
return Point(self.x, 0.)
|
95 |
+
|
96 |
+
def yproj(self):
|
97 |
+
return Point(0., self.y)
|
98 |
+
|
99 |
+
def __add__(self, other):
|
100 |
+
return Point(self.pos + other.pos)
|
101 |
+
|
102 |
+
def __sub__(self, other):
|
103 |
+
return self + other.__neg__()
|
104 |
+
|
105 |
+
def __mul__(self, lmbda):
|
106 |
+
if isinstance(lmbda, Point):
|
107 |
+
return Point(self.pos * lmbda.pos)
|
108 |
+
|
109 |
+
assert isinstance(lmbda, float_type)
|
110 |
+
return Point(lmbda * self.pos)
|
111 |
+
|
112 |
+
def __rmul__(self, lmbda):
|
113 |
+
return self * lmbda
|
114 |
+
|
115 |
+
def __truediv__(self, lmbda):
|
116 |
+
if isinstance(lmbda, Point):
|
117 |
+
return Point(self.pos / lmbda.pos)
|
118 |
+
|
119 |
+
assert isinstance(lmbda, float_type)
|
120 |
+
return self * (1 / lmbda)
|
121 |
+
|
122 |
+
def __neg__(self):
|
123 |
+
return self * -1
|
124 |
+
|
125 |
+
def __repr__(self):
|
126 |
+
return f"P({self.x}, {self.y})"
|
127 |
+
|
128 |
+
def to_str(self):
|
129 |
+
return f"{self.x} {self.y}"
|
130 |
+
|
131 |
+
def tolist(self):
|
132 |
+
return self.pos.tolist()
|
133 |
+
|
134 |
+
def to_tensor(self):
|
135 |
+
return torch.tensor(self.pos)
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def from_tensor(vector: torch.Tensor):
|
139 |
+
return Point(*vector.tolist())
|
140 |
+
|
141 |
+
def translate(self, vec: Point):
|
142 |
+
self.pos += vec.pos
|
143 |
+
|
144 |
+
def matmul(self, m):
|
145 |
+
return Point(m @ self.pos)
|
146 |
+
|
147 |
+
def rotate(self, angle: Union[Angle, float]):
|
148 |
+
rot_m = get_rotation_matrix(angle)
|
149 |
+
return self.matmul(rot_m)
|
150 |
+
|
151 |
+
def rotate_(self, angle: Union[Angle, float]):
|
152 |
+
rot_m = get_rotation_matrix(angle)
|
153 |
+
self.pos = rot_m @ self.pos
|
154 |
+
|
155 |
+
def scale(self, factor):
|
156 |
+
self.pos *= factor
|
157 |
+
|
158 |
+
def dot(self, other: Point):
|
159 |
+
return self.pos.dot(other.pos)
|
160 |
+
|
161 |
+
def norm(self):
|
162 |
+
return float(np.linalg.norm(self.pos))
|
163 |
+
|
164 |
+
def cross(self, other: Point):
|
165 |
+
return np.cross(self.pos, other.pos)
|
166 |
+
|
167 |
+
def dist(self, other: Point):
|
168 |
+
return (self - other).norm()
|
169 |
+
|
170 |
+
def angle(self, other: Point, signed=False):
|
171 |
+
rad = np.arccos(np.clip(self.normalize().dot(other.normalize()), -1., 1.))
|
172 |
+
|
173 |
+
if signed:
|
174 |
+
sign = 1 if det(self, other) >= 0 else -1
|
175 |
+
rad *= sign
|
176 |
+
return Angle.Rad(rad)
|
177 |
+
|
178 |
+
def distToLine(self, p1: Point, p2: Point):
|
179 |
+
if p1.isclose(p2):
|
180 |
+
return self.dist(p1)
|
181 |
+
|
182 |
+
return abs((p2 - p1).cross(p1 - self)) / (p2 - p1).norm()
|
183 |
+
|
184 |
+
def normalize(self):
|
185 |
+
return self / self.norm()
|
186 |
+
|
187 |
+
def numericalize(self, n=256):
|
188 |
+
self.pos = self.pos.round().clip(min=0, max=n-1)
|
189 |
+
|
190 |
+
def isclose(self, other: Point):
|
191 |
+
return np.allclose(self.pos, other.pos)
|
192 |
+
|
193 |
+
def iszero(self):
|
194 |
+
return np.all(self.pos == 0)
|
195 |
+
|
196 |
+
def pointwise_min(self, other: Point):
|
197 |
+
return Point(min(self.x, other.x), min(self.y, other.y))
|
198 |
+
|
199 |
+
def pointwise_max(self, other: Point):
|
200 |
+
return Point(max(self.x, other.x), max(self.y, other.y))
|
201 |
+
|
202 |
+
|
203 |
+
class Radius(Point):
|
204 |
+
def __init__(self, *args, **kwargs):
|
205 |
+
super().__init__(*args, **kwargs)
|
206 |
+
|
207 |
+
def copy(self):
|
208 |
+
return Radius(self.pos.copy())
|
209 |
+
|
210 |
+
def __repr__(self):
|
211 |
+
return f"Rad({self.pos[0]}, {self.pos[1]})"
|
212 |
+
|
213 |
+
def translate(self, vec: Point):
|
214 |
+
pass
|
215 |
+
|
216 |
+
|
217 |
+
class Size(Point):
|
218 |
+
def __init__(self, *args, **kwargs):
|
219 |
+
super().__init__(*args, **kwargs)
|
220 |
+
|
221 |
+
def copy(self):
|
222 |
+
return Size(self.pos.copy())
|
223 |
+
|
224 |
+
def __repr__(self):
|
225 |
+
return f"Size({self.pos[0]}, {self.pos[1]})"
|
226 |
+
|
227 |
+
def max(self):
|
228 |
+
return self.pos.max()
|
229 |
+
|
230 |
+
def min(self):
|
231 |
+
return self.pos.min()
|
232 |
+
|
233 |
+
def translate(self, vec: Point):
|
234 |
+
pass
|
235 |
+
|
236 |
+
|
237 |
+
######### Coord
|
238 |
+
class Coord(Geom):
|
239 |
+
num_args = 1
|
240 |
+
|
241 |
+
class XY(Enum):
|
242 |
+
X = "x"
|
243 |
+
Y = "y"
|
244 |
+
|
245 |
+
def __init__(self, coord, xy: XY = XY.X):
|
246 |
+
self.coord = coord
|
247 |
+
self.xy = xy
|
248 |
+
|
249 |
+
def __repr__(self):
|
250 |
+
return f"{self.xy.value}({self.coord})"
|
251 |
+
|
252 |
+
def to_str(self):
|
253 |
+
return str(self.coord)
|
254 |
+
|
255 |
+
def to_tensor(self):
|
256 |
+
return torch.tensor([self.coord])
|
257 |
+
|
258 |
+
def __add__(self, other):
|
259 |
+
if isinstance(other, float_type):
|
260 |
+
return Coord(self.coord + other, self.xy)
|
261 |
+
elif isinstance(other, Coord):
|
262 |
+
if self.xy != other.xy:
|
263 |
+
raise ValueError()
|
264 |
+
return Coord(self.coord + other.coord, self.xy)
|
265 |
+
elif isinstance(other, Point):
|
266 |
+
return Coord(self.coord + getattr(other, self.xy.value), self.xy)
|
267 |
+
else:
|
268 |
+
raise ValueError()
|
269 |
+
|
270 |
+
def __sub__(self, other):
|
271 |
+
return self + other.__neg__()
|
272 |
+
|
273 |
+
def __mul__(self, lmbda):
|
274 |
+
assert isinstance(lmbda, float_type)
|
275 |
+
return Coord(lmbda * self.coord)
|
276 |
+
|
277 |
+
def __neg__(self):
|
278 |
+
return self * -1
|
279 |
+
|
280 |
+
def scale(self, factor):
|
281 |
+
self.coord *= factor
|
282 |
+
|
283 |
+
def translate(self, vec: Point):
|
284 |
+
self.coord += getattr(vec, self.xy.value)
|
285 |
+
|
286 |
+
def to_point(self, pos: Point, is_absolute=True):
|
287 |
+
point = pos.copy() if is_absolute else Point(0.)
|
288 |
+
point.pos[int(self.xy == Coord.XY.Y)] = self.coord
|
289 |
+
return point
|
290 |
+
|
291 |
+
|
292 |
+
class XCoord(Coord):
|
293 |
+
def __init__(self, coord):
|
294 |
+
super().__init__(coord, xy=Coord.XY.X)
|
295 |
+
|
296 |
+
def copy(self):
|
297 |
+
return XCoord(self.coord)
|
298 |
+
|
299 |
+
|
300 |
+
class YCoord(Coord):
|
301 |
+
def __init__(self, coord):
|
302 |
+
super().__init__(coord, xy=Coord.XY.Y)
|
303 |
+
|
304 |
+
def copy(self):
|
305 |
+
return YCoord(self.coord)
|
306 |
+
|
307 |
+
|
308 |
+
######### Bbox
|
309 |
+
class Bbox(Geom):
|
310 |
+
num_args = 4
|
311 |
+
|
312 |
+
def __init__(self, x=None, y=None, w=None, h=None):
|
313 |
+
if isinstance(x, Point) and isinstance(y, Point):
|
314 |
+
self.xy = x
|
315 |
+
wh = y - x
|
316 |
+
self.wh = Size(wh.x, wh.y)
|
317 |
+
elif (isinstance(x, float_type) or x is None) and (isinstance(y, float_type) or y is None):
|
318 |
+
if x is None:
|
319 |
+
x = 0.
|
320 |
+
if y is None:
|
321 |
+
y = float(x)
|
322 |
+
|
323 |
+
if w is None and h is None:
|
324 |
+
w, h = float(x), float(y)
|
325 |
+
x, y = 0., 0.
|
326 |
+
self.xy = Point(x, y)
|
327 |
+
self.wh = Size(w, h)
|
328 |
+
else:
|
329 |
+
raise ValueError()
|
330 |
+
|
331 |
+
@property
|
332 |
+
def xy2(self):
|
333 |
+
return self.xy + self.wh
|
334 |
+
|
335 |
+
def copy(self):
|
336 |
+
bbox = Bbox()
|
337 |
+
bbox.xy = self.xy.copy()
|
338 |
+
bbox.wh = self.wh.copy()
|
339 |
+
return bbox
|
340 |
+
|
341 |
+
@property
|
342 |
+
def size(self):
|
343 |
+
return self.wh
|
344 |
+
|
345 |
+
@property
|
346 |
+
def center(self):
|
347 |
+
return self.xy + self.wh / 2
|
348 |
+
|
349 |
+
def __repr__(self):
|
350 |
+
return f"Bbox({self.xy.to_str()} {self.wh.to_str()})"
|
351 |
+
|
352 |
+
def to_str(self):
|
353 |
+
return f"{self.xy.to_str()} {self.wh.to_str()}"
|
354 |
+
|
355 |
+
def to_tensor(self):
|
356 |
+
return torch.tensor([*self.xy.to_tensor(), *self.wh.to_tensor()])
|
357 |
+
|
358 |
+
def make_square(self, min_size=None):
|
359 |
+
center = self.center
|
360 |
+
size = self.wh.max()
|
361 |
+
|
362 |
+
if min_size is not None:
|
363 |
+
size = max(size, min_size)
|
364 |
+
|
365 |
+
self.wh = Size(size, size)
|
366 |
+
self.xy = center - self.wh / 2
|
367 |
+
|
368 |
+
return self
|
369 |
+
|
370 |
+
def translate(self, vec):
|
371 |
+
self.xy.translate(vec)
|
372 |
+
|
373 |
+
def scale(self, factor):
|
374 |
+
self.xy.scale(factor)
|
375 |
+
self.wh.scale(factor)
|
376 |
+
|
377 |
+
def union(self, other: Bbox):
|
378 |
+
if other is None:
|
379 |
+
return self
|
380 |
+
return Bbox(self.xy.pointwise_min(other.xy), self.xy2.pointwise_max(other.xy2))
|
381 |
+
|
382 |
+
def intersect(self, other: Bbox):
|
383 |
+
if other is None:
|
384 |
+
return self
|
385 |
+
|
386 |
+
bbox = Bbox(self.xy.pointwise_max(other.xy), self.xy2.pointwise_min(other.xy2))
|
387 |
+
if bbox.wh.x < 0 or bbox.wh.y < 0:
|
388 |
+
return None
|
389 |
+
|
390 |
+
return bbox
|
391 |
+
|
392 |
+
@staticmethod
|
393 |
+
def from_points(points: List[Point]):
|
394 |
+
if not points:
|
395 |
+
return None
|
396 |
+
xy = xy2 = points[0]
|
397 |
+
for p in points[1:]:
|
398 |
+
xy = xy.pointwise_min(p)
|
399 |
+
xy2 = xy2.pointwise_max(p)
|
400 |
+
return Bbox(xy, xy2)
|
401 |
+
|
402 |
+
def to_rectangle(self, *args, **kwargs):
|
403 |
+
from .svg_primitive import SVGRectangle
|
404 |
+
return SVGRectangle(self.xy, self.wh, *args, **kwargs)
|
405 |
+
|
406 |
+
def area(self):
|
407 |
+
return self.wh.pos.prod()
|
408 |
+
|
409 |
+
def overlap(self, other):
|
410 |
+
inter = self.intersect(other)
|
411 |
+
if inter is None:
|
412 |
+
return 0.
|
413 |
+
return inter.area() / self.area()
|
414 |
+
|
415 |
+
|
416 |
+
######### Angle
|
417 |
+
class Angle(Geom):
|
418 |
+
num_args = 1
|
419 |
+
|
420 |
+
def __init__(self, deg):
|
421 |
+
self.deg = deg
|
422 |
+
|
423 |
+
@property
|
424 |
+
def rad(self):
|
425 |
+
return np.deg2rad(self.deg)
|
426 |
+
|
427 |
+
def copy(self):
|
428 |
+
return Angle(self.deg)
|
429 |
+
|
430 |
+
def __repr__(self):
|
431 |
+
return f"α({self.deg})"
|
432 |
+
|
433 |
+
def to_str(self):
|
434 |
+
return str(self.deg)
|
435 |
+
|
436 |
+
def to_tensor(self):
|
437 |
+
return torch.tensor([self.deg])
|
438 |
+
|
439 |
+
@staticmethod
|
440 |
+
def from_tensor(vector: torch.Tensor):
|
441 |
+
return Angle(vector.item())
|
442 |
+
|
443 |
+
@staticmethod
|
444 |
+
def Rad(rad):
|
445 |
+
return Angle(np.rad2deg(rad))
|
446 |
+
|
447 |
+
def __add__(self, other: Angle):
|
448 |
+
return Angle(self.deg + other.deg)
|
449 |
+
|
450 |
+
def __sub__(self, other: Angle):
|
451 |
+
return self + other.__neg__()
|
452 |
+
|
453 |
+
def __mul__(self, lmbda):
|
454 |
+
assert isinstance(lmbda, float_type)
|
455 |
+
return Angle(lmbda * self.deg)
|
456 |
+
|
457 |
+
def __rmul__(self, lmbda):
|
458 |
+
assert isinstance(lmbda, float_type)
|
459 |
+
return self * lmbda
|
460 |
+
|
461 |
+
def __truediv__(self, lmbda):
|
462 |
+
assert isinstance(lmbda, float_type)
|
463 |
+
return self * (1 / lmbda)
|
464 |
+
|
465 |
+
def __neg__(self):
|
466 |
+
return self * -1
|
467 |
+
|
468 |
+
|
469 |
+
######### Flag
|
470 |
+
class Flag(Geom):
|
471 |
+
num_args = 1
|
472 |
+
|
473 |
+
def __init__(self, flag):
|
474 |
+
self.flag = int(flag)
|
475 |
+
|
476 |
+
def copy(self):
|
477 |
+
return Flag(self.flag)
|
478 |
+
|
479 |
+
def __repr__(self):
|
480 |
+
return f"flag({self.flag})"
|
481 |
+
|
482 |
+
def to_str(self):
|
483 |
+
return str(self.flag)
|
484 |
+
|
485 |
+
def to_tensor(self):
|
486 |
+
return torch.tensor([self.flag])
|
487 |
+
|
488 |
+
def __invert__(self):
|
489 |
+
return Flag(1 - self.flag)
|
490 |
+
|
491 |
+
@staticmethod
|
492 |
+
def from_tensor(vector: torch.Tensor):
|
493 |
+
return Flag(vector.item())
|
src/preprocessing/deepsvg/deepsvg_svglib/svg.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
from .geom import *
|
8 |
+
from xml.dom import expatbuilder
|
9 |
+
import torch
|
10 |
+
from typing import List, Union
|
11 |
+
import IPython.display as ipd
|
12 |
+
import cairosvg
|
13 |
+
from PIL import Image
|
14 |
+
import io
|
15 |
+
import os
|
16 |
+
from moviepy.editor import ImageClip, concatenate_videoclips, ipython_display
|
17 |
+
import math
|
18 |
+
import random
|
19 |
+
import networkx as nx
|
20 |
+
|
21 |
+
Num = Union[int, float]
|
22 |
+
|
23 |
+
from .svg_command import SVGCommandBezier
|
24 |
+
from .svg_path import SVGPath, Filling, Orientation
|
25 |
+
from .svg_primitive import SVGPathGroup, SVGRectangle, SVGCircle, SVGEllipse, SVGLine, SVGPolyline, SVGPolygon
|
26 |
+
from .geom import union_bbox
|
27 |
+
|
28 |
+
|
29 |
+
class SVG:
|
30 |
+
def __init__(self, svg_path_groups: List[SVGPathGroup], viewbox: Bbox = None):
|
31 |
+
if viewbox is None:
|
32 |
+
viewbox = Bbox(24)
|
33 |
+
|
34 |
+
self.svg_path_groups = svg_path_groups
|
35 |
+
self.viewbox = viewbox
|
36 |
+
|
37 |
+
def __add__(self, other: SVG):
|
38 |
+
svg = self.copy()
|
39 |
+
svg.svg_path_groups.extend(other.svg_path_groups)
|
40 |
+
return svg
|
41 |
+
|
42 |
+
@property
|
43 |
+
def paths(self):
|
44 |
+
for path_group in self.svg_path_groups:
|
45 |
+
for path in path_group.svg_paths:
|
46 |
+
yield path
|
47 |
+
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
if isinstance(idx, tuple):
|
50 |
+
assert len(idx) == 2, "Dimension out of range"
|
51 |
+
i, j = idx
|
52 |
+
return self.svg_path_groups[i][j]
|
53 |
+
|
54 |
+
return self.svg_path_groups[idx]
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.svg_path_groups)
|
58 |
+
|
59 |
+
def total_length(self):
|
60 |
+
return sum([path_group.total_len() for path_group in self.svg_path_groups])
|
61 |
+
|
62 |
+
@property
|
63 |
+
def start_pos(self):
|
64 |
+
return Point(0.)
|
65 |
+
|
66 |
+
@property
|
67 |
+
def end_pos(self):
|
68 |
+
if not self.svg_path_groups:
|
69 |
+
return Point(0.)
|
70 |
+
|
71 |
+
return self.svg_path_groups[-1].end_pos
|
72 |
+
|
73 |
+
def copy(self):
|
74 |
+
return SVG([svg_path_group.copy() for svg_path_group in self.svg_path_groups], self.viewbox.copy())
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def load_svg(file_path):
|
78 |
+
with open(file_path, "r") as f:
|
79 |
+
return SVG.from_str(f.read())
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def load_splineset(spline_str: str, width, height, add_closing=True):
|
83 |
+
if "SplineSet" not in spline_str:
|
84 |
+
raise ValueError("Not a SplineSet")
|
85 |
+
|
86 |
+
spline = spline_str[spline_str.index('SplineSet') + 10:spline_str.index('EndSplineSet')]
|
87 |
+
svg_str = SVG._spline_to_svg_str(spline, height)
|
88 |
+
|
89 |
+
if not svg_str:
|
90 |
+
raise ValueError("Empty SplineSet")
|
91 |
+
|
92 |
+
svg_path_group = SVGPath.from_str(svg_str, add_closing=add_closing)
|
93 |
+
return SVG([svg_path_group], viewbox=Bbox(width, height))
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def _spline_to_svg_str(spline_str: str, height, replace_with_prev=False):
|
97 |
+
path = []
|
98 |
+
prev_xy = []
|
99 |
+
for line in spline_str.splitlines():
|
100 |
+
if not line:
|
101 |
+
continue
|
102 |
+
tokens = line.split(' ')
|
103 |
+
cmd = tokens[-2]
|
104 |
+
if cmd not in 'cml':
|
105 |
+
raise ValueError(f"Command not recognized: {cmd}")
|
106 |
+
args = tokens[:-2]
|
107 |
+
args = [float(x) for x in args if x]
|
108 |
+
|
109 |
+
if replace_with_prev and cmd in 'c':
|
110 |
+
args[:2] = prev_xy
|
111 |
+
prev_xy = args[-2:]
|
112 |
+
|
113 |
+
new_y_args = []
|
114 |
+
for i, a in enumerate(args):
|
115 |
+
if i % 2 == 1:
|
116 |
+
new_y_args.append(str(height - a))
|
117 |
+
else:
|
118 |
+
new_y_args.append(str(a))
|
119 |
+
|
120 |
+
path.extend([cmd.upper()] + new_y_args)
|
121 |
+
return " ".join(path)
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def from_str(svg_str: str):
|
125 |
+
svg_path_groups = []
|
126 |
+
svg_dom = expatbuilder.parseString(svg_str, False)
|
127 |
+
svg_root = svg_dom.getElementsByTagName('svg')[0]
|
128 |
+
|
129 |
+
viewbox_list = list(map(float, svg_root.getAttribute("viewBox").split(" ")))
|
130 |
+
view_box = Bbox(*viewbox_list)
|
131 |
+
|
132 |
+
primitives = {
|
133 |
+
"path": SVGPath,
|
134 |
+
"rect": SVGRectangle,
|
135 |
+
"circle": SVGCircle, "ellipse": SVGEllipse,
|
136 |
+
"line": SVGLine,
|
137 |
+
"polyline": SVGPolyline, "polygon": SVGPolygon
|
138 |
+
}
|
139 |
+
|
140 |
+
for tag, Primitive in primitives.items():
|
141 |
+
for x in svg_dom.getElementsByTagName(tag):
|
142 |
+
svg_path_groups.append(Primitive.from_xml(x))
|
143 |
+
|
144 |
+
return SVG(svg_path_groups, view_box)
|
145 |
+
|
146 |
+
def to_tensor(self, concat_groups=True, PAD_VAL=-1):
|
147 |
+
group_tensors = [p.to_tensor(PAD_VAL=PAD_VAL) for p in self.svg_path_groups]
|
148 |
+
|
149 |
+
if concat_groups:
|
150 |
+
return torch.cat(group_tensors, dim=0)
|
151 |
+
|
152 |
+
return group_tensors
|
153 |
+
|
154 |
+
def to_fillings(self):
|
155 |
+
return [p.path.filling for p in self.svg_path_groups]
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def from_tensor(tensor: torch.Tensor, viewbox: Bbox = None, allow_empty=False):
|
159 |
+
if viewbox is None:
|
160 |
+
viewbox = Bbox(24)
|
161 |
+
|
162 |
+
svg = SVG([SVGPath.from_tensor(tensor, allow_empty=allow_empty)], viewbox=viewbox)
|
163 |
+
return svg
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def from_tensors(tensors: List[torch.Tensor], viewbox: Bbox = None, allow_empty=False):
|
167 |
+
if viewbox is None:
|
168 |
+
viewbox = Bbox(24)
|
169 |
+
|
170 |
+
svg = SVG([SVGPath.from_tensor(t, allow_empty=allow_empty) for t in tensors], viewbox=viewbox)
|
171 |
+
return svg
|
172 |
+
|
173 |
+
def save_svg(self, file_path):
|
174 |
+
with open(file_path, "w") as f:
|
175 |
+
f.write(self.to_str())
|
176 |
+
|
177 |
+
def save_png(self, file_path):
|
178 |
+
cairosvg.svg2png(bytestring=self.to_str(), write_to=file_path)
|
179 |
+
|
180 |
+
def draw(self, fill=False, file_path=None, do_display=True, return_png=False,
|
181 |
+
with_points=False, with_handles=False, with_bboxes=False, with_markers=False, color_firstlast=False,
|
182 |
+
with_moves=True):
|
183 |
+
if file_path is not None:
|
184 |
+
_, file_extension = os.path.splitext(file_path)
|
185 |
+
if file_extension == ".svg":
|
186 |
+
self.save_svg(file_path)
|
187 |
+
elif file_extension == ".png":
|
188 |
+
self.save_png(file_path)
|
189 |
+
else:
|
190 |
+
raise ValueError(f"Unsupported file_path extension {file_extension}")
|
191 |
+
|
192 |
+
svg_str = self.to_str(fill=fill, with_points=with_points, with_handles=with_handles, with_bboxes=with_bboxes,
|
193 |
+
with_markers=with_markers, color_firstlast=color_firstlast, with_moves=with_moves)
|
194 |
+
|
195 |
+
if do_display:
|
196 |
+
ipd.display(ipd.SVG(svg_str))
|
197 |
+
|
198 |
+
if return_png:
|
199 |
+
if file_path is None:
|
200 |
+
img_data = cairosvg.svg2png(bytestring=svg_str)
|
201 |
+
return Image.open(io.BytesIO(img_data))
|
202 |
+
else:
|
203 |
+
_, file_extension = os.path.splitext(file_path)
|
204 |
+
|
205 |
+
if file_extension == ".svg":
|
206 |
+
img_data = cairosvg.svg2png(url=file_path)
|
207 |
+
return Image.open(io.BytesIO(img_data))
|
208 |
+
else:
|
209 |
+
return Image.open(file_path)
|
210 |
+
|
211 |
+
def draw_colored(self, *args, **kwargs):
|
212 |
+
self.copy().normalize().split_paths().set_color("random").draw(*args, **kwargs)
|
213 |
+
|
214 |
+
def __repr__(self):
|
215 |
+
return "SVG[{}](\n{}\n)".format(self.viewbox,
|
216 |
+
",\n".join([f"\t{svg_path_group}" for svg_path_group in self.svg_path_groups]))
|
217 |
+
|
218 |
+
def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=False,
|
219 |
+
with_moves=True):
|
220 |
+
viz_elements = []
|
221 |
+
for svg_path_group in self.svg_path_groups:
|
222 |
+
viz_elements.extend(
|
223 |
+
svg_path_group._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves))
|
224 |
+
return viz_elements
|
225 |
+
|
226 |
+
def _markers(self):
|
227 |
+
return ('<defs>'
|
228 |
+
'<marker id="arrow" viewBox="0 0 10 10" markerWidth="4" markerHeight="4" refX="0" refY="3" orient="auto" markerUnits="strokeWidth">'
|
229 |
+
'<path d="M0,0 L0,6 L9,3 z" fill="#f00" />'
|
230 |
+
'</marker>'
|
231 |
+
'</defs>')
|
232 |
+
|
233 |
+
def to_str(self, fill=False, with_points=False, with_handles=False, with_bboxes=False, with_markers=False,
|
234 |
+
color_firstlast=False, with_moves=True) -> str:
|
235 |
+
viz_elements = self._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves)
|
236 |
+
newline = "\n"
|
237 |
+
return (
|
238 |
+
f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="{self.viewbox.to_str()}" height="200px" width="200px">'
|
239 |
+
f'{self._markers() if with_markers else ""}'
|
240 |
+
f'{newline.join(svg_path_group.to_str(fill=fill, with_markers=with_markers) for svg_path_group in [*self.svg_path_groups, *viz_elements])}'
|
241 |
+
'</svg>')
|
242 |
+
|
243 |
+
def _apply_to_paths(self, method, *args, **kwargs):
|
244 |
+
for path_group in self.svg_path_groups:
|
245 |
+
getattr(path_group, method)(*args, **kwargs)
|
246 |
+
return self
|
247 |
+
|
248 |
+
def split_paths(self):
|
249 |
+
path_groups = []
|
250 |
+
for path_group in self.svg_path_groups:
|
251 |
+
path_groups.extend(path_group.split_paths())
|
252 |
+
self.svg_path_groups = path_groups
|
253 |
+
return self
|
254 |
+
|
255 |
+
def merge_groups(self):
|
256 |
+
path_group = self.svg_path_groups[0]
|
257 |
+
for path_group in self.svg_path_groups[1:]:
|
258 |
+
path_group.svg_paths.extend(path_group.svg_paths)
|
259 |
+
self.svg_path_groups = [path_group]
|
260 |
+
return self
|
261 |
+
|
262 |
+
def empty(self):
|
263 |
+
return len(self.svg_path_groups) == 0
|
264 |
+
|
265 |
+
def drop_z(self):
|
266 |
+
return self._apply_to_paths("drop_z")
|
267 |
+
|
268 |
+
def filter_empty(self):
|
269 |
+
self._apply_to_paths("filter_empty")
|
270 |
+
self.svg_path_groups = [path_group for path_group in self.svg_path_groups if path_group.svg_paths]
|
271 |
+
return self
|
272 |
+
|
273 |
+
def translate(self, vec: Point):
|
274 |
+
return self._apply_to_paths("translate", vec)
|
275 |
+
|
276 |
+
def rotate(self, angle: Angle, center: Point = None):
|
277 |
+
if center is None:
|
278 |
+
center = self.viewbox.center
|
279 |
+
|
280 |
+
self.translate(-self.viewbox.center)
|
281 |
+
self._apply_to_paths("rotate", angle)
|
282 |
+
self.translate(center)
|
283 |
+
|
284 |
+
return self
|
285 |
+
|
286 |
+
def zoom(self, factor, center: Point = None):
|
287 |
+
if center is None:
|
288 |
+
center = self.viewbox.center
|
289 |
+
|
290 |
+
self.translate(-self.viewbox.center)
|
291 |
+
self._apply_to_paths("scale", factor)
|
292 |
+
self.translate(center)
|
293 |
+
|
294 |
+
return self
|
295 |
+
|
296 |
+
def normalize(self, viewbox: Bbox = None):
|
297 |
+
if viewbox is None:
|
298 |
+
viewbox = Bbox(24)
|
299 |
+
|
300 |
+
size = self.viewbox.size
|
301 |
+
scale_factor = viewbox.size.min() / size.max()
|
302 |
+
self.zoom(scale_factor, viewbox.center)
|
303 |
+
self.viewbox = viewbox
|
304 |
+
|
305 |
+
return self
|
306 |
+
|
307 |
+
def compute_filling(self):
|
308 |
+
return self._apply_to_paths("compute_filling")
|
309 |
+
|
310 |
+
def recompute_origins(self):
|
311 |
+
origin = self.start_pos
|
312 |
+
|
313 |
+
for path_group in self.svg_path_groups:
|
314 |
+
path_group.set_origin(origin.copy())
|
315 |
+
origin = path_group.end_pos
|
316 |
+
|
317 |
+
def canonicalize_new(self, normalize=False):
|
318 |
+
self.to_path().simplify_arcs()
|
319 |
+
|
320 |
+
self.compute_filling()
|
321 |
+
|
322 |
+
if normalize:
|
323 |
+
self.normalize()
|
324 |
+
|
325 |
+
self.split_paths()
|
326 |
+
|
327 |
+
self.filter_consecutives()
|
328 |
+
self.filter_empty()
|
329 |
+
self._apply_to_paths("reorder")
|
330 |
+
self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
|
331 |
+
self._apply_to_paths("canonicalize")
|
332 |
+
self.recompute_origins()
|
333 |
+
|
334 |
+
self.drop_z()
|
335 |
+
|
336 |
+
return self
|
337 |
+
|
338 |
+
def canonicalize(self, normalize=False):
|
339 |
+
self.to_path().simplify_arcs()
|
340 |
+
|
341 |
+
if normalize:
|
342 |
+
self.normalize()
|
343 |
+
|
344 |
+
self.split_paths()
|
345 |
+
self.filter_consecutives()
|
346 |
+
self.filter_empty()
|
347 |
+
self._apply_to_paths("reorder")
|
348 |
+
self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
|
349 |
+
self._apply_to_paths("canonicalize")
|
350 |
+
self.recompute_origins()
|
351 |
+
|
352 |
+
self.drop_z()
|
353 |
+
|
354 |
+
return self
|
355 |
+
|
356 |
+
def reorder(self):
|
357 |
+
return self._apply_to_paths("reorder")
|
358 |
+
|
359 |
+
def canonicalize_old(self):
|
360 |
+
self.filter_empty()
|
361 |
+
self._apply_to_paths("reorder")
|
362 |
+
self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
|
363 |
+
self._apply_to_paths("canonicalize")
|
364 |
+
self.split_paths()
|
365 |
+
self.recompute_origins()
|
366 |
+
|
367 |
+
self.drop_z()
|
368 |
+
|
369 |
+
return self
|
370 |
+
|
371 |
+
def to_video(self, wrapper, color="grey"):
|
372 |
+
clips, svg_commands = [], []
|
373 |
+
|
374 |
+
im = SVG([]).draw(do_display=False, return_png=True)
|
375 |
+
clips.append(wrapper(np.array(im)))
|
376 |
+
|
377 |
+
for svg_path in self.paths:
|
378 |
+
clips, svg_commands = svg_path.to_video(wrapper, clips, svg_commands, color=color)
|
379 |
+
|
380 |
+
im = self.draw(do_display=False, return_png=True)
|
381 |
+
clips.append(wrapper(np.array(im)))
|
382 |
+
|
383 |
+
return clips
|
384 |
+
|
385 |
+
def animate(self, file_path=None, frame_duration=0.1, do_display=True):
|
386 |
+
clips = self.to_video(lambda img: ImageClip(img).set_duration(frame_duration))
|
387 |
+
|
388 |
+
clip = concatenate_videoclips(clips, method="compose", bg_color=(255, 255, 255))
|
389 |
+
|
390 |
+
if file_path is not None:
|
391 |
+
clip.write_gif(file_path, fps=24, verbose=False, logger=None)
|
392 |
+
|
393 |
+
if do_display:
|
394 |
+
src = clip if file_path is None else file_path
|
395 |
+
ipd.display(ipython_display(src, fps=24, rd_kwargs=dict(logger=None), autoplay=1, loop=1))
|
396 |
+
|
397 |
+
def numericalize(self, n=256):
|
398 |
+
self.normalize(viewbox=Bbox(n))
|
399 |
+
return self._apply_to_paths("numericalize", n)
|
400 |
+
|
401 |
+
def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False):
|
402 |
+
self._apply_to_paths("simplify", tolerance=tolerance, epsilon=epsilon, angle_threshold=angle_threshold,
|
403 |
+
force_smooth=force_smooth)
|
404 |
+
self.recompute_origins()
|
405 |
+
return self
|
406 |
+
|
407 |
+
def reverse(self):
|
408 |
+
self._apply_to_paths("reverse")
|
409 |
+
return self
|
410 |
+
|
411 |
+
def reverse_non_closed(self):
|
412 |
+
self._apply_to_paths("reverse_non_closed")
|
413 |
+
return self
|
414 |
+
|
415 |
+
def duplicate_extremities(self):
|
416 |
+
self._apply_to_paths("duplicate_extremities")
|
417 |
+
return self
|
418 |
+
|
419 |
+
def simplify_heuristic(self, tolerance=0.1, force_smooth=False):
|
420 |
+
return self.copy().split(max_dist=2, include_lines=False) \
|
421 |
+
.simplify(tolerance=tolerance, epsilon=0.2, angle_threshold=150, force_smooth=force_smooth) \
|
422 |
+
.split(max_dist=7.5)
|
423 |
+
|
424 |
+
def simplify_heuristic2(self):
|
425 |
+
return self.copy().split(max_dist=2, include_lines=False) \
|
426 |
+
.simplify(tolerance=0.2, epsilon=0.2, angle_threshold=150) \
|
427 |
+
.split(max_dist=7.5)
|
428 |
+
|
429 |
+
def split(self, n=None, max_dist=None, include_lines=True):
|
430 |
+
return self._apply_to_paths("split", n=n, max_dist=max_dist, include_lines=include_lines)
|
431 |
+
|
432 |
+
@staticmethod
|
433 |
+
def unit_circle():
|
434 |
+
d = 2 * (math.sqrt(2) - 1) / 3
|
435 |
+
|
436 |
+
circle = SVGPath([
|
437 |
+
SVGCommandBezier(Point(.5, 0.), Point(.5 + d, 0.), Point(1., .5 - d), Point(1., .5)),
|
438 |
+
SVGCommandBezier(Point(1., .5), Point(1., .5 + d), Point(.5 + d, 1.), Point(.5, 1.)),
|
439 |
+
SVGCommandBezier(Point(.5, 1.), Point(.5 - d, 1.), Point(0., .5 + d), Point(0., .5)),
|
440 |
+
SVGCommandBezier(Point(0., .5), Point(0., .5 - d), Point(.5 - d, 0.), Point(.5, 0.))
|
441 |
+
]).to_group()
|
442 |
+
|
443 |
+
return SVG([circle], viewbox=Bbox(1))
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def unit_square():
|
447 |
+
square = SVGPath.from_str("m 0,0 h1 v1 h-1 v-1")
|
448 |
+
return SVG([square], viewbox=Bbox(1))
|
449 |
+
|
450 |
+
def add_path_group(self, path_group: SVGPathGroup):
|
451 |
+
path_group.set_origin(self.end_pos.copy())
|
452 |
+
self.svg_path_groups.append(path_group)
|
453 |
+
|
454 |
+
return self
|
455 |
+
|
456 |
+
def add_path_groups(self, path_groups: List[SVGPathGroup]):
|
457 |
+
for path_group in path_groups:
|
458 |
+
self.add_path_group(path_group)
|
459 |
+
|
460 |
+
return self
|
461 |
+
|
462 |
+
def simplify_arcs(self):
|
463 |
+
return self._apply_to_paths("simplify_arcs")
|
464 |
+
|
465 |
+
def to_path(self):
|
466 |
+
for i, path_group in enumerate(self.svg_path_groups):
|
467 |
+
self.svg_path_groups[i] = path_group.to_path()
|
468 |
+
return self
|
469 |
+
|
470 |
+
def filter_consecutives(self):
|
471 |
+
return self._apply_to_paths("filter_consecutives")
|
472 |
+
|
473 |
+
def filter_duplicates(self):
|
474 |
+
return self._apply_to_paths("filter_duplicates")
|
475 |
+
|
476 |
+
def set_color(self, color):
|
477 |
+
colors = ["deepskyblue", "lime", "deeppink", "gold", "coral", "darkviolet", "royalblue", "darkmagenta", "teal",
|
478 |
+
"gold",
|
479 |
+
"green", "maroon", "aqua", "grey", "steelblue", "lime", "orange"]
|
480 |
+
|
481 |
+
if color == "random_random":
|
482 |
+
random.shuffle(colors)
|
483 |
+
|
484 |
+
if isinstance(color, list):
|
485 |
+
colors = color
|
486 |
+
|
487 |
+
for i, path_group in enumerate(self.svg_path_groups):
|
488 |
+
if color == "random" or color == "random_random" or isinstance(color, list):
|
489 |
+
c = colors[i % len(colors)]
|
490 |
+
else:
|
491 |
+
c = color
|
492 |
+
path_group.color = c
|
493 |
+
return self
|
494 |
+
|
495 |
+
def bbox(self):
|
496 |
+
return union_bbox([path_group.bbox() for path_group in self.svg_path_groups])
|
497 |
+
|
498 |
+
def overlap_graph(self, threshold=0.95, draw=False):
|
499 |
+
G = nx.DiGraph()
|
500 |
+
shapes = [group.to_shapely() for group in self.svg_path_groups]
|
501 |
+
|
502 |
+
for i, group1 in enumerate(shapes):
|
503 |
+
G.add_node(i)
|
504 |
+
|
505 |
+
if self.svg_path_groups[i].path.filling != Filling.OUTLINE:
|
506 |
+
|
507 |
+
for j, group2 in enumerate(shapes):
|
508 |
+
if i != j and self.svg_path_groups[j].path.filling == Filling.FILL:
|
509 |
+
overlap = group1.intersection(group2).area / group1.area
|
510 |
+
if overlap > threshold:
|
511 |
+
G.add_edge(j, i, weight=overlap)
|
512 |
+
|
513 |
+
if draw:
|
514 |
+
pos = nx.spring_layout(G)
|
515 |
+
nx.draw_networkx(G, pos, with_labels=True)
|
516 |
+
labels = nx.get_edge_attributes(G, 'weight')
|
517 |
+
nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
|
518 |
+
return G
|
519 |
+
|
520 |
+
def group_overlapping_paths(self):
|
521 |
+
G = self.overlap_graph()
|
522 |
+
|
523 |
+
path_groups = []
|
524 |
+
root_nodes = [i for i, d in G.in_degree() if d == 0]
|
525 |
+
|
526 |
+
for root in root_nodes:
|
527 |
+
if self[root].path.filling == Filling.FILL:
|
528 |
+
current = [root]
|
529 |
+
|
530 |
+
while current:
|
531 |
+
n = current.pop(0)
|
532 |
+
|
533 |
+
fill_neighbors, erase_neighbors = [], []
|
534 |
+
for m in G.neighbors(n):
|
535 |
+
if G.in_degree(m) == 1:
|
536 |
+
if self[m].path.filling == Filling.ERASE:
|
537 |
+
erase_neighbors.append(m)
|
538 |
+
else:
|
539 |
+
fill_neighbors.append(m)
|
540 |
+
G.remove_node(n)
|
541 |
+
|
542 |
+
path_group = SVGPathGroup([self[n].path.copy().set_orientation(Orientation.CLOCKWISE)], fill=True)
|
543 |
+
if erase_neighbors:
|
544 |
+
for n in erase_neighbors:
|
545 |
+
neighbor = self[n].path.copy().set_orientation(Orientation.COUNTER_CLOCKWISE)
|
546 |
+
path_group.append(neighbor)
|
547 |
+
G.remove_nodes_from(erase_neighbors)
|
548 |
+
|
549 |
+
path_groups.append(path_group)
|
550 |
+
|
551 |
+
current.extend(fill_neighbors)
|
552 |
+
|
553 |
+
# Add outlines in the end
|
554 |
+
for path_group in self.svg_path_groups:
|
555 |
+
if path_group.path.filling == Filling.OUTLINE:
|
556 |
+
path_groups.append(path_group)
|
557 |
+
|
558 |
+
return SVG(path_groups)
|
559 |
+
|
560 |
+
def to_points(self, sort=True):
|
561 |
+
points = np.concatenate([path_group.to_points() for path_group in self.svg_path_groups])
|
562 |
+
|
563 |
+
if sort:
|
564 |
+
ind = np.lexsort((points[:, 0], points[:, 1]))
|
565 |
+
points = points[ind]
|
566 |
+
|
567 |
+
# Remove duplicates
|
568 |
+
row_mask = np.append([True], np.any(np.diff(points, axis=0), 1))
|
569 |
+
points = points[row_mask]
|
570 |
+
|
571 |
+
return points
|
572 |
+
|
573 |
+
def permute(self, indices=None):
|
574 |
+
if indices is not None:
|
575 |
+
self.svg_path_groups = [self.svg_path_groups[i] for i in indices]
|
576 |
+
return self
|
577 |
+
|
578 |
+
def fill_(self, fill=True):
|
579 |
+
return self._apply_to_paths("fill_", fill)
|
src/preprocessing/deepsvg/deepsvg_svglib/svg_command.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
from .geom import *
|
8 |
+
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
|
9 |
+
from .util_fns import get_roots
|
10 |
+
from enum import Enum
|
11 |
+
import torch
|
12 |
+
import math
|
13 |
+
from typing import List, Union
|
14 |
+
Num = Union[int, float]
|
15 |
+
|
16 |
+
|
17 |
+
class SVGCmdEnum(Enum):
|
18 |
+
MOVE_TO = "m"
|
19 |
+
LINE_TO = "l"
|
20 |
+
CUBIC_BEZIER = "c"
|
21 |
+
CLOSE_PATH = "z"
|
22 |
+
ELLIPTIC_ARC = "a"
|
23 |
+
QUAD_BEZIER = "q"
|
24 |
+
LINE_TO_HORIZONTAL = "h"
|
25 |
+
LINE_TO_VERTICAL = "v"
|
26 |
+
CUBIC_BEZIER_REFL = "s"
|
27 |
+
QUAD_BEZIER_REFL = "t"
|
28 |
+
|
29 |
+
|
30 |
+
svgCmdArgTypes = {
|
31 |
+
SVGCmdEnum.MOVE_TO.value: [Point],
|
32 |
+
SVGCmdEnum.LINE_TO.value: [Point],
|
33 |
+
SVGCmdEnum.CUBIC_BEZIER.value: [Point, Point, Point],
|
34 |
+
SVGCmdEnum.CLOSE_PATH.value: [],
|
35 |
+
SVGCmdEnum.ELLIPTIC_ARC.value: [Radius, Angle, Flag, Flag, Point],
|
36 |
+
SVGCmdEnum.QUAD_BEZIER.value: [Point, Point],
|
37 |
+
SVGCmdEnum.LINE_TO_HORIZONTAL.value: [XCoord],
|
38 |
+
SVGCmdEnum.LINE_TO_VERTICAL.value: [YCoord],
|
39 |
+
SVGCmdEnum.CUBIC_BEZIER_REFL.value: [Point, Point],
|
40 |
+
SVGCmdEnum.QUAD_BEZIER_REFL.value: [Point],
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
class SVGCommand:
|
45 |
+
def __init__(self, command: SVGCmdEnum, args: List[Geom], start_pos: Point, end_pos: Point):
|
46 |
+
self.command = command
|
47 |
+
self.args = args
|
48 |
+
|
49 |
+
self.start_pos = start_pos
|
50 |
+
self.end_pos = end_pos
|
51 |
+
|
52 |
+
def copy(self):
|
53 |
+
raise NotImplementedError
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def from_str(cmd_str: str, args_str: List[Num], pos=None, initial_pos=None, prev_command: SVGCommand = None):
|
57 |
+
if pos is None:
|
58 |
+
pos = Point(0.)
|
59 |
+
if initial_pos is None:
|
60 |
+
initial_pos = Point(0.)
|
61 |
+
|
62 |
+
cmd = SVGCmdEnum(cmd_str.lower())
|
63 |
+
|
64 |
+
# Implicit MoveTo commands are treated as LineTo
|
65 |
+
if cmd is SVGCmdEnum.MOVE_TO and len(args_str) > 2:
|
66 |
+
l_cmd_str = SVGCmdEnum.LINE_TO.value
|
67 |
+
if cmd_str.isupper():
|
68 |
+
l_cmd_str = l_cmd_str.upper()
|
69 |
+
|
70 |
+
l1, pos, initial_pos = SVGCommand.from_str(cmd_str, args_str[:2], pos, initial_pos)
|
71 |
+
l2, pos, initial_pos = SVGCommand.from_str(l_cmd_str, args_str[2:], pos, initial_pos)
|
72 |
+
return [*l1, *l2], pos, initial_pos
|
73 |
+
|
74 |
+
nb_args = len(args_str)
|
75 |
+
|
76 |
+
if cmd is SVGCmdEnum.CLOSE_PATH:
|
77 |
+
assert nb_args == 0, f"Expected no argument for command {cmd_str}: {nb_args} given"
|
78 |
+
return [SVGCommandClose(pos, initial_pos)], initial_pos, initial_pos
|
79 |
+
|
80 |
+
expected_nb_args = sum([ArgType.num_args for ArgType in svgCmdArgTypes[cmd.value]])
|
81 |
+
assert nb_args % expected_nb_args == 0, f"Expected {expected_nb_args} arguments for command {cmd_str}: {nb_args} given"
|
82 |
+
|
83 |
+
l = []
|
84 |
+
i = 0
|
85 |
+
for _ in range(nb_args // expected_nb_args):
|
86 |
+
args = []
|
87 |
+
for ArgType in svgCmdArgTypes[cmd.value]:
|
88 |
+
num_args = ArgType.num_args
|
89 |
+
arg = ArgType(*args_str[i:i+num_args])
|
90 |
+
|
91 |
+
if cmd_str.islower():
|
92 |
+
arg.translate(pos)
|
93 |
+
if isinstance(arg, Coord):
|
94 |
+
arg = arg.to_point(pos)
|
95 |
+
|
96 |
+
args.append(arg)
|
97 |
+
i += num_args
|
98 |
+
|
99 |
+
if cmd is SVGCmdEnum.LINE_TO or cmd is SVGCmdEnum.LINE_TO_VERTICAL or cmd is SVGCmdEnum.LINE_TO_HORIZONTAL:
|
100 |
+
cmd_parsed = SVGCommandLine(pos, *args)
|
101 |
+
elif cmd is SVGCmdEnum.MOVE_TO:
|
102 |
+
cmd_parsed = SVGCommandMove(pos, *args)
|
103 |
+
elif cmd is SVGCmdEnum.ELLIPTIC_ARC:
|
104 |
+
cmd_parsed = SVGCommandArc(pos, *args)
|
105 |
+
elif cmd is SVGCmdEnum.CUBIC_BEZIER:
|
106 |
+
cmd_parsed = SVGCommandBezier(pos, *args)
|
107 |
+
elif cmd is SVGCmdEnum.QUAD_BEZIER:
|
108 |
+
cmd_parsed = SVGCommandBezier(pos, args[0], args[0], args[1])
|
109 |
+
elif cmd is SVGCmdEnum.QUAD_BEZIER_REFL or cmd is SVGCmdEnum.CUBIC_BEZIER_REFL:
|
110 |
+
if isinstance(prev_command, SVGCommandBezier):
|
111 |
+
control1 = pos * 2 - prev_command.control2
|
112 |
+
else:
|
113 |
+
control1 = pos
|
114 |
+
control2 = args[0] if cmd is SVGCmdEnum.CUBIC_BEZIER_REFL else control1
|
115 |
+
cmd_parsed = SVGCommandBezier(pos, control1, control2, args[-1])
|
116 |
+
|
117 |
+
prev_command = cmd_parsed
|
118 |
+
pos = cmd_parsed.end_pos
|
119 |
+
|
120 |
+
if cmd is SVGCmdEnum.MOVE_TO:
|
121 |
+
initial_pos = pos
|
122 |
+
|
123 |
+
l.append(cmd_parsed)
|
124 |
+
|
125 |
+
return l, pos, initial_pos
|
126 |
+
|
127 |
+
def __repr__(self):
|
128 |
+
cmd = self.command.value.upper()
|
129 |
+
return f"{cmd}{self.get_geoms()}"
|
130 |
+
|
131 |
+
def to_str(self):
|
132 |
+
cmd = self.command.value.upper()
|
133 |
+
return f"{cmd}{' '.join([arg.to_str() for arg in self.args])}"
|
134 |
+
|
135 |
+
def to_tensor(self, PAD_VAL=-1):
|
136 |
+
raise NotImplementedError
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def from_tensor(vector: torch.Tensor):
|
140 |
+
cmd_index, args = int(vector[0]), vector[1:]
|
141 |
+
|
142 |
+
cmd = SVGCmdEnum(SVGTensor.COMMANDS_SIMPLIFIED[cmd_index])
|
143 |
+
radius = Radius(*args[:2].tolist())
|
144 |
+
x_axis_rotation = Angle(*args[2:3].tolist())
|
145 |
+
large_arc_flag = Flag(args[3].item())
|
146 |
+
sweep_flag = Flag(args[4].item())
|
147 |
+
start_pos = Point(*args[5:7].tolist())
|
148 |
+
control1 = Point(*args[7:9].tolist())
|
149 |
+
control2 = Point(*args[9:11].tolist())
|
150 |
+
end_pos = Point(*args[11:].tolist())
|
151 |
+
|
152 |
+
return SVGCommand.from_args(cmd, radius, x_axis_rotation, large_arc_flag, sweep_flag, start_pos, control1, control2, end_pos)
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def from_args(command: SVGCmdEnum, radius: Radius, x_axis_rotation: Angle, large_arc_flag: Flag,
|
156 |
+
sweep_flag: Flag, start_pos: Point, control1: Point, control2: Point, end_pos: Point):
|
157 |
+
if command is SVGCmdEnum.MOVE_TO:
|
158 |
+
return SVGCommandMove(start_pos, end_pos)
|
159 |
+
elif command is SVGCmdEnum.LINE_TO:
|
160 |
+
return SVGCommandLine(start_pos, end_pos)
|
161 |
+
elif command is SVGCmdEnum.CUBIC_BEZIER:
|
162 |
+
return SVGCommandBezier(start_pos, control1, control2, end_pos)
|
163 |
+
elif command is SVGCmdEnum.CLOSE_PATH:
|
164 |
+
return SVGCommandClose(start_pos, end_pos)
|
165 |
+
elif command is SVGCmdEnum.ELLIPTIC_ARC:
|
166 |
+
return SVGCommandArc(start_pos, radius, x_axis_rotation, large_arc_flag, sweep_flag, end_pos)
|
167 |
+
|
168 |
+
def draw(self, *args, **kwargs):
|
169 |
+
from .svg_path import SVGPath
|
170 |
+
return SVGPath([self]).draw(*args, **kwargs)
|
171 |
+
|
172 |
+
def reverse(self):
|
173 |
+
raise NotImplementedError
|
174 |
+
|
175 |
+
def is_left_to(self, other: SVGCommand):
|
176 |
+
p1, p2 = self.start_pos, other.start_pos
|
177 |
+
|
178 |
+
if p1.y == p2.y:
|
179 |
+
return p1.x < p2.x
|
180 |
+
|
181 |
+
return p1.y < p2.y or (np.isclose(p1.norm(), p2.norm()) and p1.x < p2.x)
|
182 |
+
|
183 |
+
def numericalize(self, n=256):
|
184 |
+
raise NotImplementedError
|
185 |
+
|
186 |
+
def get_geoms(self):
|
187 |
+
return [self.start_pos, self.end_pos]
|
188 |
+
|
189 |
+
def get_points_viz(self, first=False, last=False):
|
190 |
+
from .svg_primitive import SVGCircle
|
191 |
+
color = "red" if first else "purple" if last else "deepskyblue" # "#C4C4C4"
|
192 |
+
opacity = 0.75 if first or last else 1.0
|
193 |
+
return [SVGCircle(self.end_pos, radius=Radius(0.4), color=color, fill=True, stroke_width=".1", opacity=opacity)]
|
194 |
+
|
195 |
+
def get_handles_viz(self):
|
196 |
+
return []
|
197 |
+
|
198 |
+
def sample_points(self, n=10, return_array=False):
|
199 |
+
return []
|
200 |
+
|
201 |
+
def split(self, n=2):
|
202 |
+
raise NotImplementedError
|
203 |
+
|
204 |
+
def length(self):
|
205 |
+
raise NotImplementedError
|
206 |
+
|
207 |
+
def bbox(self):
|
208 |
+
raise NotImplementedError
|
209 |
+
|
210 |
+
|
211 |
+
class SVGCommandLinear(SVGCommand):
|
212 |
+
def __init__(self, *args, **kwargs):
|
213 |
+
super().__init__(*args, **kwargs)
|
214 |
+
|
215 |
+
def to_tensor(self, PAD_VAL=-1):
|
216 |
+
cmd_index = SVGTensor.COMMANDS_SIMPLIFIED.index(self.command.value)
|
217 |
+
return torch.tensor([cmd_index,
|
218 |
+
*([PAD_VAL] * 5),
|
219 |
+
*self.start_pos.to_tensor(),
|
220 |
+
*([PAD_VAL] * 4),
|
221 |
+
*self.end_pos.to_tensor()])
|
222 |
+
|
223 |
+
def numericalize(self, n=256):
|
224 |
+
self.start_pos.numericalize(n)
|
225 |
+
self.end_pos.numericalize(n)
|
226 |
+
|
227 |
+
def copy(self):
|
228 |
+
return self.__class__(self.start_pos.copy(), self.end_pos.copy())
|
229 |
+
|
230 |
+
def reverse(self):
|
231 |
+
return self.__class__(self.end_pos, self.start_pos)
|
232 |
+
|
233 |
+
def split(self, n=2):
|
234 |
+
return [self]
|
235 |
+
|
236 |
+
def bbox(self):
|
237 |
+
return Bbox(self.start_pos, self.end_pos)
|
238 |
+
|
239 |
+
|
240 |
+
class SVGCommandMove(SVGCommandLinear):
|
241 |
+
def __init__(self, start_pos: Point, end_pos: Point=None):
|
242 |
+
if end_pos is None:
|
243 |
+
start_pos, end_pos = Point(0.), start_pos
|
244 |
+
super().__init__(SVGCmdEnum.MOVE_TO, [end_pos], start_pos, end_pos)
|
245 |
+
|
246 |
+
def get_points_viz(self, first=False, last=False):
|
247 |
+
from .svg_primitive import SVGLine
|
248 |
+
points_viz = super().get_points_viz(first, last)
|
249 |
+
points_viz.append(SVGLine(self.start_pos, self.end_pos, color="red", dasharray=0.5))
|
250 |
+
return points_viz
|
251 |
+
|
252 |
+
def bbox(self):
|
253 |
+
return Bbox(self.end_pos, self.end_pos)
|
254 |
+
|
255 |
+
|
256 |
+
class SVGCommandLine(SVGCommandLinear):
|
257 |
+
def __init__(self, start_pos: Point, end_pos: Point):
|
258 |
+
super().__init__(SVGCmdEnum.LINE_TO, [end_pos], start_pos, end_pos)
|
259 |
+
|
260 |
+
def sample_points(self, n=10, return_array=False):
|
261 |
+
z = np.linspace(0., 1., n)
|
262 |
+
|
263 |
+
if return_array:
|
264 |
+
points = (1-z)[:, None] * self.start_pos.pos[None] + z[:, None] * self.end_pos.pos[None]
|
265 |
+
return points
|
266 |
+
|
267 |
+
points = [(1 - alpha) * self.start_pos + alpha * self.end_pos for alpha in z]
|
268 |
+
return points
|
269 |
+
|
270 |
+
def split(self, n=2):
|
271 |
+
points = self.sample_points(n+1)
|
272 |
+
return [SVGCommandLine(p1, p2) for p1, p2 in zip(points[:-1], points[1:])]
|
273 |
+
|
274 |
+
def length(self):
|
275 |
+
return self.start_pos.dist(self.end_pos)
|
276 |
+
|
277 |
+
|
278 |
+
class SVGCommandClose(SVGCommandLinear):
|
279 |
+
def __init__(self, start_pos: Point, end_pos: Point):
|
280 |
+
super().__init__(SVGCmdEnum.CLOSE_PATH, [], start_pos, end_pos)
|
281 |
+
|
282 |
+
def get_points_viz(self, first=False, last=False):
|
283 |
+
return []
|
284 |
+
|
285 |
+
|
286 |
+
class SVGCommandBezier(SVGCommand):
|
287 |
+
def __init__(self, start_pos: Point, control1: Point, control2: Point, end_pos: Point):
|
288 |
+
if control2 is None:
|
289 |
+
control2 = control1.copy()
|
290 |
+
super().__init__(SVGCmdEnum.CUBIC_BEZIER, [control1, control2, end_pos], start_pos, end_pos)
|
291 |
+
|
292 |
+
self.control1 = control1
|
293 |
+
self.control2 = control2
|
294 |
+
|
295 |
+
@property
|
296 |
+
def p1(self):
|
297 |
+
return self.start_pos
|
298 |
+
|
299 |
+
@property
|
300 |
+
def p2(self):
|
301 |
+
return self.end_pos
|
302 |
+
|
303 |
+
@property
|
304 |
+
def q1(self):
|
305 |
+
return self.control1
|
306 |
+
|
307 |
+
@property
|
308 |
+
def q2(self):
|
309 |
+
return self.control2
|
310 |
+
|
311 |
+
def copy(self):
|
312 |
+
return SVGCommandBezier(self.start_pos.copy(), self.control1.copy(), self.control2.copy(), self.end_pos.copy())
|
313 |
+
|
314 |
+
def to_tensor(self, PAD_VAL=-1):
|
315 |
+
cmd_index = SVGTensor.COMMANDS_SIMPLIFIED.index(SVGCmdEnum.CUBIC_BEZIER.value)
|
316 |
+
return torch.tensor([cmd_index,
|
317 |
+
*([PAD_VAL] * 5),
|
318 |
+
*self.start_pos.to_tensor(),
|
319 |
+
*self.control1.to_tensor(),
|
320 |
+
*self.control2.to_tensor(),
|
321 |
+
*self.end_pos.to_tensor()])
|
322 |
+
|
323 |
+
def to_vector(self):
|
324 |
+
return np.array([
|
325 |
+
self.start_pos.tolist(),
|
326 |
+
self.control1.tolist(),
|
327 |
+
self.control2.tolist(),
|
328 |
+
self.end_pos.tolist()
|
329 |
+
])
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
def from_vector(vector):
|
333 |
+
return SVGCommandBezier(Point(vector[0]), Point(vector[1]), Point(vector[2]), Point(vector[3]))
|
334 |
+
|
335 |
+
def reverse(self):
|
336 |
+
return SVGCommandBezier(self.end_pos, self.control2, self.control1, self.start_pos)
|
337 |
+
|
338 |
+
def numericalize(self, n=256):
|
339 |
+
self.start_pos.numericalize(n)
|
340 |
+
self.control1.numericalize(n)
|
341 |
+
self.control2.numericalize(n)
|
342 |
+
self.end_pos.numericalize(n)
|
343 |
+
|
344 |
+
def get_geoms(self):
|
345 |
+
return [self.start_pos, self.control1, self.control2, self.end_pos]
|
346 |
+
|
347 |
+
def get_handles_viz(self):
|
348 |
+
from .svg_primitive import SVGLine, SVGCircle
|
349 |
+
anchor_1 = SVGCircle(self.control1, radius=Radius(0.4), color="lime", fill=True, stroke_width=".1")
|
350 |
+
anchor_2 = SVGCircle(self.control2, radius=Radius(0.4), color="lime", fill=True, stroke_width=".1")
|
351 |
+
|
352 |
+
handle_1 = SVGLine(self.start_pos, self.control1, color="grey", dasharray=0.5, stroke_width=".1")
|
353 |
+
handle_2 = SVGLine(self.end_pos, self.control2, color="grey", dasharray=0.5, stroke_width=".1")
|
354 |
+
return [handle_1, handle_2, anchor_1, anchor_2]
|
355 |
+
|
356 |
+
def eval(self, t):
|
357 |
+
return (1 - t)**3 * self.start_pos + 3 * (1 - t)**2 * t * self.control1 + 3 * (1 - t) * t**2 * self.control2 + t**3 * self.end_pos
|
358 |
+
|
359 |
+
def derivative(self, t, n=1):
|
360 |
+
if n == 1:
|
361 |
+
return 3 * (1 - t)**2 * (self.control1 - self.start_pos) + 6 * (1 - t) * t * (self.control2 - self.control1) + 3 * t**2 * (self.end_pos - self.control2)
|
362 |
+
elif n == 2:
|
363 |
+
return 6 * (1 - t) * (self.control2 - 2 * self.control1 + self.start_pos) + 6 * t * (self.end_pos - 2 * self.control2 + self.control1)
|
364 |
+
|
365 |
+
raise NotImplementedError
|
366 |
+
|
367 |
+
def angle(self, other: SVGCommandBezier):
|
368 |
+
t1, t2 = self.derivative(1.), -other.derivative(0.)
|
369 |
+
if np.isclose(t1.norm(), 0.) or np.isclose(t2.norm(), 0.):
|
370 |
+
return 0.
|
371 |
+
angle = np.arccos(np.clip(t1.normalize().dot(t2.normalize()), -1., 1.))
|
372 |
+
return np.rad2deg(angle)
|
373 |
+
|
374 |
+
def sample_points(self, n=10, return_array=False):
|
375 |
+
b = self.to_vector()
|
376 |
+
|
377 |
+
z = np.linspace(0., 1., n)
|
378 |
+
Z = np.stack([np.ones_like(z), z, z**2, z**3], axis=1)
|
379 |
+
Q = np.array([[1., 0., 0., 0.],
|
380 |
+
[-3, 3., 0., 0.],
|
381 |
+
[3., -6, 3., 0.],
|
382 |
+
[-1, 3., -3, 1]])
|
383 |
+
|
384 |
+
points = Z @ Q @ b
|
385 |
+
|
386 |
+
if return_array:
|
387 |
+
return points
|
388 |
+
|
389 |
+
return [Point(p) for p in points]
|
390 |
+
|
391 |
+
def _split_two(self, z=.5):
|
392 |
+
b = self.to_vector()
|
393 |
+
|
394 |
+
Q1 = np.array([[1, 0, 0, 0],
|
395 |
+
[-(z - 1), z, 0, 0],
|
396 |
+
[(z - 1) ** 2, -2 * (z - 1) * z, z ** 2, 0],
|
397 |
+
[-(z - 1) ** 3, 3 * (z - 1) ** 2 * z, -3 * (z - 1) * z ** 2, z ** 3]])
|
398 |
+
Q2 = np.array([[-(z - 1) ** 3, 3 * (z - 1) ** 2 * z, -3 * (z - 1) * z ** 2, z ** 3],
|
399 |
+
[0, (z - 1) ** 2, -2 * (z - 1) * z, z ** 2],
|
400 |
+
[0, 0, -(z - 1), z],
|
401 |
+
[0, 0, 0, 1]])
|
402 |
+
|
403 |
+
return SVGCommandBezier.from_vector(Q1 @ b), SVGCommandBezier.from_vector(Q2 @ b)
|
404 |
+
|
405 |
+
def split(self, n=2):
|
406 |
+
b_list = []
|
407 |
+
b = self
|
408 |
+
|
409 |
+
for i in range(n - 1):
|
410 |
+
z = 1. / (n - i)
|
411 |
+
b1, b = b._split_two(z)
|
412 |
+
b_list.append(b1)
|
413 |
+
b_list.append(b)
|
414 |
+
return b_list
|
415 |
+
|
416 |
+
def length(self):
|
417 |
+
p = self.sample_points(n=100, return_array=True)
|
418 |
+
return np.linalg.norm(p[1:] - p[:-1], axis=-1).sum()
|
419 |
+
|
420 |
+
def bbox(self):
|
421 |
+
return Bbox.from_points(self.find_extrema())
|
422 |
+
|
423 |
+
def find_roots(self):
|
424 |
+
a = 3 * (-self.p1 + 3 * self.q1 - 3 * self.q2 + self.p2)
|
425 |
+
b = 6 * (self.p1 - 2 * self.q1 + self.q2)
|
426 |
+
c = 3 * (self.q1 - self.p1)
|
427 |
+
|
428 |
+
x_roots, y_roots = get_roots(a.x, b.x, c.x), get_roots(a.y, b.y, c.y)
|
429 |
+
roots_cat = [*x_roots, *y_roots]
|
430 |
+
roots = [root for root in roots_cat if 0 <= root <= 1]
|
431 |
+
return roots
|
432 |
+
|
433 |
+
def find_extrema(self):
|
434 |
+
points = [self.start_pos, self.end_pos]
|
435 |
+
points.extend([self.eval(root) for root in self.find_roots()])
|
436 |
+
return points
|
437 |
+
|
438 |
+
|
439 |
+
class SVGCommandArc(SVGCommand):
|
440 |
+
def __init__(self, start_pos: Point, radius: Radius, x_axis_rotation: Angle, large_arc_flag: Flag, sweep_flag: Flag, end_pos: Point):
|
441 |
+
super().__init__(SVGCmdEnum.ELLIPTIC_ARC, [radius, x_axis_rotation, large_arc_flag, sweep_flag, end_pos], start_pos, end_pos)
|
442 |
+
|
443 |
+
self.radius = radius
|
444 |
+
self.x_axis_rotation = x_axis_rotation
|
445 |
+
self.large_arc_flag = large_arc_flag
|
446 |
+
self.sweep_flag = sweep_flag
|
447 |
+
|
448 |
+
def copy(self):
|
449 |
+
return SVGCommandArc(self.start_pos.copy(), self.radius.copy(), self.x_axis_rotation.copy(), self.large_arc_flag.copy(),
|
450 |
+
self.sweep_flag.copy(), self.end_pos.copy())
|
451 |
+
|
452 |
+
def to_tensor(self, PAD_VAL=-1):
|
453 |
+
cmd_index = SVGTensor.COMMANDS_SIMPLIFIED.index(SVGCmdEnum.ELLIPTIC_ARC.value)
|
454 |
+
return torch.tensor([cmd_index,
|
455 |
+
*self.radius.to_tensor(),
|
456 |
+
*self.x_axis_rotation.to_tensor(),
|
457 |
+
*self.large_arc_flag.to_tensor(),
|
458 |
+
*self.sweep_flag.to_tensor(),
|
459 |
+
*self.start_pos.to_tensor(),
|
460 |
+
*([PAD_VAL] * 4),
|
461 |
+
*self.end_pos.to_tensor()])
|
462 |
+
|
463 |
+
def _get_center_parametrization(self):
|
464 |
+
r = self.radius
|
465 |
+
p1, p2 = self.start_pos, self.end_pos
|
466 |
+
|
467 |
+
h, m = 0.5 * (p1 - p2), 0.5 * (p1 + p2)
|
468 |
+
p1_trans = h.rotate(-self.x_axis_rotation)
|
469 |
+
|
470 |
+
sign = -1 if self.large_arc_flag.flag == self.sweep_flag.flag else 1
|
471 |
+
x2, y2, rx2, ry2 = p1_trans.x**2, p1_trans.y**2, r.x**2, r.y**2
|
472 |
+
sqrt = math.sqrt(max((rx2*ry2 - rx2*y2 - ry2*x2) / (rx2*y2 + ry2*x2), 0.))
|
473 |
+
c_trans = sign * sqrt * Point(r.x * p1_trans.y / r.y, -r.y * p1_trans.x / r.x)
|
474 |
+
|
475 |
+
c = c_trans.rotate(self.x_axis_rotation) + m
|
476 |
+
|
477 |
+
d, ns = (p1_trans - c_trans) / r, -(p1_trans + c_trans) / r
|
478 |
+
|
479 |
+
theta_1 = Point(1, 0).angle(d, signed=True)
|
480 |
+
|
481 |
+
delta_theta = d.angle(ns, signed=True)
|
482 |
+
delta_theta.deg %= 360
|
483 |
+
if self.sweep_flag.flag == 0 and delta_theta.deg > 0:
|
484 |
+
delta_theta = delta_theta - Angle(360)
|
485 |
+
if self.sweep_flag == 1 and delta_theta.deg < 0:
|
486 |
+
delta_theta = delta_theta + Angle(360)
|
487 |
+
|
488 |
+
return c, theta_1, delta_theta
|
489 |
+
|
490 |
+
def _get_point(self, c: Point, t: float_type):
|
491 |
+
r = self.radius
|
492 |
+
return c + Point(r.x * np.cos(t), r.y * np.sin(t)).rotate(self.x_axis_rotation)
|
493 |
+
|
494 |
+
def _get_derivative(self, t: float_type):
|
495 |
+
r = self.radius
|
496 |
+
return Point(-r.x * np.sin(t), r.y * np.cos(t)).rotate(self.x_axis_rotation)
|
497 |
+
|
498 |
+
def to_beziers(self):
|
499 |
+
""" References:
|
500 |
+
https://www.w3.org/TR/2018/CR-SVG2-20180807/implnote.html
|
501 |
+
https://mortoray.com/2017/02/16/rendering-an-svg-elliptical-arc-as-bezier-curves/
|
502 |
+
http://www.spaceroots.org/documents/ellipse/elliptical-arc.pdf """
|
503 |
+
beziers = []
|
504 |
+
|
505 |
+
c, theta_1, delta_theta = self._get_center_parametrization()
|
506 |
+
nb_curves = max(int(abs(delta_theta.deg) // 45), 1)
|
507 |
+
etas = [theta_1 + i * delta_theta / nb_curves for i in range(nb_curves+1)]
|
508 |
+
for eta_1, eta_2 in zip(etas[:-1], etas[1:]):
|
509 |
+
e1, e2 = eta_1.rad, eta_2.rad
|
510 |
+
alpha = np.sin(e2 - e1) * (math.sqrt(4 + 3 * np.tan(0.5 * (e2 - e1))**2) - 1) / 3
|
511 |
+
p1, p2 = self._get_point(c, e1), self._get_point(c, e2)
|
512 |
+
q1 = p1 + alpha * self._get_derivative(e1)
|
513 |
+
q2 = p2 - alpha * self._get_derivative(e2)
|
514 |
+
beziers.append(SVGCommandBezier(p1, q1, q2, p2))
|
515 |
+
|
516 |
+
return beziers
|
517 |
+
|
518 |
+
def reverse(self):
|
519 |
+
return SVGCommandArc(self.end_pos, self.radius, self.x_axis_rotation, self.large_arc_flag, ~self.sweep_flag, self.start_pos)
|
520 |
+
|
521 |
+
def numericalize(self, n=256):
|
522 |
+
raise NotImplementedError
|
523 |
+
|
524 |
+
def get_geoms(self):
|
525 |
+
return [self.start_pos, self.radius, self.x_axis_rotation, self.large_arc_flag, self.sweep_flag, self.end_pos]
|
526 |
+
|
527 |
+
def split(self, n=2):
|
528 |
+
raise NotImplementedError
|
529 |
+
|
530 |
+
def sample_points(self, n=10, return_array=False):
|
531 |
+
raise NotImplementedError
|
src/preprocessing/deepsvg/deepsvg_svglib/svg_path.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
from .geom import *
|
8 |
+
import src.preprocessing.deepsvg.deepsvg_svglib.geom as geom
|
9 |
+
import re
|
10 |
+
import torch
|
11 |
+
from typing import List, Union
|
12 |
+
from xml.dom import minidom
|
13 |
+
import math
|
14 |
+
#import shapely.geometry
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from .geom import union_bbox
|
18 |
+
from .svg_command import SVGCommand, SVGCommandMove, SVGCommandClose, SVGCommandBezier, SVGCommandLine, SVGCommandArc
|
19 |
+
|
20 |
+
|
21 |
+
COMMANDS = "MmZzLlHhVvCcSsQqTtAa"
|
22 |
+
COMMAND_RE = re.compile(r"([MmZzLlHhVvCcSsQqTtAa])")
|
23 |
+
FLOAT_RE = re.compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
|
24 |
+
|
25 |
+
|
26 |
+
empty_command = SVGCommandMove(Point(0.))
|
27 |
+
|
28 |
+
|
29 |
+
class Orientation:
|
30 |
+
COUNTER_CLOCKWISE = 0
|
31 |
+
CLOCKWISE = 1
|
32 |
+
|
33 |
+
|
34 |
+
class Filling:
|
35 |
+
OUTLINE = 0
|
36 |
+
FILL = 1
|
37 |
+
ERASE = 2
|
38 |
+
|
39 |
+
|
40 |
+
class SVGPath:
|
41 |
+
def __init__(self, path_commands: List[SVGCommand] = None, origin: Point = None, closed=False, filling=Filling.OUTLINE):
|
42 |
+
self.origin = origin or Point(0.)
|
43 |
+
self.path_commands = path_commands
|
44 |
+
self.closed = closed
|
45 |
+
|
46 |
+
self.filling = filling
|
47 |
+
|
48 |
+
@property
|
49 |
+
def start_command(self):
|
50 |
+
return SVGCommandMove(self.origin, self.start_pos)
|
51 |
+
|
52 |
+
@property
|
53 |
+
def start_pos(self):
|
54 |
+
return self.path_commands[0].start_pos
|
55 |
+
|
56 |
+
@property
|
57 |
+
def end_pos(self):
|
58 |
+
return self.path_commands[-1].end_pos
|
59 |
+
|
60 |
+
def to_group(self, *args, **kwargs):
|
61 |
+
from .svg_primitive import SVGPathGroup
|
62 |
+
return SVGPathGroup([self], *args, **kwargs)
|
63 |
+
|
64 |
+
def set_filling(self, filling=True):
|
65 |
+
self.filling = Filling.FILL if filling else Filling.ERASE
|
66 |
+
return self
|
67 |
+
|
68 |
+
def __len__(self):
|
69 |
+
return 1 + len(self.path_commands)
|
70 |
+
|
71 |
+
def __getitem__(self, idx):
|
72 |
+
if idx == 0:
|
73 |
+
return self.start_command
|
74 |
+
return self.path_commands[idx-1]
|
75 |
+
|
76 |
+
def all_commands(self, with_close=True):
|
77 |
+
close_cmd = [SVGCommandClose(self.path_commands[-1].end_pos.copy(), self.start_pos.copy())] if self.closed and self.path_commands and with_close \
|
78 |
+
else ()
|
79 |
+
return [self.start_command, *self.path_commands, *close_cmd]
|
80 |
+
|
81 |
+
def copy(self):
|
82 |
+
return SVGPath([path_command.copy() for path_command in self.path_commands], self.origin.copy(), self.closed, filling=self.filling)
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def _tokenize_path(path_str):
|
86 |
+
cmd = None
|
87 |
+
for x in COMMAND_RE.split(path_str):
|
88 |
+
if x and x in COMMANDS:
|
89 |
+
cmd = x
|
90 |
+
elif cmd is not None:
|
91 |
+
yield cmd, list(map(float, FLOAT_RE.findall(x)))
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def from_xml(x: minidom.Element):
|
95 |
+
stroke = x.getAttribute('stroke')
|
96 |
+
dasharray = x.getAttribute('dasharray')
|
97 |
+
stroke_width = x.getAttribute('stroke-width')
|
98 |
+
|
99 |
+
fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
|
100 |
+
|
101 |
+
filling = Filling.OUTLINE if not x.hasAttribute("filling") else int(x.getAttribute("filling"))
|
102 |
+
|
103 |
+
s = x.getAttribute('d')
|
104 |
+
return SVGPath.from_str(s, fill=fill, filling=filling)
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def from_str(s: str, fill=False, filling=Filling.OUTLINE, add_closing=False):
|
108 |
+
path_commands = []
|
109 |
+
pos = initial_pos = Point(0.)
|
110 |
+
prev_command = None
|
111 |
+
for cmd, args in SVGPath._tokenize_path(s):
|
112 |
+
cmd_parsed, pos, initial_pos = SVGCommand.from_str(cmd, args, pos, initial_pos, prev_command)
|
113 |
+
prev_command = cmd_parsed[-1]
|
114 |
+
path_commands.extend(cmd_parsed)
|
115 |
+
|
116 |
+
return SVGPath.from_commands(path_commands, fill=fill, filling=filling, add_closing=add_closing)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def from_tensor(tensor: torch.Tensor, allow_empty=False):
|
120 |
+
return SVGPath.from_commands([SVGCommand.from_tensor(row) for row in tensor], allow_empty=allow_empty)
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def from_commands(path_commands: List[SVGCommand], fill=False, filling=Filling.OUTLINE, add_closing=False, allow_empty=False):
|
124 |
+
from .svg_primitive import SVGPathGroup
|
125 |
+
|
126 |
+
if not path_commands:
|
127 |
+
return SVGPathGroup([])
|
128 |
+
|
129 |
+
svg_paths = []
|
130 |
+
svg_path = None
|
131 |
+
|
132 |
+
for command in path_commands:
|
133 |
+
if isinstance(command, SVGCommandMove):
|
134 |
+
if svg_path is not None and (allow_empty or svg_path.path_commands): # SVGPath contains at least one command
|
135 |
+
if add_closing:
|
136 |
+
svg_path.closed = True
|
137 |
+
if not svg_path.path_commands:
|
138 |
+
svg_path.path_commands.append(empty_command)
|
139 |
+
svg_paths.append(svg_path)
|
140 |
+
|
141 |
+
svg_path = SVGPath([], command.start_pos.copy(), filling=filling)
|
142 |
+
else:
|
143 |
+
if svg_path is None:
|
144 |
+
# Ignore commands until the first moveTo commands
|
145 |
+
continue
|
146 |
+
|
147 |
+
if isinstance(command, SVGCommandClose):
|
148 |
+
if allow_empty or svg_path.path_commands: # SVGPath contains at least one command
|
149 |
+
svg_path.closed = True
|
150 |
+
if not svg_path.path_commands:
|
151 |
+
svg_path.path_commands.append(empty_command)
|
152 |
+
svg_paths.append(svg_path)
|
153 |
+
svg_path = None
|
154 |
+
else:
|
155 |
+
svg_path.path_commands.append(command)
|
156 |
+
if svg_path is not None and (allow_empty or svg_path.path_commands): # SVGPath contains at least one command
|
157 |
+
if add_closing:
|
158 |
+
svg_path.closed = True
|
159 |
+
if not svg_path.path_commands:
|
160 |
+
svg_path.path_commands.append(empty_command)
|
161 |
+
svg_paths.append(svg_path)
|
162 |
+
return SVGPathGroup(svg_paths, fill=fill)
|
163 |
+
|
164 |
+
def __repr__(self):
|
165 |
+
return "SVGPath({})".format(" ".join(command.__repr__() for command in self.all_commands()))
|
166 |
+
|
167 |
+
def to_str(self, fill=False):
|
168 |
+
return " ".join(command.to_str() for command in self.all_commands())
|
169 |
+
|
170 |
+
def to_tensor(self, PAD_VAL=-1):
|
171 |
+
return torch.stack([command.to_tensor(PAD_VAL=PAD_VAL) for command in self.all_commands()])
|
172 |
+
|
173 |
+
def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=False, with_moves=True):
|
174 |
+
points = self._get_points_viz(color_firstlast, with_moves) if with_points else ()
|
175 |
+
handles = self._get_handles_viz() if with_handles else ()
|
176 |
+
return [*points, *handles]
|
177 |
+
|
178 |
+
def draw(self, viewbox=Bbox(24), *args, **kwargs):
|
179 |
+
from .svg import SVG
|
180 |
+
return SVG([self.to_group()], viewbox=viewbox).draw(*args, **kwargs)
|
181 |
+
|
182 |
+
def _get_points_viz(self, color_firstlast=True, with_moves=True):
|
183 |
+
points = []
|
184 |
+
commands = self.all_commands(with_close=False)
|
185 |
+
n = len(commands)
|
186 |
+
for i, command in enumerate(commands):
|
187 |
+
if not isinstance(command, SVGCommandMove) or with_moves:
|
188 |
+
points_viz = command.get_points_viz(first=(color_firstlast and i <= 1), last=(color_firstlast and i >= n-2))
|
189 |
+
points.extend(points_viz)
|
190 |
+
return points
|
191 |
+
|
192 |
+
def _get_handles_viz(self):
|
193 |
+
handles = []
|
194 |
+
for command in self.path_commands:
|
195 |
+
handles.extend(command.get_handles_viz())
|
196 |
+
return handles
|
197 |
+
|
198 |
+
def _get_unique_geoms(self):
|
199 |
+
geoms = []
|
200 |
+
for command in self.all_commands():
|
201 |
+
geoms.extend(command.get_geoms())
|
202 |
+
return list(set(geoms))
|
203 |
+
|
204 |
+
def translate(self, vec):
|
205 |
+
for geom in self._get_unique_geoms():
|
206 |
+
geom.translate(vec)
|
207 |
+
return self
|
208 |
+
|
209 |
+
def rotate(self, angle):
|
210 |
+
for geom in self._get_unique_geoms():
|
211 |
+
geom.rotate_(angle)
|
212 |
+
return self
|
213 |
+
|
214 |
+
def scale(self, factor):
|
215 |
+
for geom in self._get_unique_geoms():
|
216 |
+
geom.scale(factor)
|
217 |
+
return self
|
218 |
+
|
219 |
+
def filter_consecutives(self):
|
220 |
+
path_commands = []
|
221 |
+
for command in self.path_commands:
|
222 |
+
if not command.start_pos.isclose(command.end_pos):
|
223 |
+
path_commands.append(command)
|
224 |
+
self.path_commands = path_commands
|
225 |
+
return self
|
226 |
+
|
227 |
+
def filter_duplicates(self, min_dist=0.2):
|
228 |
+
path_commands = []
|
229 |
+
current_command = None
|
230 |
+
for command in self.path_commands:
|
231 |
+
if current_command is None:
|
232 |
+
path_commands.append(command)
|
233 |
+
current_command = command
|
234 |
+
|
235 |
+
if command.end_pos.dist(current_command.end_pos) >= min_dist:
|
236 |
+
command.start_pos = current_command.end_pos
|
237 |
+
path_commands.append(command)
|
238 |
+
current_command = command
|
239 |
+
|
240 |
+
self.path_commands = path_commands
|
241 |
+
return self
|
242 |
+
|
243 |
+
def duplicate_extremities(self):
|
244 |
+
self.path_commands = [SVGCommandLine(self.start_pos, self.start_pos),
|
245 |
+
*self.path_commands,
|
246 |
+
SVGCommandLine(self.end_pos, self.end_pos)]
|
247 |
+
return self
|
248 |
+
|
249 |
+
def is_clockwise(self):
|
250 |
+
if len(self.path_commands) == 1:
|
251 |
+
cmd = self.path_commands[0]
|
252 |
+
return cmd.start_pos.tolist() <= cmd.end_pos.tolist()
|
253 |
+
|
254 |
+
det_total = 0.
|
255 |
+
for cmd in self.path_commands:
|
256 |
+
det_total += geom.det(cmd.start_pos, cmd.end_pos)
|
257 |
+
return det_total >= 0.
|
258 |
+
|
259 |
+
def set_orientation(self, orientation):
|
260 |
+
"""
|
261 |
+
orientation: 1 (clockwise), 0 (counter-clockwise)
|
262 |
+
"""
|
263 |
+
if orientation == self.is_clockwise():
|
264 |
+
return self
|
265 |
+
return self.reverse()
|
266 |
+
|
267 |
+
def set_closed(self, closed=True):
|
268 |
+
self.closed = closed
|
269 |
+
return self
|
270 |
+
|
271 |
+
def reverse(self):
|
272 |
+
path_commands = []
|
273 |
+
|
274 |
+
for command in reversed(self.path_commands):
|
275 |
+
path_commands.append(command.reverse())
|
276 |
+
|
277 |
+
self.path_commands = path_commands
|
278 |
+
return self
|
279 |
+
|
280 |
+
def reverse_non_closed(self):
|
281 |
+
if not self.start_pos.isclose(self.end_pos):
|
282 |
+
return self.reverse()
|
283 |
+
return self
|
284 |
+
|
285 |
+
def simplify_arcs(self):
|
286 |
+
path_commands = []
|
287 |
+
for command in self.path_commands:
|
288 |
+
if isinstance(command, SVGCommandArc):
|
289 |
+
if command.radius.iszero():
|
290 |
+
continue
|
291 |
+
if command.start_pos.isclose(command.end_pos):
|
292 |
+
continue
|
293 |
+
path_commands.extend(command.to_beziers())
|
294 |
+
else:
|
295 |
+
path_commands.append(command)
|
296 |
+
|
297 |
+
self.path_commands = path_commands
|
298 |
+
return self
|
299 |
+
|
300 |
+
def _get_topleftmost_command(self):
|
301 |
+
topleftmost_cmd = None
|
302 |
+
topleftmost_idx = 0
|
303 |
+
|
304 |
+
for i, cmd in enumerate(self.path_commands):
|
305 |
+
if topleftmost_cmd is None or cmd.is_left_to(topleftmost_cmd):
|
306 |
+
topleftmost_cmd = cmd
|
307 |
+
topleftmost_idx = i
|
308 |
+
|
309 |
+
return topleftmost_cmd, topleftmost_idx
|
310 |
+
|
311 |
+
def reorder(self):
|
312 |
+
if self.closed:
|
313 |
+
topleftmost_cmd, topleftmost_idx = self._get_topleftmost_command()
|
314 |
+
|
315 |
+
self.path_commands = [
|
316 |
+
*self.path_commands[topleftmost_idx:],
|
317 |
+
*self.path_commands[:topleftmost_idx]
|
318 |
+
]
|
319 |
+
|
320 |
+
return self
|
321 |
+
|
322 |
+
def to_video(self, wrapper, clips=None, svg_commands=None, color="grey"):
|
323 |
+
from .svg import SVG
|
324 |
+
from .svg_primitive import SVGLine, SVGCircle
|
325 |
+
|
326 |
+
if clips is None:
|
327 |
+
clips = []
|
328 |
+
if svg_commands is None:
|
329 |
+
svg_commands = []
|
330 |
+
svg_dots, svg_moves = [], []
|
331 |
+
|
332 |
+
for command in self.all_commands():
|
333 |
+
start_pos, end_pos = command.start_pos, command.end_pos
|
334 |
+
|
335 |
+
if isinstance(command, SVGCommandMove):
|
336 |
+
move = SVGLine(start_pos, end_pos, color="teal", dasharray=0.5)
|
337 |
+
svg_moves.append(move)
|
338 |
+
|
339 |
+
dot = SVGCircle(end_pos, radius=Radius(0.1), color="red")
|
340 |
+
svg_dots.append(dot)
|
341 |
+
|
342 |
+
svg_path = SVGPath(svg_commands).to_group(color=color)
|
343 |
+
svg_new_path = SVGPath([SVGCommandMove(start_pos), command]).to_group(color="red")
|
344 |
+
|
345 |
+
svg_paths = [svg_path, svg_new_path] if svg_commands else [svg_new_path]
|
346 |
+
im = SVG([*svg_paths, *svg_moves, *svg_dots]).draw(do_display=False, return_png=True, with_points=False)
|
347 |
+
clips.append(wrapper(np.array(im)))
|
348 |
+
|
349 |
+
svg_dots[-1].color = "grey"
|
350 |
+
svg_commands.append(command)
|
351 |
+
svg_moves = []
|
352 |
+
|
353 |
+
return clips, svg_commands
|
354 |
+
|
355 |
+
def numericalize(self, n=256):
|
356 |
+
for command in self.all_commands():
|
357 |
+
command.numericalize(n)
|
358 |
+
|
359 |
+
def smooth(self):
|
360 |
+
# https://github.com/paperjs/paper.js/blob/c7d85b663edb728ec78fffa9f828435eaf78d9c9/src/path/Path.js#L1288
|
361 |
+
n = len(self.path_commands)
|
362 |
+
knots = [self.start_pos, *(path_commmand.end_pos for path_commmand in self.path_commands)]
|
363 |
+
r = [knots[0] + 2 * knots[1]]
|
364 |
+
f = [2]
|
365 |
+
p = [Point(0.)] * (n + 1)
|
366 |
+
|
367 |
+
# Solve with the Thomas algorithm
|
368 |
+
for i in range(1, n):
|
369 |
+
internal = i < n - 1
|
370 |
+
a = 1
|
371 |
+
b = 4 if internal else 2
|
372 |
+
u = 4 if internal else 3
|
373 |
+
v = 2 if internal else 0
|
374 |
+
m = a / f[i-1]
|
375 |
+
|
376 |
+
f.append(b-m)
|
377 |
+
r.append(u * knots[i] + v * knots[i + 1] - m * r[i-1])
|
378 |
+
|
379 |
+
p[n-1] = r[n-1] / f[n-1]
|
380 |
+
for i in range(n-2, -1, -1):
|
381 |
+
p[i] = (r[i] - p[i+1]) / f[i]
|
382 |
+
p[n] = (3 * knots[n] - p[n-1]) / 2
|
383 |
+
|
384 |
+
for i in range(n):
|
385 |
+
p1, p2 = knots[i], knots[i+1]
|
386 |
+
c1, c2 = p[i], 2 * p2 - p[i+1]
|
387 |
+
self.path_commands[i] = SVGCommandBezier(p1, c1, c2, p2)
|
388 |
+
|
389 |
+
return self
|
390 |
+
|
391 |
+
def simplify_heuristic(self):
|
392 |
+
return self.copy().split(max_dist=2, include_lines=False) \
|
393 |
+
.simplify(tolerance=0.1, epsilon=0.2, angle_threshold=150) \
|
394 |
+
.split(max_dist=7.5)
|
395 |
+
|
396 |
+
def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False):
|
397 |
+
# https://github.com/paperjs/paper.js/blob/c044b698c6b224c10a7747664b2a4cd00a416a25/src/path/PathFitter.js#L44
|
398 |
+
points = [self.start_pos, *(path_command.end_pos for path_command in self.path_commands)]
|
399 |
+
|
400 |
+
def subdivide_indices():
|
401 |
+
segments_list = []
|
402 |
+
current_segment = []
|
403 |
+
prev_command = None
|
404 |
+
|
405 |
+
for i, command in enumerate(self.path_commands):
|
406 |
+
if isinstance(command, SVGCommandLine):
|
407 |
+
if current_segment:
|
408 |
+
segments_list.append(current_segment)
|
409 |
+
current_segment = []
|
410 |
+
prev_command = None
|
411 |
+
|
412 |
+
continue
|
413 |
+
|
414 |
+
if prev_command is not None and prev_command.angle(command) < angle_threshold:
|
415 |
+
if current_segment:
|
416 |
+
segments_list.append(current_segment)
|
417 |
+
current_segment = []
|
418 |
+
|
419 |
+
current_segment.append(i)
|
420 |
+
prev_command = command
|
421 |
+
|
422 |
+
if current_segment:
|
423 |
+
segments_list.append(current_segment)
|
424 |
+
|
425 |
+
return segments_list
|
426 |
+
|
427 |
+
path_commands = []
|
428 |
+
|
429 |
+
def computeMaxError(first, last, curve: SVGCommandBezier, u):
|
430 |
+
maxDist = 0.
|
431 |
+
index = (last - first + 1) // 2
|
432 |
+
for i in range(1, last - first):
|
433 |
+
dist = curve.eval(u[i]).dist(points[first + i]) ** 2
|
434 |
+
if dist >= maxDist:
|
435 |
+
maxDist = dist
|
436 |
+
index = first + i
|
437 |
+
return maxDist, index
|
438 |
+
|
439 |
+
def chordLengthParametrize(first, last):
|
440 |
+
u = [0.]
|
441 |
+
for i in range(1, last - first + 1):
|
442 |
+
u.append(u[i-1] + points[first + i].dist(points[first + i-1]))
|
443 |
+
|
444 |
+
for i, _ in enumerate(u[1:], 1):
|
445 |
+
u[i] /= u[-1]
|
446 |
+
|
447 |
+
return u
|
448 |
+
|
449 |
+
def isMachineZero(val):
|
450 |
+
MACHINE_EPSILON = 1.12e-16
|
451 |
+
return val >= -MACHINE_EPSILON and val <= MACHINE_EPSILON
|
452 |
+
|
453 |
+
def findRoot(curve: SVGCommandBezier, point, u):
|
454 |
+
"""
|
455 |
+
Newton's root finding algorithm calculates f(x)=0 by reiterating
|
456 |
+
x_n+1 = x_n - f(x_n)/f'(x_n)
|
457 |
+
We are trying to find curve parameter u for some point p that minimizes
|
458 |
+
the distance from that point to the curve. Distance point to curve is d=q(u)-p.
|
459 |
+
At minimum distance the point is perpendicular to the curve.
|
460 |
+
We are solving
|
461 |
+
f = q(u)-p * q'(u) = 0
|
462 |
+
with
|
463 |
+
f' = q'(u) * q'(u) + q(u)-p * q''(u)
|
464 |
+
gives
|
465 |
+
u_n+1 = u_n - |q(u_n)-p * q'(u_n)| / |q'(u_n)**2 + q(u_n)-p * q''(u_n)|
|
466 |
+
"""
|
467 |
+
diff = curve.eval(u) - point
|
468 |
+
d1, d2 = curve.derivative(u, n=1), curve.derivative(u, n=2)
|
469 |
+
numerator = diff.dot(d1)
|
470 |
+
denominator = d1.dot(d1) + diff.dot(d2)
|
471 |
+
|
472 |
+
return u if isMachineZero(denominator) else u - numerator / denominator
|
473 |
+
|
474 |
+
def reparametrize(first, last, u, curve: SVGCommandBezier):
|
475 |
+
for i in range(0, last - first + 1):
|
476 |
+
u[i] = findRoot(curve, points[first + i], u[i])
|
477 |
+
|
478 |
+
for i in range(1, len(u)):
|
479 |
+
if u[i] <= u[i-1]:
|
480 |
+
return False
|
481 |
+
|
482 |
+
return True
|
483 |
+
|
484 |
+
def generateBezier(first, last, uPrime, tan1, tan2):
|
485 |
+
epsilon = 1e-12
|
486 |
+
p1, p2 = points[first], points[last]
|
487 |
+
C = np.zeros((2, 2))
|
488 |
+
X = np.zeros(2)
|
489 |
+
|
490 |
+
for i in range(last - first + 1):
|
491 |
+
u = uPrime[i]
|
492 |
+
t = 1 - u
|
493 |
+
b = 3 * u * t
|
494 |
+
b0 = t**3
|
495 |
+
b1 = b * t
|
496 |
+
b2 = b * u
|
497 |
+
b3 = u**3
|
498 |
+
a1 = tan1 * b1
|
499 |
+
a2 = tan2 * b2
|
500 |
+
tmp = points[first + i] - p1 * (b0 + b1) - p2 * (b2 + b3)
|
501 |
+
|
502 |
+
C[0, 0] += a1.dot(a1)
|
503 |
+
C[0, 1] += a1.dot(a2)
|
504 |
+
C[1, 0] = C[0, 1]
|
505 |
+
C[1, 1] += a2.dot(a2)
|
506 |
+
X[0] += a1.dot(tmp)
|
507 |
+
X[1] += a2.dot(tmp)
|
508 |
+
|
509 |
+
detC0C1 = C[0, 0] * C[1, 1] - C[1, 0] * C[0, 1]
|
510 |
+
if abs(detC0C1) > epsilon:
|
511 |
+
detC0X = C[0, 0] * X[1] - C[1, 0] * X[0]
|
512 |
+
detXC1 = X[0] * C[1, 1] - X[1] * C[0, 1]
|
513 |
+
alpha1 = detXC1 / detC0C1
|
514 |
+
alpha2 = detC0X / detC0C1
|
515 |
+
else:
|
516 |
+
c0 = C[0, 0] + C[0, 1]
|
517 |
+
c1 = C[1, 0] + C[1, 1]
|
518 |
+
alpha1 = alpha2 = X[0] / c0 if abs(c0) > epsilon else (X[1] / c1 if abs(c1) > epsilon else 0)
|
519 |
+
|
520 |
+
segLength = p2.dist(p1)
|
521 |
+
eps = epsilon * segLength
|
522 |
+
handle1 = handle2 = None
|
523 |
+
|
524 |
+
if alpha1 < eps or alpha2 < eps:
|
525 |
+
alpha1 = alpha2 = segLength / 3
|
526 |
+
else:
|
527 |
+
line = p2 - p1
|
528 |
+
handle1 = tan1 * alpha1
|
529 |
+
handle2 = tan2 * alpha2
|
530 |
+
|
531 |
+
if handle1.dot(line) - handle2.dot(line) > segLength**2:
|
532 |
+
alpha1 = alpha2 = segLength / 3
|
533 |
+
handle1 = handle2 = None
|
534 |
+
|
535 |
+
if handle1 is None or handle2 is None:
|
536 |
+
handle1 = tan1 * alpha1
|
537 |
+
handle2 = tan2 * alpha2
|
538 |
+
|
539 |
+
return SVGCommandBezier(p1, p1 + handle1, p2 + handle2, p2)
|
540 |
+
|
541 |
+
def computeLinearMaxError(first, last):
|
542 |
+
maxDist = 0.
|
543 |
+
index = (last - first + 1) // 2
|
544 |
+
|
545 |
+
p1, p2 = points[first], points[last]
|
546 |
+
for i in range(first + 1, last):
|
547 |
+
dist = points[i].distToLine(p1, p2)
|
548 |
+
if dist >= maxDist:
|
549 |
+
maxDist = dist
|
550 |
+
index = i
|
551 |
+
return maxDist, index
|
552 |
+
|
553 |
+
def ramerDouglasPeucker(first, last, epsilon):
|
554 |
+
max_error, split_index = computeLinearMaxError(first, last)
|
555 |
+
|
556 |
+
if max_error > epsilon:
|
557 |
+
ramerDouglasPeucker(first, split_index, epsilon)
|
558 |
+
ramerDouglasPeucker(split_index, last, epsilon)
|
559 |
+
else:
|
560 |
+
p1, p2 = points[first], points[last]
|
561 |
+
path_commands.append(SVGCommandLine(p1, p2))
|
562 |
+
|
563 |
+
def fitCubic(error, first, last, tan1=None, tan2=None):
|
564 |
+
# For convenience, compute extremity tangents if not provided
|
565 |
+
if tan1 is None and tan2 is None:
|
566 |
+
tan1 = (points[first + 1] - points[first]).normalize()
|
567 |
+
tan2 = (points[last - 1] - points[last]).normalize()
|
568 |
+
|
569 |
+
if last - first == 1:
|
570 |
+
p1, p2 = points[first], points[last]
|
571 |
+
dist = p1.dist(p2) / 3
|
572 |
+
path_commands.append(SVGCommandBezier(p1, p1 + dist * tan1, p2 + dist * tan2, p2))
|
573 |
+
return
|
574 |
+
|
575 |
+
uPrime = chordLengthParametrize(first, last)
|
576 |
+
maxError = max(error, error**2)
|
577 |
+
parametersInOrder = True
|
578 |
+
|
579 |
+
for i in range(5):
|
580 |
+
curve = generateBezier(first, last, uPrime, tan1, tan2)
|
581 |
+
|
582 |
+
max_error, split_index = computeMaxError(first, last, curve, uPrime)
|
583 |
+
|
584 |
+
if max_error < error and parametersInOrder:
|
585 |
+
path_commands.append(curve)
|
586 |
+
return
|
587 |
+
|
588 |
+
if max_error >= maxError:
|
589 |
+
break
|
590 |
+
|
591 |
+
parametersInOrder = reparametrize(first, last, uPrime, curve)
|
592 |
+
maxError = max_error
|
593 |
+
|
594 |
+
tanCenter = (points[split_index-1] - points[split_index+1]).normalize()
|
595 |
+
fitCubic(error, first, split_index, tan1, tanCenter)
|
596 |
+
fitCubic(error, split_index, last, -tanCenter, tan2)
|
597 |
+
|
598 |
+
segments_list = subdivide_indices()
|
599 |
+
if force_smooth:
|
600 |
+
fitCubic(tolerance, 0, len(points) - 1)
|
601 |
+
else:
|
602 |
+
if segments_list:
|
603 |
+
seg = segments_list[0]
|
604 |
+
ramerDouglasPeucker(0, seg[0], epsilon)
|
605 |
+
|
606 |
+
for seg, seg_next in zip(segments_list[:-1], segments_list[1:]):
|
607 |
+
fitCubic(tolerance, seg[0], seg[-1] + 1)
|
608 |
+
ramerDouglasPeucker(seg[-1] + 1, seg_next[0], epsilon)
|
609 |
+
|
610 |
+
seg = segments_list[-1]
|
611 |
+
fitCubic(tolerance, seg[0], seg[-1] + 1)
|
612 |
+
ramerDouglasPeucker(seg[-1] + 1, len(points) - 1, epsilon)
|
613 |
+
else:
|
614 |
+
ramerDouglasPeucker(0, len(points) - 1, epsilon)
|
615 |
+
|
616 |
+
self.path_commands = path_commands
|
617 |
+
|
618 |
+
return self
|
619 |
+
|
620 |
+
def split(self, n=None, max_dist=None, include_lines=True):
|
621 |
+
path_commands = []
|
622 |
+
|
623 |
+
for command in self.path_commands:
|
624 |
+
if isinstance(command, SVGCommandLine) and not include_lines:
|
625 |
+
path_commands.append(command)
|
626 |
+
else:
|
627 |
+
l = command.length()
|
628 |
+
if max_dist is not None:
|
629 |
+
n = max(math.ceil(l / max_dist), 1)
|
630 |
+
|
631 |
+
path_commands.extend(command.split(n=n))
|
632 |
+
|
633 |
+
self.path_commands = path_commands
|
634 |
+
|
635 |
+
return self
|
636 |
+
|
637 |
+
def bbox(self):
|
638 |
+
return union_bbox([cmd.bbox() for cmd in self.path_commands])
|
639 |
+
|
640 |
+
def sample_points(self, max_dist=0.4):
|
641 |
+
points = []
|
642 |
+
|
643 |
+
for command in self.path_commands:
|
644 |
+
l = command.length()
|
645 |
+
n = max(math.ceil(l / max_dist), 1)
|
646 |
+
points.extend(command.sample_points(n=n, return_array=True)[None])
|
647 |
+
points = np.concatenate(points, axis=0)
|
648 |
+
return points
|
649 |
+
|
650 |
+
def to_shapely(self):
|
651 |
+
polygon = shapely.geometry.Polygon(self.sample_points())
|
652 |
+
|
653 |
+
if not polygon.is_valid:
|
654 |
+
polygon = polygon.buffer(0)
|
655 |
+
|
656 |
+
return polygon
|
657 |
+
|
658 |
+
def to_points(self):
|
659 |
+
return np.array([self.start_pos.pos, *(cmd.end_pos.pos for cmd in self.path_commands)])
|
src/preprocessing/deepsvg/deepsvg_svglib/svg_primitive.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
from .geom import *
|
8 |
+
import torch
|
9 |
+
import re
|
10 |
+
from typing import List, Union
|
11 |
+
from xml.dom import minidom
|
12 |
+
from .svg_path import SVGPath
|
13 |
+
from .svg_command import SVGCommandLine, SVGCommandArc, SVGCommandBezier, SVGCommandClose
|
14 |
+
import shapely as shapely
|
15 |
+
import shapely.ops
|
16 |
+
import shapely.geometry
|
17 |
+
import networkx as nx
|
18 |
+
|
19 |
+
|
20 |
+
FLOAT_RE = re.compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
|
21 |
+
|
22 |
+
|
23 |
+
def extract_args(args):
|
24 |
+
return list(map(float, FLOAT_RE.findall(args)))
|
25 |
+
|
26 |
+
|
27 |
+
class SVGPrimitive:
|
28 |
+
"""
|
29 |
+
Reference: https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Basic_Shapes
|
30 |
+
"""
|
31 |
+
def __init__(self, color="black", fill=False, dasharray=None, stroke_width=".3", opacity=1.0):
|
32 |
+
self.color = color
|
33 |
+
self.dasharray = dasharray
|
34 |
+
self.stroke_width = stroke_width
|
35 |
+
self.opacity = opacity
|
36 |
+
|
37 |
+
self.fill = fill
|
38 |
+
|
39 |
+
def _get_fill_attr(self):
|
40 |
+
fill_attr = f'fill="{self.color}" fill-opacity="{self.opacity}"' if self.fill else f'fill="none" stroke="{self.color}" stroke-width="{self.stroke_width}" stroke-opacity="{self.opacity}"'
|
41 |
+
if self.dasharray is not None and not self.fill:
|
42 |
+
fill_attr += f' stroke-dasharray="{self.dasharray}"'
|
43 |
+
return fill_attr
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def from_xml(cls, x: minidom.Element):
|
47 |
+
raise NotImplementedError
|
48 |
+
|
49 |
+
def draw(self, viewbox=Bbox(24), *args, **kwargs):
|
50 |
+
from .svg import SVG
|
51 |
+
return SVG([self], viewbox=viewbox).draw(*args, **kwargs)
|
52 |
+
|
53 |
+
def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=True, with_moves=True):
|
54 |
+
return []
|
55 |
+
|
56 |
+
def to_path(self):
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
def copy(self):
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
def bbox(self):
|
63 |
+
raise NotImplementedError
|
64 |
+
|
65 |
+
def fill_(self, fill=True):
|
66 |
+
self.fill = fill
|
67 |
+
return self
|
68 |
+
|
69 |
+
|
70 |
+
class SVGEllipse(SVGPrimitive):
|
71 |
+
def __init__(self, center: Point, radius: Radius, *args, **kwargs):
|
72 |
+
super().__init__(*args, **kwargs)
|
73 |
+
|
74 |
+
self.center = center
|
75 |
+
self.radius = radius
|
76 |
+
|
77 |
+
def __repr__(self):
|
78 |
+
return f'SVGEllipse(c={self.center} r={self.radius})'
|
79 |
+
|
80 |
+
def to_str(self, *args, **kwargs):
|
81 |
+
fill_attr = self._get_fill_attr()
|
82 |
+
return f'<ellipse {fill_attr} cx="{self.center.x}" cy="{self.center.y}" rx="{self.radius.x}" ry="{self.radius.y}"/>'
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def from_xml(_, x: minidom.Element):
|
86 |
+
fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
|
87 |
+
|
88 |
+
center = Point(float(x.getAttribute("cx")), float(x.getAttribute("cy")))
|
89 |
+
radius = Radius(float(x.getAttribute("rx")), float(x.getAttribute("ry")))
|
90 |
+
return SVGEllipse(center, radius, fill=fill)
|
91 |
+
|
92 |
+
def to_path(self):
|
93 |
+
p0, p1 = self.center + self.radius.xproj(), self.center + self.radius.yproj()
|
94 |
+
p2, p3 = self.center - self.radius.xproj(), self.center - self.radius.yproj()
|
95 |
+
commands = [
|
96 |
+
SVGCommandArc(p0, self.radius, Angle(0.), Flag(0.), Flag(1.), p1),
|
97 |
+
SVGCommandArc(p1, self.radius, Angle(0.), Flag(0.), Flag(1.), p2),
|
98 |
+
SVGCommandArc(p2, self.radius, Angle(0.), Flag(0.), Flag(1.), p3),
|
99 |
+
SVGCommandArc(p3, self.radius, Angle(0.), Flag(0.), Flag(1.), p0),
|
100 |
+
]
|
101 |
+
return SVGPath(commands, closed=True).to_group(fill=self.fill)
|
102 |
+
|
103 |
+
|
104 |
+
class SVGCircle(SVGEllipse):
|
105 |
+
def __init__(self, *args, **kwargs):
|
106 |
+
super().__init__(*args, **kwargs)
|
107 |
+
|
108 |
+
def __repr__(self):
|
109 |
+
return f'SVGCircle(c={self.center} r={self.radius})'
|
110 |
+
|
111 |
+
def to_str(self, *args, **kwargs):
|
112 |
+
fill_attr = self._get_fill_attr()
|
113 |
+
return f'<circle {fill_attr} cx="{self.center.x}" cy="{self.center.y}" r="{self.radius.x}"/>'
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_xml(_, x: minidom.Element):
|
117 |
+
fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
|
118 |
+
|
119 |
+
center = Point(float(x.getAttribute("cx")), float(x.getAttribute("cy")))
|
120 |
+
radius = Radius(float(x.getAttribute("r")))
|
121 |
+
return SVGCircle(center, radius, fill=fill)
|
122 |
+
|
123 |
+
|
124 |
+
class SVGRectangle(SVGPrimitive):
|
125 |
+
def __init__(self, xy: Point, wh: Size, *args, **kwargs):
|
126 |
+
super().__init__(*args, **kwargs)
|
127 |
+
|
128 |
+
self.xy = xy
|
129 |
+
self.wh = wh
|
130 |
+
|
131 |
+
def __repr__(self):
|
132 |
+
return f'SVGRectangle(xy={self.xy} wh={self.wh})'
|
133 |
+
|
134 |
+
def to_str(self, *args, **kwargs):
|
135 |
+
fill_attr = self._get_fill_attr()
|
136 |
+
return f'<rect {fill_attr} x="{self.xy.x}" y="{self.xy.y}" width="{self.wh.x}" height="{self.wh.y}"/>'
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def from_xml(_, x: minidom.Element):
|
140 |
+
fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
|
141 |
+
|
142 |
+
xy = Point(0.)
|
143 |
+
if x.hasAttribute("x"):
|
144 |
+
xy.pos[0] = float(x.getAttribute("x"))
|
145 |
+
if x.hasAttribute("y"):
|
146 |
+
xy.pos[1] = float(x.getAttribute("y"))
|
147 |
+
wh = Size(float(x.getAttribute("width")), float(x.getAttribute("height")))
|
148 |
+
return SVGRectangle(xy, wh, fill=fill)
|
149 |
+
|
150 |
+
def to_path(self):
|
151 |
+
p0, p1, p2, p3 = self.xy, self.xy + self.wh.xproj(), self.xy + self.wh, self.xy + self.wh.yproj()
|
152 |
+
commands = [
|
153 |
+
SVGCommandLine(p0, p1),
|
154 |
+
SVGCommandLine(p1, p2),
|
155 |
+
SVGCommandLine(p2, p3),
|
156 |
+
SVGCommandLine(p3, p0)
|
157 |
+
]
|
158 |
+
return SVGPath(commands, closed=True).to_group(fill=self.fill)
|
159 |
+
|
160 |
+
|
161 |
+
class SVGLine(SVGPrimitive):
|
162 |
+
def __init__(self, start_pos: Point, end_pos: Point, *args, **kwargs):
|
163 |
+
super().__init__(*args, **kwargs)
|
164 |
+
|
165 |
+
self.start_pos = start_pos
|
166 |
+
self.end_pos = end_pos
|
167 |
+
|
168 |
+
def __repr__(self):
|
169 |
+
return f'SVGLine(xy1={self.start_pos} xy2={self.end_pos})'
|
170 |
+
|
171 |
+
def to_str(self, *args, **kwargs):
|
172 |
+
fill_attr = self._get_fill_attr()
|
173 |
+
return f'<line {fill_attr} x1="{self.start_pos.x}" y1="{self.start_pos.y}" x2="{self.end_pos.x}" y2="{self.end_pos.y}"/>'
|
174 |
+
|
175 |
+
@classmethod
|
176 |
+
def from_xml(_, x: minidom.Element):
|
177 |
+
fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
|
178 |
+
|
179 |
+
start_pos = Point(float(x.getAttribute("x1") or 0.), float(x.getAttribute("y1") or 0.))
|
180 |
+
end_pos = Point(float(x.getAttribute("x2") or 0.), float(x.getAttribute("y2") or 0.))
|
181 |
+
return SVGLine(start_pos, end_pos, fill=fill)
|
182 |
+
|
183 |
+
def to_path(self):
|
184 |
+
return SVGPath([SVGCommandLine(self.start_pos, self.end_pos)]).to_group(fill=self.fill)
|
185 |
+
|
186 |
+
|
187 |
+
class SVGPolyline(SVGPrimitive):
|
188 |
+
def __init__(self, points: List[Point], *args, **kwargs):
|
189 |
+
super().__init__(*args, **kwargs)
|
190 |
+
|
191 |
+
self.points = points
|
192 |
+
|
193 |
+
def __repr__(self):
|
194 |
+
return f'SVGPolyline(points={self.points})'
|
195 |
+
|
196 |
+
def to_str(self, *args, **kwargs):
|
197 |
+
fill_attr = self._get_fill_attr()
|
198 |
+
return '<polyline {} points="{}"/>'.format(fill_attr, ' '.join([p.to_str() for p in self.points]))
|
199 |
+
|
200 |
+
@classmethod
|
201 |
+
def from_xml(cls, x: minidom.Element):
|
202 |
+
fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none"
|
203 |
+
|
204 |
+
args = extract_args(x.getAttribute("points"))
|
205 |
+
assert len(args) % 2 == 0, f"Expected even number of arguments for SVGPolyline: {len(args)} given"
|
206 |
+
points = [Point(x, args[2*i+1]) for i, x in enumerate(args[::2])]
|
207 |
+
return cls(points, fill=fill)
|
208 |
+
|
209 |
+
def to_path(self):
|
210 |
+
commands = [SVGCommandLine(p1, p2) for p1, p2 in zip(self.points[:-1], self.points[1:])]
|
211 |
+
is_closed = self.__class__.__name__ == "SVGPolygon"
|
212 |
+
return SVGPath(commands, closed=is_closed).to_group(fill=self.fill)
|
213 |
+
|
214 |
+
|
215 |
+
class SVGPolygon(SVGPolyline):
|
216 |
+
def __init__(self, *args, **kwargs):
|
217 |
+
super().__init__(*args, **kwargs)
|
218 |
+
|
219 |
+
def __repr__(self):
|
220 |
+
return f'SVGPolygon(points={self.points})'
|
221 |
+
|
222 |
+
def to_str(self, *args, **kwargs):
|
223 |
+
fill_attr = self._get_fill_attr()
|
224 |
+
return '<polygon {} points="{}"/>'.format(fill_attr, ' '.join([p.to_str() for p in self.points]))
|
225 |
+
|
226 |
+
|
227 |
+
class SVGPathGroup(SVGPrimitive):
|
228 |
+
def __init__(self, svg_paths: List[SVGPath] = None, origin=None, *args, **kwargs):
|
229 |
+
super().__init__(*args, **kwargs)
|
230 |
+
self.svg_paths = svg_paths
|
231 |
+
|
232 |
+
if origin is None:
|
233 |
+
origin = Point(0.)
|
234 |
+
self.origin = origin
|
235 |
+
|
236 |
+
# Alias
|
237 |
+
@property
|
238 |
+
def paths(self):
|
239 |
+
return self.svg_paths
|
240 |
+
|
241 |
+
@property
|
242 |
+
def path(self):
|
243 |
+
return self.svg_paths[0]
|
244 |
+
|
245 |
+
def __getitem__(self, idx):
|
246 |
+
return self.svg_paths[idx]
|
247 |
+
|
248 |
+
def __len__(self):
|
249 |
+
return len(self.paths)
|
250 |
+
|
251 |
+
def total_len(self):
|
252 |
+
return sum([len(path) for path in self.svg_paths])
|
253 |
+
|
254 |
+
@property
|
255 |
+
def start_pos(self):
|
256 |
+
return self.svg_paths[0].start_pos
|
257 |
+
|
258 |
+
@property
|
259 |
+
def end_pos(self):
|
260 |
+
last_path = self.svg_paths[-1]
|
261 |
+
if last_path.closed:
|
262 |
+
return last_path.start_pos
|
263 |
+
return last_path.end_pos
|
264 |
+
|
265 |
+
def set_origin(self, origin: Point):
|
266 |
+
self.origin = origin
|
267 |
+
if self.svg_paths:
|
268 |
+
self.svg_paths[0].origin = origin
|
269 |
+
self.recompute_origins()
|
270 |
+
|
271 |
+
def append(self, path: SVGPath):
|
272 |
+
self.svg_paths.append(path)
|
273 |
+
|
274 |
+
def copy(self):
|
275 |
+
return SVGPathGroup([svg_path.copy() for svg_path in self.svg_paths], self.origin.copy(),
|
276 |
+
self.color, self.fill, self.dasharray, self.stroke_width, self.opacity)
|
277 |
+
|
278 |
+
def __repr__(self):
|
279 |
+
return "SVGPathGroup({})".format(", ".join(svg_path.__repr__() for svg_path in self.svg_paths))
|
280 |
+
|
281 |
+
def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=True, with_moves=True):
|
282 |
+
viz_elements = []
|
283 |
+
for svg_path in self.svg_paths:
|
284 |
+
viz_elements.extend(svg_path._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves))
|
285 |
+
|
286 |
+
if with_bboxes:
|
287 |
+
viz_elements.append(self._get_bbox_viz())
|
288 |
+
|
289 |
+
return viz_elements
|
290 |
+
|
291 |
+
def _get_bbox_viz(self):
|
292 |
+
color = "red" if self.color == "black" else self.color
|
293 |
+
bbox = self.bbox().to_rectangle(color=color)
|
294 |
+
return bbox
|
295 |
+
|
296 |
+
def to_path(self):
|
297 |
+
return self
|
298 |
+
|
299 |
+
def to_str(self, with_markers=False, *args, **kwargs):
|
300 |
+
fill_attr = self._get_fill_attr()
|
301 |
+
marker_attr = 'marker-start="url(#arrow)"' if with_markers else ''
|
302 |
+
return '<path {} {} filling="{}" d="{}"></path>'.format(fill_attr, marker_attr, self.path.filling,
|
303 |
+
" ".join(svg_path.to_str() for svg_path in self.svg_paths))
|
304 |
+
|
305 |
+
def to_tensor(self, PAD_VAL=-1):
|
306 |
+
return torch.cat([p.to_tensor(PAD_VAL=PAD_VAL) for p in self.svg_paths], dim=0)
|
307 |
+
|
308 |
+
def _apply_to_paths(self, method, *args, **kwargs):
|
309 |
+
for path in self.svg_paths:
|
310 |
+
getattr(path, method)(*args, **kwargs)
|
311 |
+
return self
|
312 |
+
|
313 |
+
def translate(self, vec):
|
314 |
+
return self._apply_to_paths("translate", vec)
|
315 |
+
|
316 |
+
def rotate(self, angle: Angle):
|
317 |
+
return self._apply_to_paths("rotate", angle)
|
318 |
+
|
319 |
+
def scale(self, factor):
|
320 |
+
return self._apply_to_paths("scale", factor)
|
321 |
+
|
322 |
+
def numericalize(self, n=256):
|
323 |
+
return self._apply_to_paths("numericalize", n)
|
324 |
+
|
325 |
+
def drop_z(self):
|
326 |
+
return self._apply_to_paths("set_closed", False)
|
327 |
+
|
328 |
+
def recompute_origins(self):
|
329 |
+
origin = self.origin
|
330 |
+
for path in self.svg_paths:
|
331 |
+
path.origin = origin.copy()
|
332 |
+
origin = path.end_pos
|
333 |
+
return self
|
334 |
+
|
335 |
+
def reorder(self):
|
336 |
+
self._apply_to_paths("reorder")
|
337 |
+
self.recompute_origins()
|
338 |
+
return self
|
339 |
+
|
340 |
+
def filter_empty(self):
|
341 |
+
self.svg_paths = [path for path in self.svg_paths if path.path_commands]
|
342 |
+
return self
|
343 |
+
|
344 |
+
def canonicalize(self):
|
345 |
+
self.svg_paths = sorted(self.svg_paths, key=lambda x: x.start_pos.tolist()[::-1])
|
346 |
+
if not self.svg_paths[0].is_clockwise():
|
347 |
+
self._apply_to_paths("reverse")
|
348 |
+
|
349 |
+
self.recompute_origins()
|
350 |
+
return self
|
351 |
+
|
352 |
+
def reverse(self):
|
353 |
+
self._apply_to_paths("reverse")
|
354 |
+
|
355 |
+
self.recompute_origins()
|
356 |
+
return self
|
357 |
+
|
358 |
+
def duplicate_extremities(self):
|
359 |
+
self._apply_to_paths("duplicate_extremities")
|
360 |
+
return self
|
361 |
+
|
362 |
+
def reverse_non_closed(self):
|
363 |
+
self._apply_to_paths("reverse_non_closed")
|
364 |
+
|
365 |
+
self.recompute_origins()
|
366 |
+
return self
|
367 |
+
|
368 |
+
def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False):
|
369 |
+
self._apply_to_paths("simplify", tolerance=tolerance, epsilon=epsilon, angle_threshold=angle_threshold,
|
370 |
+
force_smooth=force_smooth)
|
371 |
+
self.recompute_origins()
|
372 |
+
return self
|
373 |
+
|
374 |
+
def split_paths(self):
|
375 |
+
return [SVGPathGroup([svg_path], self.origin,
|
376 |
+
self.color, self.fill, self.dasharray, self.stroke_width, self.opacity)
|
377 |
+
for svg_path in self.svg_paths]
|
378 |
+
|
379 |
+
def split(self, n=None, max_dist=None, include_lines=True):
|
380 |
+
return self._apply_to_paths("split", n=n, max_dist=max_dist, include_lines=include_lines)
|
381 |
+
|
382 |
+
def simplify_arcs(self):
|
383 |
+
return self._apply_to_paths("simplify_arcs")
|
384 |
+
|
385 |
+
def filter_consecutives(self):
|
386 |
+
return self._apply_to_paths("filter_consecutives")
|
387 |
+
|
388 |
+
def filter_duplicates(self):
|
389 |
+
return self._apply_to_paths("filter_duplicates")
|
390 |
+
|
391 |
+
def bbox(self):
|
392 |
+
return union_bbox([path.bbox() for path in self.svg_paths])
|
393 |
+
|
394 |
+
def to_shapely(self):
|
395 |
+
return shapely.ops.unary_union([path.to_shapely() for path in self.svg_paths])
|
396 |
+
|
397 |
+
def compute_filling(self):
|
398 |
+
if self.fill:
|
399 |
+
G = self.overlap_graph()
|
400 |
+
|
401 |
+
root_nodes = [i for i, d in G.in_degree() if d == 0]
|
402 |
+
|
403 |
+
for root in root_nodes:
|
404 |
+
if not self.svg_paths[root].closed:
|
405 |
+
continue
|
406 |
+
|
407 |
+
current = [(1, root)]
|
408 |
+
|
409 |
+
while current:
|
410 |
+
visited = set()
|
411 |
+
neighbors = set()
|
412 |
+
for d, n in current:
|
413 |
+
self.svg_paths[n].set_filling(d != 0)
|
414 |
+
|
415 |
+
for n2 in G.neighbors(n):
|
416 |
+
if not n2 in visited:
|
417 |
+
d2 = d + (self.svg_paths[n2].is_clockwise() == self.svg_paths[n].is_clockwise()) * 2 - 1
|
418 |
+
visited.add(n2)
|
419 |
+
neighbors.add((d2, n2))
|
420 |
+
|
421 |
+
G.remove_nodes_from([n for d, n in current])
|
422 |
+
|
423 |
+
current = [(d, n) for d, n in neighbors if G.in_degree(n) == 0]
|
424 |
+
|
425 |
+
return self
|
426 |
+
|
427 |
+
def overlap_graph(self, threshold=0.9, draw=False):
|
428 |
+
G = nx.DiGraph()
|
429 |
+
shapes = [path.to_shapely() for path in self.svg_paths]
|
430 |
+
|
431 |
+
for i, path1 in enumerate(shapes):
|
432 |
+
G.add_node(i)
|
433 |
+
|
434 |
+
if self.svg_paths[i].closed:
|
435 |
+
for j, path2 in enumerate(shapes):
|
436 |
+
if i != j and self.svg_paths[j].closed:
|
437 |
+
overlap = path1.intersection(path2).area / path1.area
|
438 |
+
if overlap > threshold:
|
439 |
+
G.add_edge(j, i, weight=overlap)
|
440 |
+
|
441 |
+
if draw:
|
442 |
+
pos = nx.spring_layout(G)
|
443 |
+
nx.draw_networkx(G, pos, with_labels=True)
|
444 |
+
labels = nx.get_edge_attributes(G, 'weight')
|
445 |
+
nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
|
446 |
+
return G
|
447 |
+
|
448 |
+
def bbox_overlap(self, other: SVGPathGroup):
|
449 |
+
return self.bbox().overlap(other.bbox())
|
450 |
+
|
451 |
+
def to_points(self):
|
452 |
+
return np.concatenate([path.to_points() for path in self.svg_paths])
|
src/preprocessing/deepsvg/deepsvg_svglib/svglib_utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import src.preprocessing.deepsvg.deepsvg_svglib.svg as svg_lib
|
7 |
+
from .geom import Bbox, Point
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import IPython.display as ipd
|
11 |
+
from moviepy.editor import ImageClip, concatenate_videoclips, ipython_display
|
12 |
+
|
13 |
+
|
14 |
+
def make_grid(svgs, num_cols=3, grid_width=24):
|
15 |
+
"""
|
16 |
+
svgs: List[svg_lib.SVG]
|
17 |
+
"""
|
18 |
+
nb_rows = math.ceil(len(svgs) / num_cols)
|
19 |
+
grid = svg_lib.SVG([], viewbox=Bbox(grid_width * num_cols, grid_width * nb_rows))
|
20 |
+
|
21 |
+
for i, svg in enumerate(svgs):
|
22 |
+
row, col = i // num_cols, i % num_cols
|
23 |
+
svg = svg.copy().translate(Point(grid_width * col, grid_width * row))
|
24 |
+
|
25 |
+
grid.add_path_groups(svg.svg_path_groups)
|
26 |
+
|
27 |
+
return grid
|
28 |
+
|
29 |
+
|
30 |
+
def make_grid_grid(svg_grid, grid_width=24):
|
31 |
+
"""
|
32 |
+
svg_grid: List[List[svg_lib.SVG]]
|
33 |
+
"""
|
34 |
+
nb_rows = len(svg_grid)
|
35 |
+
num_cols = len(svg_grid[0])
|
36 |
+
grid = svg_lib.SVG([], viewbox=Bbox(grid_width * num_cols, grid_width * nb_rows))
|
37 |
+
|
38 |
+
for i, row in enumerate(svg_grid):
|
39 |
+
for j, svg in enumerate(row):
|
40 |
+
svg = svg.copy().translate(Point(grid_width * j, grid_width * i))
|
41 |
+
|
42 |
+
grid.add_path_groups(svg.svg_path_groups)
|
43 |
+
|
44 |
+
return grid
|
45 |
+
|
46 |
+
|
47 |
+
def make_grid_lines(svg_grid, grid_width=24):
|
48 |
+
"""
|
49 |
+
svg_grid: List[List[svg_lib.SVG]]
|
50 |
+
"""
|
51 |
+
nb_rows = len(svg_grid)
|
52 |
+
num_cols = max(len(r) for r in svg_grid)
|
53 |
+
grid = svg_lib.SVG([], viewbox=Bbox(grid_width * num_cols, grid_width * nb_rows))
|
54 |
+
|
55 |
+
for i, row in enumerate(svg_grid):
|
56 |
+
for j, svg in enumerate(row):
|
57 |
+
j_shift = (num_cols - len(row)) // 2
|
58 |
+
svg = svg.copy().translate(Point(grid_width * (j + j_shift), grid_width * i))
|
59 |
+
|
60 |
+
grid.add_path_groups(svg.svg_path_groups)
|
61 |
+
|
62 |
+
return grid
|
63 |
+
|
64 |
+
|
65 |
+
COLORS = ["aliceblue", "antiquewhite", "aqua", "aquamarine", "azure", "beige", "bisque", "black", "blanchedalmond",
|
66 |
+
"blue", "blueviolet", "brown", "burlywood", "cadetblue", "chartreuse", "chocolate", "coral", "cornflowerblue",
|
67 |
+
"cornsilk", "crimson", "cyan", "darkblue", "darkcyan", "darkgoldenrod", "darkgray", "darkgreen", "darkgrey",
|
68 |
+
"darkkhaki", "darkmagenta", "darkolivegreen", "darkorange", "darkorchid", "darkred", "darksalmon",
|
69 |
+
"darkseagreen", "darkslateblue", "darkslategray", "darkslategrey", "darkturquoise", "darkviolet", "deeppink",
|
70 |
+
"deepskyblue", "dimgray", "dimgrey", "dodgerblue", "firebrick", "floralwhite", "forestgreen", "fuchsia",
|
71 |
+
"gainsboro", "ghostwhite", "gold", "goldenrod", "gray", "green", "greenyellow", "grey", "honeydew", "hotpink",
|
72 |
+
"indianred", "indigo", "ivory", "khaki", "lavender", "lavenderblush", "lawngreen", "lemonchiffon",
|
73 |
+
"lightblue", "lightcoral", "lightcyan", "lightgoldenrodyellow", "lightgray", "lightgreen", "lightgrey",
|
74 |
+
"lightpink", "lightsalmon", "lightseagreen", "lightskyblue", "lightslategray", "lightslategrey",
|
75 |
+
"lightsteelblue", "lightyellow", "lime", "limegreen", "linen", "magenta", "maroon", "mediumaquamarine",
|
76 |
+
"mediumblue", "mediumorchid", "mediumpurple", "mediumseagreen", "mediumslateblue", "mediumspringgreen",
|
77 |
+
"mediumturquoise", "mediumvioletred", "midnightblue", "mintcream", "mistyrose", "moccasin", "navajowhite",
|
78 |
+
"navy", "oldlace", "olive", "olivedrab", "orange", "orangered", "orchid", "palegoldenrod", "palegreen",
|
79 |
+
"paleturquoise", "palevioletred", "papayawhip", "peachpuff", "peru", "pink", "plum", "powderblue", "purple",
|
80 |
+
"red", "rosybrown", "royalblue", "saddlebrown", "salmon", "sandybrown", "seagreen", "seashell", "sienna",
|
81 |
+
"silver", "skyblue", "slateblue", "slategray", "slategrey", "snow", "springgreen", "steelblue", "tan", "teal",
|
82 |
+
"thistle", "tomato", "turquoise", "violet", "wheat", "white", "whitesmoke", "yellow", "yellowgreen"]
|
83 |
+
|
84 |
+
|
85 |
+
def to_gif(img_list, file_path=None, frame_duration=0.1, do_display=True):
|
86 |
+
clips = [ImageClip(np.array(img)).set_duration(frame_duration) for img in img_list]
|
87 |
+
|
88 |
+
clip = concatenate_videoclips(clips, method="compose", bg_color=(255, 255, 255))
|
89 |
+
|
90 |
+
if file_path is not None:
|
91 |
+
clip.write_gif(file_path, fps=24, verbose=False, logger=None)
|
92 |
+
|
93 |
+
if do_display:
|
94 |
+
src = clip if file_path is None else file_path
|
95 |
+
ipd.display(ipython_display(src, fps=24, rd_kwargs=dict(logger=None), autoplay=1, loop=1))
|
src/preprocessing/deepsvg/deepsvg_svglib/util_fns.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
def get_roots(a, b, c):
|
10 |
+
if a == 0:
|
11 |
+
if b == 0:
|
12 |
+
return []
|
13 |
+
return [-c / b]
|
14 |
+
r = b * b - 4 * a * c
|
15 |
+
if r < 0:
|
16 |
+
return []
|
17 |
+
elif r == 0:
|
18 |
+
x0 = -b / (2 * a)
|
19 |
+
return [x0]
|
20 |
+
|
21 |
+
x1, x2 = (-b - math.sqrt(r)) / (2 * a), (-b + math.sqrt(r)) / (2 * a)
|
22 |
+
return x1, x2
|
src/preprocessing/deepsvg/deepsvg_utils/train_utils.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import shutil
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
import glob
|
13 |
+
|
14 |
+
|
15 |
+
def save_ckpt(checkpoint_dir, model, cfg=None, optimizer=None, scheduler_lr=None, scheduler_warmup=None,
|
16 |
+
stats=None, train_vars=None):
|
17 |
+
if is_multi_gpu(model):
|
18 |
+
model = model.module
|
19 |
+
|
20 |
+
state = {
|
21 |
+
"model": model.state_dict()
|
22 |
+
}
|
23 |
+
|
24 |
+
if optimizer is not None:
|
25 |
+
state["optimizer"] = optimizer.state_dict()
|
26 |
+
if scheduler_lr is not None:
|
27 |
+
state["scheduler_lr"] = scheduler_lr.state_dict()
|
28 |
+
if scheduler_warmup is not None:
|
29 |
+
state["scheduler_warmup"] = scheduler_warmup.state_dict()
|
30 |
+
if cfg is not None:
|
31 |
+
state["cfg"] = cfg.to_dict()
|
32 |
+
if stats is not None:
|
33 |
+
state["stats"] = stats.to_dict()
|
34 |
+
if train_vars is not None:
|
35 |
+
state["train_vars"] = train_vars.to_dict()
|
36 |
+
|
37 |
+
checkpoint_path = os.path.join(checkpoint_dir, "{:06d}.pth.tar".format(stats.step))
|
38 |
+
|
39 |
+
if not os.path.exists(checkpoint_dir):
|
40 |
+
os.makedirs(checkpoint_dir)
|
41 |
+
torch.save(state, checkpoint_path)
|
42 |
+
|
43 |
+
if stats.is_best():
|
44 |
+
best_model_path = os.path.join(checkpoint_dir, "best.pth.tar")
|
45 |
+
shutil.copyfile(checkpoint_path, best_model_path)
|
46 |
+
|
47 |
+
|
48 |
+
def save_ckpt_list(checkpoint_dir, model, cfg=None, optimizers=None, scheduler_lrs=None, scheduler_warmups=None,
|
49 |
+
stats=None, train_vars=None):
|
50 |
+
if is_multi_gpu(model):
|
51 |
+
model = model.module
|
52 |
+
|
53 |
+
state = {
|
54 |
+
"model": model.state_dict()
|
55 |
+
}
|
56 |
+
|
57 |
+
if optimizers is not None:
|
58 |
+
state["optimizers"] = [optimizer.state_dict() if optimizer is not None else optimizer for optimizer in optimizers]
|
59 |
+
if scheduler_lrs is not None:
|
60 |
+
state["scheduler_lrs"] = [scheduler_lr.state_dict() if scheduler_lr is not None else scheduler_lr for scheduler_lr in scheduler_lrs]
|
61 |
+
if scheduler_warmups is not None:
|
62 |
+
state["scheduler_warmups"] = [scheduler_warmup.state_dict() if scheduler_warmup is not None else None for scheduler_warmup in scheduler_warmups]
|
63 |
+
if cfg is not None:
|
64 |
+
state["cfg"] = cfg.to_dict()
|
65 |
+
if stats is not None:
|
66 |
+
state["stats"] = stats.to_dict()
|
67 |
+
if train_vars is not None:
|
68 |
+
state["train_vars"] = train_vars.to_dict()
|
69 |
+
|
70 |
+
checkpoint_path = os.path.join(checkpoint_dir, "{:06d}.pth.tar".format(stats.step))
|
71 |
+
|
72 |
+
if not os.path.exists(checkpoint_dir):
|
73 |
+
os.makedirs(checkpoint_dir)
|
74 |
+
torch.save(state, checkpoint_path)
|
75 |
+
|
76 |
+
if stats.is_best():
|
77 |
+
best_model_path = os.path.join(checkpoint_dir, "best.pth.tar")
|
78 |
+
shutil.copyfile(checkpoint_path, best_model_path)
|
79 |
+
|
80 |
+
|
81 |
+
def load_ckpt(checkpoint_dir, model, cfg=None, optimizer=None, scheduler_lr=None, scheduler_warmup=None,
|
82 |
+
stats=None, train_vars=None):
|
83 |
+
if not os.path.exists(checkpoint_dir):
|
84 |
+
return False
|
85 |
+
|
86 |
+
if os.path.isfile(checkpoint_dir):
|
87 |
+
checkpoint_path = checkpoint_dir
|
88 |
+
else:
|
89 |
+
ckpts_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "./[0-9]*.pth.tar")))
|
90 |
+
if not ckpts_paths:
|
91 |
+
return False
|
92 |
+
checkpoint_path = ckpts_paths[-1]
|
93 |
+
|
94 |
+
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
95 |
+
|
96 |
+
if is_multi_gpu(model):
|
97 |
+
model = model.module
|
98 |
+
model.load_state_dict(state["model"], strict=False)
|
99 |
+
|
100 |
+
if optimizer is not None:
|
101 |
+
optimizer.load_state_dict(state["optimizer"])
|
102 |
+
if scheduler_lr is not None:
|
103 |
+
scheduler_lr.load_state_dict(state["scheduler_lr"])
|
104 |
+
if scheduler_warmup is not None:
|
105 |
+
scheduler_warmup.load_state_dict(state["scheduler_warmup"])
|
106 |
+
if cfg is not None:
|
107 |
+
cfg.load_dict(state["cfg"])
|
108 |
+
if stats is not None:
|
109 |
+
stats.load_dict(state["stats"])
|
110 |
+
if train_vars is not None:
|
111 |
+
train_vars.load_dict(state["train_vars"])
|
112 |
+
|
113 |
+
return True
|
114 |
+
|
115 |
+
|
116 |
+
def load_ckpt_list(checkpoint_dir, model, cfg=None, optimizers=None, scheduler_lrs=None, scheduler_warmups=None,
|
117 |
+
stats=None, train_vars=None):
|
118 |
+
if not os.path.exists(checkpoint_dir):
|
119 |
+
return False
|
120 |
+
|
121 |
+
if os.path.isfile(checkpoint_dir):
|
122 |
+
checkpoint_path = checkpoint_dir
|
123 |
+
else:
|
124 |
+
ckpts_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "./[0-9]*.pth.tar")))
|
125 |
+
if not ckpts_paths:
|
126 |
+
return False
|
127 |
+
checkpoint_path = ckpts_paths[-1]
|
128 |
+
|
129 |
+
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
130 |
+
|
131 |
+
if is_multi_gpu(model):
|
132 |
+
model = model.module
|
133 |
+
model.load_state_dict(state["model"], strict=False)
|
134 |
+
|
135 |
+
for optimizer, scheduler_lr, scheduler_warmup, optimizer_sd, scheduler_lr_sd, scheduler_warmups_sd in zip(optimizers, scheduler_lrs, scheduler_warmups, state["optimizers"], state["scheduler_lrs"], state["scheduler_warmups"]):
|
136 |
+
if optimizer is not None and optimizer_sd is not None:
|
137 |
+
optimizer.load_state_dict(optimizer_sd)
|
138 |
+
if scheduler_lr is not None and scheduler_lr_sd is not None:
|
139 |
+
scheduler_lr.load_state_dict(scheduler_lr_sd)
|
140 |
+
if scheduler_warmup is not None and scheduler_warmups_sd is not None:
|
141 |
+
scheduler_warmup.load_state_dict(scheduler_warmups_sd)
|
142 |
+
if cfg is not None and state["cfg"] is not None:
|
143 |
+
cfg.load_dict(state["cfg"])
|
144 |
+
if stats is not None and state["stats"] is not None:
|
145 |
+
stats.load_dict(state["stats"])
|
146 |
+
if train_vars is not None and state["train_vars"] is not None:
|
147 |
+
train_vars.load_dict(state["train_vars"])
|
148 |
+
|
149 |
+
return True
|
150 |
+
|
151 |
+
|
152 |
+
def load_model(checkpoint_path, model):
|
153 |
+
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
154 |
+
|
155 |
+
if is_multi_gpu(model):
|
156 |
+
model = model.module
|
157 |
+
model.load_state_dict(state["model"], strict=False)
|
158 |
+
|
159 |
+
|
160 |
+
def is_multi_gpu(model):
|
161 |
+
return isinstance(model, nn.DataParallel)
|
162 |
+
|
163 |
+
|
164 |
+
def count_parameters(model):
|
165 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
166 |
+
|
167 |
+
|
168 |
+
def pad_sequence(sequences, batch_first=False, padding_value=0, max_len=None):
|
169 |
+
r"""Pad a list of variable length Tensors with ``padding_value``
|
170 |
+
|
171 |
+
``pad_sequence`` stacks a list of Tensors along a new dimension,
|
172 |
+
and pads them to equal length. For example, if the input is list of
|
173 |
+
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
|
174 |
+
otherwise.
|
175 |
+
|
176 |
+
`B` is batch size. It is equal to the number of elements in ``sequences``.
|
177 |
+
`T` is length of the longest sequence.
|
178 |
+
`L` is length of the sequence.
|
179 |
+
`*` is any number of trailing dimensions, including none.
|
180 |
+
|
181 |
+
Example:
|
182 |
+
>>> from torch.nn.utils.rnn import pad_sequence
|
183 |
+
>>> a = torch.ones(25, 300)
|
184 |
+
>>> b = torch.ones(22, 300)
|
185 |
+
>>> c = torch.ones(15, 300)
|
186 |
+
>>> pad_sequence([a, b, c]).size()
|
187 |
+
torch.Size([25, 3, 300])
|
188 |
+
|
189 |
+
Note:
|
190 |
+
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
|
191 |
+
where `T` is the length of the longest sequence. This function assumes
|
192 |
+
trailing dimensions and type of all the Tensors in sequences are same.
|
193 |
+
|
194 |
+
Arguments:
|
195 |
+
sequences (list[Tensor]): list of variable length sequences.
|
196 |
+
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
|
197 |
+
``T x B x *`` otherwise
|
198 |
+
padding_value (float, optional): value for padded elements. Default: 0.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
|
202 |
+
Tensor of size ``B x T x *`` otherwise
|
203 |
+
"""
|
204 |
+
|
205 |
+
# assuming trailing dimensions and type of all the Tensors
|
206 |
+
# in sequences are same and fetching those from sequences[0]
|
207 |
+
max_size = sequences[0].size()
|
208 |
+
trailing_dims = max_size[1:]
|
209 |
+
|
210 |
+
if max_len is None:
|
211 |
+
max_len = max([s.size(0) for s in sequences])
|
212 |
+
if batch_first:
|
213 |
+
out_dims = (len(sequences), max_len) + trailing_dims
|
214 |
+
else:
|
215 |
+
out_dims = (max_len, len(sequences)) + trailing_dims
|
216 |
+
|
217 |
+
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
|
218 |
+
for i, tensor in enumerate(sequences):
|
219 |
+
length = tensor.size(0)
|
220 |
+
# use index notation to prevent duplicate references to the tensor
|
221 |
+
if batch_first:
|
222 |
+
out_tensor[i, :length, ...] = tensor
|
223 |
+
else:
|
224 |
+
out_tensor[:length, i, ...] = tensor
|
225 |
+
|
226 |
+
return out_tensor
|
227 |
+
|
228 |
+
|
229 |
+
def set_seed(_seed=42):
|
230 |
+
random.seed(_seed)
|
231 |
+
np.random.seed(_seed)
|
232 |
+
torch.manual_seed(_seed)
|
233 |
+
torch.cuda.manual_seed(_seed)
|
234 |
+
torch.cuda.manual_seed_all(_seed)
|
235 |
+
os.environ['PYTHONHASHSEED'] = str(_seed)
|
236 |
+
|
237 |
+
|
238 |
+
def infinite_range(start_idx=0):
|
239 |
+
while True:
|
240 |
+
yield start_idx
|
241 |
+
start_idx += 1
|
src/preprocessing/deepsvg/deepsvg_utils/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is taken from <https://github.com/alexandre01/deepsvg>
|
2 |
+
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
|
3 |
+
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def linear(a, b, x, min_x, max_x):
|
10 |
+
"""
|
11 |
+
b ___________
|
12 |
+
/|
|
13 |
+
/ |
|
14 |
+
a _______/ |
|
15 |
+
| |
|
16 |
+
min_x max_x
|
17 |
+
"""
|
18 |
+
return a + min(max((x - min_x) / (max_x - min_x), 0), 1) * (b - a)
|
19 |
+
|
20 |
+
|
21 |
+
def batchify(data, device):
|
22 |
+
return (d.unsqueeze(0).to(device) for d in data)
|
23 |
+
|
24 |
+
|
25 |
+
def _make_seq_first(*args):
|
26 |
+
# N, G, S, ... -> S, G, N, ...
|
27 |
+
if len(args) == 1:
|
28 |
+
arg, = args
|
29 |
+
return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
|
30 |
+
return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)
|
31 |
+
|
32 |
+
|
33 |
+
def _make_batch_first(*args):
|
34 |
+
# S, G, N, ... -> N, G, S, ...
|
35 |
+
if len(args) == 1:
|
36 |
+
arg, = args
|
37 |
+
return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
|
38 |
+
return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)
|
39 |
+
|
40 |
+
|
41 |
+
def _pack_group_batch(*args):
|
42 |
+
# S, G, N, ... -> S, G * N, ...
|
43 |
+
if len(args) == 1:
|
44 |
+
arg, = args
|
45 |
+
return arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None
|
46 |
+
return (*(arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None for arg in args),)
|
47 |
+
|
48 |
+
|
49 |
+
def _unpack_group_batch(N, *args):
|
50 |
+
# S, G * N, ... -> S, G, N, ...
|
51 |
+
if len(args) == 1:
|
52 |
+
arg, = args
|
53 |
+
return arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None
|
54 |
+
return (*(arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None for arg in args),)
|
src/preprocessing/preprocessing.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Imports
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import glob
|
6 |
+
import pandas as pd
|
7 |
+
import pickle
|
8 |
+
from xml.dom import minidom
|
9 |
+
from svgpathtools import svg2paths2
|
10 |
+
from svgpathtools import wsvg
|
11 |
+
import sys
|
12 |
+
sys.path.append(os.getcwd())
|
13 |
+
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
|
14 |
+
from src.preprocessing.deepsvg.deepsvg_config import config_hierarchical_ordered
|
15 |
+
from src.preprocessing.deepsvg.deepsvg_utils import train_utils
|
16 |
+
from src.preprocessing.deepsvg.deepsvg_utils import utils
|
17 |
+
from src.preprocessing.deepsvg.deepsvg_dataloader import svg_dataset
|
18 |
+
|
19 |
+
# ---- Methods for embedding logos ----
|
20 |
+
|
21 |
+
def compute_embedding_folder(folder_path: str, model_path: str, save: str = None) -> pd.DataFrame:
|
22 |
+
data_list = []
|
23 |
+
for file in os.listdir(folder_path):
|
24 |
+
print('File: ' + file)
|
25 |
+
try:
|
26 |
+
embedding = compute_embedding(os.path.join(folder_path, file), model_path)
|
27 |
+
embedding['filename'] = file
|
28 |
+
data_list.append(embedding)
|
29 |
+
except:
|
30 |
+
print('Embedding failed')
|
31 |
+
print('Concatenating')
|
32 |
+
data = pd.concat(data_list)
|
33 |
+
if not save == None:
|
34 |
+
output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
|
35 |
+
pickle.dump(data, output)
|
36 |
+
output.close()
|
37 |
+
return data
|
38 |
+
|
39 |
+
|
40 |
+
def compute_embedding(path: str, model_path: str, save: str = None) -> pd.DataFrame:
|
41 |
+
# Convert all primitives to SVG paths - TODO text
|
42 |
+
paths, attributes, svg_attributes = svg2paths2(path) # In previous project, this is performed at the end
|
43 |
+
wsvg(paths, attributes=attributes, svg_attributes=svg_attributes, filename=path)
|
44 |
+
|
45 |
+
svg = SVG.load_svg(path)
|
46 |
+
svg.normalize() # Using DeepSVG normalize instead of expanding viewbox - TODO check is this equal?
|
47 |
+
svg_str = svg.to_str()
|
48 |
+
|
49 |
+
# Assign animation id to every path - TODO this changes the original logo!
|
50 |
+
document = minidom.parseString(svg_str)
|
51 |
+
paths = document.getElementsByTagName('path')
|
52 |
+
for i in range(len(paths)):
|
53 |
+
paths[i].setAttribute('animation_id', str(i))
|
54 |
+
with open(path, 'wb') as svg_file:
|
55 |
+
svg_file.write(document.toxml(encoding='iso-8859-1'))
|
56 |
+
|
57 |
+
# Decompose SVGs
|
58 |
+
|
59 |
+
decomposed_svgs = {}
|
60 |
+
|
61 |
+
for i in range(len(paths)):
|
62 |
+
doc_temp = copy.deepcopy(document)
|
63 |
+
paths_temp = doc_temp.getElementsByTagName('path')
|
64 |
+
current_path = paths_temp[i]
|
65 |
+
# Iteratively choose path i and remove all others
|
66 |
+
remove_temp = paths_temp[:i] + paths_temp[i+1:]
|
67 |
+
for path in remove_temp:
|
68 |
+
if not path.parentNode.nodeName == 'clipPath':
|
69 |
+
path.parentNode.removeChild(path)
|
70 |
+
# Check for style attributes; add in case there are none
|
71 |
+
if len(current_path.getAttribute('style')) <= 0:
|
72 |
+
current_path.setAttribute('stroke', 'black')
|
73 |
+
current_path.setAttribute('stroke-width', '2')
|
74 |
+
id = current_path.getAttribute('animation_id')
|
75 |
+
decomposed_svgs[id] = doc_temp.toprettyxml(encoding='iso-8859-1')
|
76 |
+
doc_temp.unlink()
|
77 |
+
#print(decomposed_svgs)
|
78 |
+
meta = {}
|
79 |
+
for id in decomposed_svgs:
|
80 |
+
svg_d_str = decomposed_svgs[id]
|
81 |
+
# Load into SVG and canonicalize
|
82 |
+
current_svg = SVG.from_str(svg_d_str)
|
83 |
+
# Canonicalize
|
84 |
+
current_svg.canonicalize() # Applies DeepSVG canonicalize; previously custom methods were used
|
85 |
+
decomposed_svgs[id] = current_svg.to_str()
|
86 |
+
if not os.path.exists('data/temp_svg'):
|
87 |
+
os.mkdir('data/temp_svg')
|
88 |
+
with open(('data/temp_svg/path_' + str(id)) + '.svg', 'w') as svg_file:
|
89 |
+
svg_file.write(decomposed_svgs[id])
|
90 |
+
|
91 |
+
# Collect metadata
|
92 |
+
len_groups = [path_group.total_len() for path_group in current_svg.svg_path_groups]
|
93 |
+
start_pos = [path_group.svg_paths[0].start_pos for path_group in current_svg.svg_path_groups]
|
94 |
+
try:
|
95 |
+
total_len = sum(len_groups)
|
96 |
+
nb_groups = len(len_groups)
|
97 |
+
max_len_group = max(len_groups)
|
98 |
+
except:
|
99 |
+
total_len = 0
|
100 |
+
nb_groups = 0
|
101 |
+
max_len_group = 0
|
102 |
+
|
103 |
+
meta[id] = {
|
104 |
+
'id': id,
|
105 |
+
'total_len': total_len,
|
106 |
+
'nb_groups': nb_groups,
|
107 |
+
'len_groups': len_groups,
|
108 |
+
'max_len_group': max_len_group,
|
109 |
+
'start_pos': start_pos
|
110 |
+
}
|
111 |
+
metadata = pd.DataFrame(meta.values())
|
112 |
+
#print(metadata)
|
113 |
+
if not os.path.exists('data/metadata'):
|
114 |
+
os.mkdir('data/metadata')
|
115 |
+
metadata.to_csv('data/metadata/metadata.csv', index=False)
|
116 |
+
# Load pretrained DeepSVG model
|
117 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
118 |
+
cfg = config_hierarchical_ordered.Config()
|
119 |
+
model = cfg.make_model().to(device)
|
120 |
+
train_utils.load_model(model_path, model)
|
121 |
+
model.eval()
|
122 |
+
# Load dataset
|
123 |
+
cfg.data_dir = 'data/temp_svg/'
|
124 |
+
cfg.meta_filepath = 'data/metadata/metadata.csv'
|
125 |
+
dataset = svg_dataset.load_dataset(cfg)
|
126 |
+
svg_files = glob.glob('data/temp_svg/*.svg')
|
127 |
+
#print(svg_files)
|
128 |
+
svg_list = []
|
129 |
+
for svg_file in svg_files:
|
130 |
+
id = svg_file.split('\\')[1].split('_')[1].split('.')[0]
|
131 |
+
# Preprocessing
|
132 |
+
svg = SVG.load_svg(svg_file)
|
133 |
+
svg = dataset.simplify(svg)
|
134 |
+
svg = dataset.preprocess(svg, augment=False)
|
135 |
+
data = dataset.get(svg=svg)
|
136 |
+
# Get embedding
|
137 |
+
model_args = utils.batchify((data[key] for key in cfg.model_args), device)
|
138 |
+
with torch.no_grad():
|
139 |
+
z = model(*model_args, encode_mode=True).cpu().numpy()[0][0][0]
|
140 |
+
dict_data = {
|
141 |
+
'animation_id': id,
|
142 |
+
'embedding': z
|
143 |
+
}
|
144 |
+
svg_list.append(dict_data)
|
145 |
+
data = pd.DataFrame.from_records(svg_list, index='animation_id')['embedding'].apply(pd.Series)
|
146 |
+
data.reset_index(level=0, inplace=True)
|
147 |
+
data.dropna(inplace=True)
|
148 |
+
data.reset_index(drop=True, inplace=True)
|
149 |
+
if not save == None:
|
150 |
+
output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
|
151 |
+
pickle.dump(data, output)
|
152 |
+
output.close()
|
153 |
+
print('Embedding computed')
|
154 |
+
return data
|
155 |
+
|
156 |
+
|
157 |
+
#compute_embedding_folder('data/raw_dataset', 'src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar', 'data/embedding')
|