Daniel Gil-U Fuhge commited on
Commit
a09c67b
1 Parent(s): 2b65f48

update postprocessing

Browse files
animationPipeline.py CHANGED
@@ -5,8 +5,15 @@ from AnimationTransformer import predict
5
  import torch.nn as torch
6
  import torch
7
  import pandas as pd
8
-
9
- def animateLogo(path : str):
 
 
 
 
 
 
 
10
  #transformer
11
  NUM_HEADS = 6 # Dividers of 282: {1, 2, 3, 6, 47, 94, 141, 282}
12
  NUM_ENCODER_LAYERS = 2
@@ -44,7 +51,7 @@ def animateLogo(path : str):
44
  result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
45
  result["animation_id"] = range(len(result))
46
  print(result, path)
47
- animate_logo(result, path)
48
 
49
  #logo = "data/examples/logo_181.svg"
50
  #animateLogo(logo)
 
5
  import torch.nn as torch
6
  import torch
7
  import pandas as pd
8
+ import shutil
9
+
10
+ def animateLogo(path : str, targetPath : str):
11
+ try:
12
+ # Copy the original file to the new location with the new filename
13
+ shutil.copyfile(path, targetPath)
14
+ print(f"File copied and renamed to {targetPath}")
15
+ except Exception as e:
16
+ print(f"An error occurred: {e}")
17
  #transformer
18
  NUM_HEADS = 6 # Dividers of 282: {1, 2, 3, 6, 47, 94, 141, 282}
19
  NUM_ENCODER_LAYERS = 2
 
51
  result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
52
  result["animation_id"] = range(len(result))
53
  print(result, path)
54
+ animate_logo(result, targetPath)
55
 
56
  #logo = "data/examples/logo_181.svg"
57
  #animateLogo(logo)
app.py CHANGED
@@ -21,5 +21,5 @@ if uploaded_file is not None:
21
  sys.setrecursionlimit(1500)
22
  animateLogo(path)
23
  with open(path, "rb") as file:
24
- st.download_button('Download animated SVG', data=file, file_name=uploaded_file.name+"_animated.svg")
25
 
 
21
  sys.setrecursionlimit(1500)
22
  animateLogo(path)
23
  with open(path, "rb") as file:
24
+ st.download_button('Download animated SVG', data=file, file_name=uploaded_file.name[:-3]+"_animated.svg")
25
 
src/postprocessing/postprocessing.py CHANGED
@@ -15,6 +15,16 @@ random.seed(0)
15
  filter_id = 0
16
 
17
  def animate_logo(model_output: pd.DataFrame, logo_path: str):
 
 
 
 
 
 
 
 
 
 
18
  logo_xmin, logo_xmax, logo_ymin, logo_ymax = get_svg_bbox(logo_path)
19
  # ---- Normalize model output ----
20
  animations_by_id = defaultdict(list)
 
15
  filter_id = 0
16
 
17
  def animate_logo(model_output: pd.DataFrame, logo_path: str):
18
+ # Add animation id
19
+ document = minidom.parse(logo_path)
20
+
21
+ paths = document.getElementsByTagName('path')
22
+ for i in range(len(paths)):
23
+ paths[i].setAttribute('animation_id', str(i))
24
+ with open(path, 'wb') as svg_file:
25
+ svg_file.write(document.toxml(encoding='iso-8859-1'))
26
+
27
+
28
  logo_xmin, logo_xmax, logo_ymin, logo_ymax = get_svg_bbox(logo_path)
29
  # ---- Normalize model output ----
30
  animations_by_id = defaultdict(list)