Spaces:
Running
Running
Daniel Gil-U Fuhge
commited on
Commit
•
f7da327
1
Parent(s):
91b5220
update to new temperature approach
Browse files
AnimationTransformer.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import math
|
2 |
import time
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
@@ -170,33 +171,39 @@ def fit(model, optimizer, loss_function, train_dataloader, val_dataloader, epoch
|
|
170 |
return train_loss_list, validation_loss_list
|
171 |
|
172 |
|
173 |
-
def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True):
|
174 |
if backpropagate:
|
175 |
model.train()
|
176 |
else:
|
177 |
-
model.eval()
|
178 |
|
179 |
source_sequence = source_sequence.float().to(device)
|
180 |
y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
|
181 |
-
|
182 |
i = 0
|
183 |
while i < max_length:
|
184 |
# Get source mask
|
|
|
185 |
prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
|
186 |
# tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
|
187 |
src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
|
188 |
-
|
189 |
next_embedding = prediction[0, -1, :] # prediction on last token
|
|
|
190 |
pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
|
191 |
#print(pred_deep_svg, pred_type, pred_parameters)
|
192 |
pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
|
193 |
device)
|
194 |
|
|
|
|
|
195 |
# === TYPE ===
|
196 |
# Apply Softmax
|
197 |
type_softmax = torch.softmax(pred_type, dim=0)
|
198 |
type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
|
199 |
-
|
|
|
|
|
|
|
200 |
|
201 |
# Break if EOS is most likely
|
202 |
if animation_type == 0:
|
@@ -222,6 +229,7 @@ def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=
|
|
222 |
|
223 |
# === SEQUENCE ===
|
224 |
y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
|
|
|
225 |
y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
|
226 |
|
227 |
# === INFO PRINT ===
|
|
|
1 |
import math
|
2 |
import time
|
3 |
+
import random
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
171 |
return train_loss_list, validation_loss_list
|
172 |
|
173 |
|
174 |
+
def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True, temperature=1):
|
175 |
if backpropagate:
|
176 |
model.train()
|
177 |
else:
|
178 |
+
model.eval()
|
179 |
|
180 |
source_sequence = source_sequence.float().to(device)
|
181 |
y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
|
182 |
+
#print(source_sequence, source_sequence.unsqueeze(0))
|
183 |
i = 0
|
184 |
while i < max_length:
|
185 |
# Get source mask
|
186 |
+
#print(y_input, y_input.unsqueeze(0))
|
187 |
prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
|
188 |
# tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
|
189 |
src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
|
|
|
190 |
next_embedding = prediction[0, -1, :] # prediction on last token
|
191 |
+
|
192 |
pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
|
193 |
#print(pred_deep_svg, pred_type, pred_parameters)
|
194 |
pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
|
195 |
device)
|
196 |
|
197 |
+
pred_type = pred_type / temperature
|
198 |
+
|
199 |
# === TYPE ===
|
200 |
# Apply Softmax
|
201 |
type_softmax = torch.softmax(pred_type, dim=0)
|
202 |
type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
|
203 |
+
|
204 |
+
indices = torch.argsort(type_softmax, descending=True)
|
205 |
+
animation_type = random.choice(indices[:3])
|
206 |
+
#animation_type = torch.argmax(type_softmax, dim=0)
|
207 |
|
208 |
# Break if EOS is most likely
|
209 |
if animation_type == 0:
|
|
|
229 |
|
230 |
# === SEQUENCE ===
|
231 |
y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
|
232 |
+
#y_new = torch.concat([pred_deep_svg, pred_type.to(device), pred_parameters], dim=0)
|
233 |
y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
|
234 |
|
235 |
# === INFO PRINT ===
|
animationPipeline.py
CHANGED
@@ -15,10 +15,10 @@ def animateLogo(path : str, targetPath : str):
|
|
15 |
except Exception as e:
|
16 |
print(f"An error occurred: {e}")
|
17 |
#transformer
|
18 |
-
NUM_HEADS =
|
19 |
-
NUM_ENCODER_LAYERS =
|
20 |
-
NUM_DECODER_LAYERS =
|
21 |
-
DROPOUT=0.
|
22 |
# CONSTANTS
|
23 |
FEATURE_DIM = 282
|
24 |
|
@@ -34,7 +34,7 @@ def animateLogo(path : str, targetPath : str):
|
|
34 |
use_positional_encoder=True
|
35 |
).to(device)
|
36 |
|
37 |
-
model.load_state_dict(torch.load("models/
|
38 |
|
39 |
df = compute_embedding(path, "models/deepSVG_hierarchical_ordered.pth.tar")
|
40 |
df = df.drop("animation_id", axis=1)
|
@@ -46,7 +46,7 @@ def animateLogo(path : str, targetPath : str):
|
|
46 |
|
47 |
sos_token = torch.zeros(282)
|
48 |
sos_token[256] = 1
|
49 |
-
result = predict(model, inp, sos_token=sos_token, device=device, max_length=inp.shape[0], eos_scaling=
|
50 |
result = pd.DataFrame(result[1:, -26:].cpu().detach().numpy())
|
51 |
result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
|
52 |
result["animation_id"] = range(len(result))
|
|
|
15 |
except Exception as e:
|
16 |
print(f"An error occurred: {e}")
|
17 |
#transformer
|
18 |
+
NUM_HEADS = 47 # Dividers of 282: {1, 2, 3, 6, 47, 94, 141, 282}
|
19 |
+
NUM_ENCODER_LAYERS = 6
|
20 |
+
NUM_DECODER_LAYERS = 4
|
21 |
+
DROPOUT=0.21
|
22 |
# CONSTANTS
|
23 |
FEATURE_DIM = 282
|
24 |
|
|
|
34 |
use_positional_encoder=True
|
35 |
).to(device)
|
36 |
|
37 |
+
model.load_state_dict(torch.load("models/animation_transformer2.pth", map_location=torch.device('cpu')), strict=False)
|
38 |
|
39 |
df = compute_embedding(path, "models/deepSVG_hierarchical_ordered.pth.tar")
|
40 |
df = df.drop("animation_id", axis=1)
|
|
|
46 |
|
47 |
sos_token = torch.zeros(282)
|
48 |
sos_token[256] = 1
|
49 |
+
result = predict(model, inp, sos_token=sos_token, device=device, max_length=inp.shape[0], eos_scaling=0.5, temperature=100)
|
50 |
result = pd.DataFrame(result[1:, -26:].cpu().detach().numpy())
|
51 |
result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
|
52 |
result["animation_id"] = range(len(result))
|
models/{animation_transformer.pth → animation_transformer2.pth}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e63638f545c6f925a1a6d31578d507834de8ed30b71db2a0762c86859c597c44
|
3 |
+
size 69927679
|