Spaces:
Running
Running
# Imports | |
import os | |
import copy | |
import torch | |
import glob | |
import pandas as pd | |
import streamlit as st | |
import pickle | |
from xml.dom import minidom | |
from svgpathtools import svg2paths2 | |
from svgpathtools import wsvg | |
import sys | |
sys.path.append(os.getcwd()) | |
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG | |
from src.preprocessing.deepsvg.deepsvg_config import config_hierarchical_ordered | |
from src.preprocessing.deepsvg.deepsvg_utils import train_utils | |
from src.preprocessing.deepsvg.deepsvg_utils import utils | |
from src.preprocessing.deepsvg.deepsvg_dataloader import svg_dataset | |
# ---- Methods for embedding logos ---- | |
def compute_embedding_folder(folder_path: str, model_path: str, save: str = None) -> pd.DataFrame: | |
data_list = [] | |
for file in os.listdir(folder_path): | |
print('File: ' + file) | |
try: | |
embedding = compute_embedding(os.path.join(folder_path, file), model_path) | |
embedding['filename'] = file | |
data_list.append(embedding) | |
except: | |
print('Embedding failed') | |
print('Concatenating') | |
data = pd.concat(data_list) | |
if not save == None: | |
output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb') | |
pickle.dump(data, output) | |
output.close() | |
return data | |
def compute_embedding(path: str, model_path: str, save: str = None) -> pd.DataFrame: | |
# Convert all primitives to SVG paths - TODO text | |
paths, attributes, svg_attributes = svg2paths2(path) # In previous project, this is performed at the end | |
wsvg(paths, attributes=attributes, svg_attributes=svg_attributes, filename=path) | |
svg = SVG.load_svg(path) | |
svg.normalize() # Using DeepSVG normalize instead of expanding viewbox - TODO check is this equal? | |
svg_str = svg.to_str() | |
# Assign animation id to every path - TODO this changes the original logo! | |
document = minidom.parseString(svg_str) | |
paths = document.getElementsByTagName('path') | |
for i in range(len(paths)): | |
paths[i].setAttribute('animation_id', str(i)) | |
with open(path, 'wb') as svg_file: | |
svg_file.write(document.toxml(encoding='iso-8859-1')) | |
# Decompose SVGs | |
decomposed_svgs = {} | |
for i in range(len(paths)): | |
st.write(document) | |
doc_temp = copy.deepcopy(document) | |
paths_temp = doc_temp.getElementsByTagName('path') | |
current_path = paths_temp[i] | |
# Iteratively choose path i and remove all others | |
remove_temp = paths_temp[:i] + paths_temp[i+1:] | |
for path in remove_temp: | |
if not path.parentNode.nodeName == 'clipPath': | |
path.parentNode.removeChild(path) | |
# Check for style attributes; add in case there are none | |
if len(current_path.getAttribute('style')) <= 0: | |
current_path.setAttribute('stroke', 'black') | |
current_path.setAttribute('stroke-width', '2') | |
id = current_path.getAttribute('animation_id') | |
decomposed_svgs[id] = doc_temp.toprettyxml(encoding='iso-8859-1') | |
doc_temp.unlink() | |
#print(decomposed_svgs) | |
meta = {} | |
for id in decomposed_svgs: | |
svg_d_str = decomposed_svgs[id] | |
# Load into SVG and canonicalize | |
current_svg = SVG.from_str(svg_d_str) | |
# Canonicalize | |
current_svg.canonicalize() # Applies DeepSVG canonicalize; previously custom methods were used | |
decomposed_svgs[id] = current_svg.to_str() | |
if not os.path.exists('data'): | |
os.mkdir('data') | |
if not os.path.exists('data/temp_svg'): | |
os.mkdir('data/temp_svg') | |
with open(('data/temp_svg/path_' + str(id)) + '.svg', 'w') as svg_file: | |
svg_file.write(decomposed_svgs[id]) | |
# Collect metadata | |
len_groups = [path_group.total_len() for path_group in current_svg.svg_path_groups] | |
start_pos = [path_group.svg_paths[0].start_pos for path_group in current_svg.svg_path_groups] | |
try: | |
total_len = sum(len_groups) | |
nb_groups = len(len_groups) | |
max_len_group = max(len_groups) | |
except: | |
total_len = 0 | |
nb_groups = 0 | |
max_len_group = 0 | |
meta[id] = { | |
'id': id, | |
'total_len': total_len, | |
'nb_groups': nb_groups, | |
'len_groups': len_groups, | |
'max_len_group': max_len_group, | |
'start_pos': start_pos | |
} | |
metadata = pd.DataFrame(meta.values()) | |
#print(metadata) | |
if not os.path.exists('data/metadata'): | |
os.mkdir('data/metadata') | |
metadata.to_csv('data/metadata/metadata.csv', index=False) | |
# Load pretrained DeepSVG model | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
cfg = config_hierarchical_ordered.Config() | |
model = cfg.make_model().to(device) | |
train_utils.load_model(model_path, model) | |
model.eval() | |
# Load dataset | |
cfg.data_dir = 'data/temp_svg/' | |
cfg.meta_filepath = 'data/metadata/metadata.csv' | |
dataset = svg_dataset.load_dataset(cfg) | |
svg_files = glob.glob('data/temp_svg/*.svg') | |
#print(svg_files) | |
svg_list = [] | |
for svg_file in svg_files: | |
id = svg_file.split('\\')[1].split('_')[1].split('.')[0] | |
# Preprocessing | |
svg = SVG.load_svg(svg_file) | |
svg = dataset.simplify(svg) | |
svg = dataset.preprocess(svg, augment=False) | |
data = dataset.get(svg=svg) | |
# Get embedding | |
model_args = utils.batchify((data[key] for key in cfg.model_args), device) | |
with torch.no_grad(): | |
z = model(*model_args, encode_mode=True).cpu().numpy()[0][0][0] | |
dict_data = { | |
'animation_id': id, | |
'embedding': z | |
} | |
svg_list.append(dict_data) | |
data = pd.DataFrame.from_records(svg_list, index='animation_id')['embedding'].apply(pd.Series) | |
data.reset_index(level=0, inplace=True) | |
data.dropna(inplace=True) | |
data.reset_index(drop=True, inplace=True) | |
if not save == None: | |
output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb') | |
pickle.dump(data, output) | |
output.close() | |
print('Embedding computed') | |
return data | |
#compute_embedding_folder('data/raw_dataset', 'src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar', 'data/embedding') |