Spaces:
Running
Running
File size: 6,194 Bytes
e17e8cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Imports
import os
import copy
import torch
import glob
import pandas as pd
import pickle
from xml.dom import minidom
from svgpathtools import svg2paths2
from svgpathtools import wsvg
import sys
sys.path.append(os.getcwd())
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
from src.preprocessing.deepsvg.deepsvg_config import config_hierarchical_ordered
from src.preprocessing.deepsvg.deepsvg_utils import train_utils
from src.preprocessing.deepsvg.deepsvg_utils import utils
from src.preprocessing.deepsvg.deepsvg_dataloader import svg_dataset
# ---- Methods for embedding logos ----
def compute_embedding_folder(folder_path: str, model_path: str, save: str = None) -> pd.DataFrame:
data_list = []
for file in os.listdir(folder_path):
print('File: ' + file)
try:
embedding = compute_embedding(os.path.join(folder_path, file), model_path)
embedding['filename'] = file
data_list.append(embedding)
except:
print('Embedding failed')
print('Concatenating')
data = pd.concat(data_list)
if not save == None:
output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
pickle.dump(data, output)
output.close()
return data
def compute_embedding(path: str, model_path: str, save: str = None) -> pd.DataFrame:
# Convert all primitives to SVG paths - TODO text
paths, attributes, svg_attributes = svg2paths2(path) # In previous project, this is performed at the end
wsvg(paths, attributes=attributes, svg_attributes=svg_attributes, filename=path)
svg = SVG.load_svg(path)
svg.normalize() # Using DeepSVG normalize instead of expanding viewbox - TODO check is this equal?
svg_str = svg.to_str()
# Assign animation id to every path - TODO this changes the original logo!
document = minidom.parseString(svg_str)
paths = document.getElementsByTagName('path')
for i in range(len(paths)):
paths[i].setAttribute('animation_id', str(i))
with open(path, 'wb') as svg_file:
svg_file.write(document.toxml(encoding='iso-8859-1'))
# Decompose SVGs
decomposed_svgs = {}
for i in range(len(paths)):
doc_temp = copy.deepcopy(document)
paths_temp = doc_temp.getElementsByTagName('path')
current_path = paths_temp[i]
# Iteratively choose path i and remove all others
remove_temp = paths_temp[:i] + paths_temp[i+1:]
for path in remove_temp:
if not path.parentNode.nodeName == 'clipPath':
path.parentNode.removeChild(path)
# Check for style attributes; add in case there are none
if len(current_path.getAttribute('style')) <= 0:
current_path.setAttribute('stroke', 'black')
current_path.setAttribute('stroke-width', '2')
id = current_path.getAttribute('animation_id')
decomposed_svgs[id] = doc_temp.toprettyxml(encoding='iso-8859-1')
doc_temp.unlink()
#print(decomposed_svgs)
meta = {}
for id in decomposed_svgs:
svg_d_str = decomposed_svgs[id]
# Load into SVG and canonicalize
current_svg = SVG.from_str(svg_d_str)
# Canonicalize
current_svg.canonicalize() # Applies DeepSVG canonicalize; previously custom methods were used
decomposed_svgs[id] = current_svg.to_str()
if not os.path.exists('data/temp_svg'):
os.mkdir('data/temp_svg')
with open(('data/temp_svg/path_' + str(id)) + '.svg', 'w') as svg_file:
svg_file.write(decomposed_svgs[id])
# Collect metadata
len_groups = [path_group.total_len() for path_group in current_svg.svg_path_groups]
start_pos = [path_group.svg_paths[0].start_pos for path_group in current_svg.svg_path_groups]
try:
total_len = sum(len_groups)
nb_groups = len(len_groups)
max_len_group = max(len_groups)
except:
total_len = 0
nb_groups = 0
max_len_group = 0
meta[id] = {
'id': id,
'total_len': total_len,
'nb_groups': nb_groups,
'len_groups': len_groups,
'max_len_group': max_len_group,
'start_pos': start_pos
}
metadata = pd.DataFrame(meta.values())
#print(metadata)
if not os.path.exists('data/metadata'):
os.mkdir('data/metadata')
metadata.to_csv('data/metadata/metadata.csv', index=False)
# Load pretrained DeepSVG model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cfg = config_hierarchical_ordered.Config()
model = cfg.make_model().to(device)
train_utils.load_model(model_path, model)
model.eval()
# Load dataset
cfg.data_dir = 'data/temp_svg/'
cfg.meta_filepath = 'data/metadata/metadata.csv'
dataset = svg_dataset.load_dataset(cfg)
svg_files = glob.glob('data/temp_svg/*.svg')
#print(svg_files)
svg_list = []
for svg_file in svg_files:
id = svg_file.split('\\')[1].split('_')[1].split('.')[0]
# Preprocessing
svg = SVG.load_svg(svg_file)
svg = dataset.simplify(svg)
svg = dataset.preprocess(svg, augment=False)
data = dataset.get(svg=svg)
# Get embedding
model_args = utils.batchify((data[key] for key in cfg.model_args), device)
with torch.no_grad():
z = model(*model_args, encode_mode=True).cpu().numpy()[0][0][0]
dict_data = {
'animation_id': id,
'embedding': z
}
svg_list.append(dict_data)
data = pd.DataFrame.from_records(svg_list, index='animation_id')['embedding'].apply(pd.Series)
data.reset_index(level=0, inplace=True)
data.dropna(inplace=True)
data.reset_index(drop=True, inplace=True)
if not save == None:
output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb')
pickle.dump(data, output)
output.close()
print('Embedding computed')
return data
#compute_embedding_folder('data/raw_dataset', 'src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar', 'data/embedding') |