Daniel Gil-U Fuhge commited on
Commit
f7da327
1 Parent(s): 91b5220

update to new temperature approach

Browse files
AnimationTransformer.py CHANGED
@@ -1,5 +1,6 @@
1
  import math
2
  import time
 
3
 
4
  import torch
5
  import torch.nn as nn
@@ -170,33 +171,39 @@ def fit(model, optimizer, loss_function, train_dataloader, val_dataloader, epoch
170
  return train_loss_list, validation_loss_list
171
 
172
 
173
- def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True):
174
  if backpropagate:
175
  model.train()
176
  else:
177
- model.eval()
178
 
179
  source_sequence = source_sequence.float().to(device)
180
  y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
181
-
182
  i = 0
183
  while i < max_length:
184
  # Get source mask
 
185
  prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
186
  # tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
187
  src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
188
-
189
  next_embedding = prediction[0, -1, :] # prediction on last token
 
190
  pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
191
  #print(pred_deep_svg, pred_type, pred_parameters)
192
  pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
193
  device)
194
 
 
 
195
  # === TYPE ===
196
  # Apply Softmax
197
  type_softmax = torch.softmax(pred_type, dim=0)
198
  type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
199
- animation_type = torch.argmax(type_softmax, dim=0)
 
 
 
200
 
201
  # Break if EOS is most likely
202
  if animation_type == 0:
@@ -222,6 +229,7 @@ def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=
222
 
223
  # === SEQUENCE ===
224
  y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
 
225
  y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
226
 
227
  # === INFO PRINT ===
 
1
  import math
2
  import time
3
+ import random
4
 
5
  import torch
6
  import torch.nn as nn
 
171
  return train_loss_list, validation_loss_list
172
 
173
 
174
+ def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True, temperature=1):
175
  if backpropagate:
176
  model.train()
177
  else:
178
+ model.eval()
179
 
180
  source_sequence = source_sequence.float().to(device)
181
  y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
182
+ #print(source_sequence, source_sequence.unsqueeze(0))
183
  i = 0
184
  while i < max_length:
185
  # Get source mask
186
+ #print(y_input, y_input.unsqueeze(0))
187
  prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
188
  # tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
189
  src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
 
190
  next_embedding = prediction[0, -1, :] # prediction on last token
191
+
192
  pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
193
  #print(pred_deep_svg, pred_type, pred_parameters)
194
  pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
195
  device)
196
 
197
+ pred_type = pred_type / temperature
198
+
199
  # === TYPE ===
200
  # Apply Softmax
201
  type_softmax = torch.softmax(pred_type, dim=0)
202
  type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
203
+
204
+ indices = torch.argsort(type_softmax, descending=True)
205
+ animation_type = random.choice(indices[:3])
206
+ #animation_type = torch.argmax(type_softmax, dim=0)
207
 
208
  # Break if EOS is most likely
209
  if animation_type == 0:
 
229
 
230
  # === SEQUENCE ===
231
  y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
232
+ #y_new = torch.concat([pred_deep_svg, pred_type.to(device), pred_parameters], dim=0)
233
  y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
234
 
235
  # === INFO PRINT ===
animationPipeline.py CHANGED
@@ -15,10 +15,10 @@ def animateLogo(path : str, targetPath : str):
15
  except Exception as e:
16
  print(f"An error occurred: {e}")
17
  #transformer
18
- NUM_HEADS = 6 # Dividers of 282: {1, 2, 3, 6, 47, 94, 141, 282}
19
- NUM_ENCODER_LAYERS = 2
20
- NUM_DECODER_LAYERS = 8
21
- DROPOUT=0.1
22
  # CONSTANTS
23
  FEATURE_DIM = 282
24
 
@@ -34,7 +34,7 @@ def animateLogo(path : str, targetPath : str):
34
  use_positional_encoder=True
35
  ).to(device)
36
 
37
- model.load_state_dict(torch.load("models/animation_transformer.pth", map_location=torch.device('cpu')), strict=False)
38
 
39
  df = compute_embedding(path, "models/deepSVG_hierarchical_ordered.pth.tar")
40
  df = df.drop("animation_id", axis=1)
@@ -46,7 +46,7 @@ def animateLogo(path : str, targetPath : str):
46
 
47
  sos_token = torch.zeros(282)
48
  sos_token[256] = 1
49
- result = predict(model, inp, sos_token=sos_token, device=device, max_length=inp.shape[0], eos_scaling=1)
50
  result = pd.DataFrame(result[1:, -26:].cpu().detach().numpy())
51
  result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
52
  result["animation_id"] = range(len(result))
 
15
  except Exception as e:
16
  print(f"An error occurred: {e}")
17
  #transformer
18
+ NUM_HEADS = 47 # Dividers of 282: {1, 2, 3, 6, 47, 94, 141, 282}
19
+ NUM_ENCODER_LAYERS = 6
20
+ NUM_DECODER_LAYERS = 4
21
+ DROPOUT=0.21
22
  # CONSTANTS
23
  FEATURE_DIM = 282
24
 
 
34
  use_positional_encoder=True
35
  ).to(device)
36
 
37
+ model.load_state_dict(torch.load("models/animation_transformer2.pth", map_location=torch.device('cpu')), strict=False)
38
 
39
  df = compute_embedding(path, "models/deepSVG_hierarchical_ordered.pth.tar")
40
  df = df.drop("animation_id", axis=1)
 
46
 
47
  sos_token = torch.zeros(282)
48
  sos_token[256] = 1
49
+ result = predict(model, inp, sos_token=sos_token, device=device, max_length=inp.shape[0], eos_scaling=0.5, temperature=100)
50
  result = pd.DataFrame(result[1:, -26:].cpu().detach().numpy())
51
  result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
52
  result["animation_id"] = range(len(result))
models/{animation_transformer.pth → animation_transformer2.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12ae92d0b1a5ada8a8681122f76ea7c4e6b3fdf0169dd4b3a5d908899e563f86
3
- size 60658902
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e63638f545c6f925a1a6d31578d507834de8ed30b71db2a0762c86859c597c44
3
+ size 69927679