Daniel Gil-U Fuhge
add model files
e17e8cc
raw
history blame
20.1 kB
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""
from __future__ import annotations
from .geom import *
from xml.dom import expatbuilder
import torch
from typing import List, Union
import IPython.display as ipd
import cairosvg
from PIL import Image
import io
import os
from moviepy.editor import ImageClip, concatenate_videoclips, ipython_display
import math
import random
import networkx as nx
Num = Union[int, float]
from .svg_command import SVGCommandBezier
from .svg_path import SVGPath, Filling, Orientation
from .svg_primitive import SVGPathGroup, SVGRectangle, SVGCircle, SVGEllipse, SVGLine, SVGPolyline, SVGPolygon
from .geom import union_bbox
class SVG:
def __init__(self, svg_path_groups: List[SVGPathGroup], viewbox: Bbox = None):
if viewbox is None:
viewbox = Bbox(24)
self.svg_path_groups = svg_path_groups
self.viewbox = viewbox
def __add__(self, other: SVG):
svg = self.copy()
svg.svg_path_groups.extend(other.svg_path_groups)
return svg
@property
def paths(self):
for path_group in self.svg_path_groups:
for path in path_group.svg_paths:
yield path
def __getitem__(self, idx):
if isinstance(idx, tuple):
assert len(idx) == 2, "Dimension out of range"
i, j = idx
return self.svg_path_groups[i][j]
return self.svg_path_groups[idx]
def __len__(self):
return len(self.svg_path_groups)
def total_length(self):
return sum([path_group.total_len() for path_group in self.svg_path_groups])
@property
def start_pos(self):
return Point(0.)
@property
def end_pos(self):
if not self.svg_path_groups:
return Point(0.)
return self.svg_path_groups[-1].end_pos
def copy(self):
return SVG([svg_path_group.copy() for svg_path_group in self.svg_path_groups], self.viewbox.copy())
@staticmethod
def load_svg(file_path):
with open(file_path, "r") as f:
return SVG.from_str(f.read())
@staticmethod
def load_splineset(spline_str: str, width, height, add_closing=True):
if "SplineSet" not in spline_str:
raise ValueError("Not a SplineSet")
spline = spline_str[spline_str.index('SplineSet') + 10:spline_str.index('EndSplineSet')]
svg_str = SVG._spline_to_svg_str(spline, height)
if not svg_str:
raise ValueError("Empty SplineSet")
svg_path_group = SVGPath.from_str(svg_str, add_closing=add_closing)
return SVG([svg_path_group], viewbox=Bbox(width, height))
@staticmethod
def _spline_to_svg_str(spline_str: str, height, replace_with_prev=False):
path = []
prev_xy = []
for line in spline_str.splitlines():
if not line:
continue
tokens = line.split(' ')
cmd = tokens[-2]
if cmd not in 'cml':
raise ValueError(f"Command not recognized: {cmd}")
args = tokens[:-2]
args = [float(x) for x in args if x]
if replace_with_prev and cmd in 'c':
args[:2] = prev_xy
prev_xy = args[-2:]
new_y_args = []
for i, a in enumerate(args):
if i % 2 == 1:
new_y_args.append(str(height - a))
else:
new_y_args.append(str(a))
path.extend([cmd.upper()] + new_y_args)
return " ".join(path)
@staticmethod
def from_str(svg_str: str):
svg_path_groups = []
svg_dom = expatbuilder.parseString(svg_str, False)
svg_root = svg_dom.getElementsByTagName('svg')[0]
viewbox_list = list(map(float, svg_root.getAttribute("viewBox").split(" ")))
view_box = Bbox(*viewbox_list)
primitives = {
"path": SVGPath,
"rect": SVGRectangle,
"circle": SVGCircle, "ellipse": SVGEllipse,
"line": SVGLine,
"polyline": SVGPolyline, "polygon": SVGPolygon
}
for tag, Primitive in primitives.items():
for x in svg_dom.getElementsByTagName(tag):
svg_path_groups.append(Primitive.from_xml(x))
return SVG(svg_path_groups, view_box)
def to_tensor(self, concat_groups=True, PAD_VAL=-1):
group_tensors = [p.to_tensor(PAD_VAL=PAD_VAL) for p in self.svg_path_groups]
if concat_groups:
return torch.cat(group_tensors, dim=0)
return group_tensors
def to_fillings(self):
return [p.path.filling for p in self.svg_path_groups]
@staticmethod
def from_tensor(tensor: torch.Tensor, viewbox: Bbox = None, allow_empty=False):
if viewbox is None:
viewbox = Bbox(24)
svg = SVG([SVGPath.from_tensor(tensor, allow_empty=allow_empty)], viewbox=viewbox)
return svg
@staticmethod
def from_tensors(tensors: List[torch.Tensor], viewbox: Bbox = None, allow_empty=False):
if viewbox is None:
viewbox = Bbox(24)
svg = SVG([SVGPath.from_tensor(t, allow_empty=allow_empty) for t in tensors], viewbox=viewbox)
return svg
def save_svg(self, file_path):
with open(file_path, "w") as f:
f.write(self.to_str())
def save_png(self, file_path):
cairosvg.svg2png(bytestring=self.to_str(), write_to=file_path)
def draw(self, fill=False, file_path=None, do_display=True, return_png=False,
with_points=False, with_handles=False, with_bboxes=False, with_markers=False, color_firstlast=False,
with_moves=True):
if file_path is not None:
_, file_extension = os.path.splitext(file_path)
if file_extension == ".svg":
self.save_svg(file_path)
elif file_extension == ".png":
self.save_png(file_path)
else:
raise ValueError(f"Unsupported file_path extension {file_extension}")
svg_str = self.to_str(fill=fill, with_points=with_points, with_handles=with_handles, with_bboxes=with_bboxes,
with_markers=with_markers, color_firstlast=color_firstlast, with_moves=with_moves)
if do_display:
ipd.display(ipd.SVG(svg_str))
if return_png:
if file_path is None:
img_data = cairosvg.svg2png(bytestring=svg_str)
return Image.open(io.BytesIO(img_data))
else:
_, file_extension = os.path.splitext(file_path)
if file_extension == ".svg":
img_data = cairosvg.svg2png(url=file_path)
return Image.open(io.BytesIO(img_data))
else:
return Image.open(file_path)
def draw_colored(self, *args, **kwargs):
self.copy().normalize().split_paths().set_color("random").draw(*args, **kwargs)
def __repr__(self):
return "SVG[{}](\n{}\n)".format(self.viewbox,
",\n".join([f"\t{svg_path_group}" for svg_path_group in self.svg_path_groups]))
def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=False,
with_moves=True):
viz_elements = []
for svg_path_group in self.svg_path_groups:
viz_elements.extend(
svg_path_group._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves))
return viz_elements
def _markers(self):
return ('<defs>'
'<marker id="arrow" viewBox="0 0 10 10" markerWidth="4" markerHeight="4" refX="0" refY="3" orient="auto" markerUnits="strokeWidth">'
'<path d="M0,0 L0,6 L9,3 z" fill="#f00" />'
'</marker>'
'</defs>')
def to_str(self, fill=False, with_points=False, with_handles=False, with_bboxes=False, with_markers=False,
color_firstlast=False, with_moves=True) -> str:
viz_elements = self._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves)
newline = "\n"
return (
f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="{self.viewbox.to_str()}" height="200px" width="200px">'
f'{self._markers() if with_markers else ""}'
f'{newline.join(svg_path_group.to_str(fill=fill, with_markers=with_markers) for svg_path_group in [*self.svg_path_groups, *viz_elements])}'
'</svg>')
def _apply_to_paths(self, method, *args, **kwargs):
for path_group in self.svg_path_groups:
getattr(path_group, method)(*args, **kwargs)
return self
def split_paths(self):
path_groups = []
for path_group in self.svg_path_groups:
path_groups.extend(path_group.split_paths())
self.svg_path_groups = path_groups
return self
def merge_groups(self):
path_group = self.svg_path_groups[0]
for path_group in self.svg_path_groups[1:]:
path_group.svg_paths.extend(path_group.svg_paths)
self.svg_path_groups = [path_group]
return self
def empty(self):
return len(self.svg_path_groups) == 0
def drop_z(self):
return self._apply_to_paths("drop_z")
def filter_empty(self):
self._apply_to_paths("filter_empty")
self.svg_path_groups = [path_group for path_group in self.svg_path_groups if path_group.svg_paths]
return self
def translate(self, vec: Point):
return self._apply_to_paths("translate", vec)
def rotate(self, angle: Angle, center: Point = None):
if center is None:
center = self.viewbox.center
self.translate(-self.viewbox.center)
self._apply_to_paths("rotate", angle)
self.translate(center)
return self
def zoom(self, factor, center: Point = None):
if center is None:
center = self.viewbox.center
self.translate(-self.viewbox.center)
self._apply_to_paths("scale", factor)
self.translate(center)
return self
def normalize(self, viewbox: Bbox = None):
if viewbox is None:
viewbox = Bbox(24)
size = self.viewbox.size
scale_factor = viewbox.size.min() / size.max()
self.zoom(scale_factor, viewbox.center)
self.viewbox = viewbox
return self
def compute_filling(self):
return self._apply_to_paths("compute_filling")
def recompute_origins(self):
origin = self.start_pos
for path_group in self.svg_path_groups:
path_group.set_origin(origin.copy())
origin = path_group.end_pos
def canonicalize_new(self, normalize=False):
self.to_path().simplify_arcs()
self.compute_filling()
if normalize:
self.normalize()
self.split_paths()
self.filter_consecutives()
self.filter_empty()
self._apply_to_paths("reorder")
self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
self._apply_to_paths("canonicalize")
self.recompute_origins()
self.drop_z()
return self
def canonicalize(self, normalize=False):
self.to_path().simplify_arcs()
if normalize:
self.normalize()
self.split_paths()
self.filter_consecutives()
self.filter_empty()
self._apply_to_paths("reorder")
self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
self._apply_to_paths("canonicalize")
self.recompute_origins()
self.drop_z()
return self
def reorder(self):
return self._apply_to_paths("reorder")
def canonicalize_old(self):
self.filter_empty()
self._apply_to_paths("reorder")
self.svg_path_groups = sorted(self.svg_path_groups, key=lambda x: x.start_pos.tolist()[::-1])
self._apply_to_paths("canonicalize")
self.split_paths()
self.recompute_origins()
self.drop_z()
return self
def to_video(self, wrapper, color="grey"):
clips, svg_commands = [], []
im = SVG([]).draw(do_display=False, return_png=True)
clips.append(wrapper(np.array(im)))
for svg_path in self.paths:
clips, svg_commands = svg_path.to_video(wrapper, clips, svg_commands, color=color)
im = self.draw(do_display=False, return_png=True)
clips.append(wrapper(np.array(im)))
return clips
def animate(self, file_path=None, frame_duration=0.1, do_display=True):
clips = self.to_video(lambda img: ImageClip(img).set_duration(frame_duration))
clip = concatenate_videoclips(clips, method="compose", bg_color=(255, 255, 255))
if file_path is not None:
clip.write_gif(file_path, fps=24, verbose=False, logger=None)
if do_display:
src = clip if file_path is None else file_path
ipd.display(ipython_display(src, fps=24, rd_kwargs=dict(logger=None), autoplay=1, loop=1))
def numericalize(self, n=256):
self.normalize(viewbox=Bbox(n))
return self._apply_to_paths("numericalize", n)
def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False):
self._apply_to_paths("simplify", tolerance=tolerance, epsilon=epsilon, angle_threshold=angle_threshold,
force_smooth=force_smooth)
self.recompute_origins()
return self
def reverse(self):
self._apply_to_paths("reverse")
return self
def reverse_non_closed(self):
self._apply_to_paths("reverse_non_closed")
return self
def duplicate_extremities(self):
self._apply_to_paths("duplicate_extremities")
return self
def simplify_heuristic(self, tolerance=0.1, force_smooth=False):
return self.copy().split(max_dist=2, include_lines=False) \
.simplify(tolerance=tolerance, epsilon=0.2, angle_threshold=150, force_smooth=force_smooth) \
.split(max_dist=7.5)
def simplify_heuristic2(self):
return self.copy().split(max_dist=2, include_lines=False) \
.simplify(tolerance=0.2, epsilon=0.2, angle_threshold=150) \
.split(max_dist=7.5)
def split(self, n=None, max_dist=None, include_lines=True):
return self._apply_to_paths("split", n=n, max_dist=max_dist, include_lines=include_lines)
@staticmethod
def unit_circle():
d = 2 * (math.sqrt(2) - 1) / 3
circle = SVGPath([
SVGCommandBezier(Point(.5, 0.), Point(.5 + d, 0.), Point(1., .5 - d), Point(1., .5)),
SVGCommandBezier(Point(1., .5), Point(1., .5 + d), Point(.5 + d, 1.), Point(.5, 1.)),
SVGCommandBezier(Point(.5, 1.), Point(.5 - d, 1.), Point(0., .5 + d), Point(0., .5)),
SVGCommandBezier(Point(0., .5), Point(0., .5 - d), Point(.5 - d, 0.), Point(.5, 0.))
]).to_group()
return SVG([circle], viewbox=Bbox(1))
@staticmethod
def unit_square():
square = SVGPath.from_str("m 0,0 h1 v1 h-1 v-1")
return SVG([square], viewbox=Bbox(1))
def add_path_group(self, path_group: SVGPathGroup):
path_group.set_origin(self.end_pos.copy())
self.svg_path_groups.append(path_group)
return self
def add_path_groups(self, path_groups: List[SVGPathGroup]):
for path_group in path_groups:
self.add_path_group(path_group)
return self
def simplify_arcs(self):
return self._apply_to_paths("simplify_arcs")
def to_path(self):
for i, path_group in enumerate(self.svg_path_groups):
self.svg_path_groups[i] = path_group.to_path()
return self
def filter_consecutives(self):
return self._apply_to_paths("filter_consecutives")
def filter_duplicates(self):
return self._apply_to_paths("filter_duplicates")
def set_color(self, color):
colors = ["deepskyblue", "lime", "deeppink", "gold", "coral", "darkviolet", "royalblue", "darkmagenta", "teal",
"gold",
"green", "maroon", "aqua", "grey", "steelblue", "lime", "orange"]
if color == "random_random":
random.shuffle(colors)
if isinstance(color, list):
colors = color
for i, path_group in enumerate(self.svg_path_groups):
if color == "random" or color == "random_random" or isinstance(color, list):
c = colors[i % len(colors)]
else:
c = color
path_group.color = c
return self
def bbox(self):
return union_bbox([path_group.bbox() for path_group in self.svg_path_groups])
def overlap_graph(self, threshold=0.95, draw=False):
G = nx.DiGraph()
shapes = [group.to_shapely() for group in self.svg_path_groups]
for i, group1 in enumerate(shapes):
G.add_node(i)
if self.svg_path_groups[i].path.filling != Filling.OUTLINE:
for j, group2 in enumerate(shapes):
if i != j and self.svg_path_groups[j].path.filling == Filling.FILL:
overlap = group1.intersection(group2).area / group1.area
if overlap > threshold:
G.add_edge(j, i, weight=overlap)
if draw:
pos = nx.spring_layout(G)
nx.draw_networkx(G, pos, with_labels=True)
labels = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
return G
def group_overlapping_paths(self):
G = self.overlap_graph()
path_groups = []
root_nodes = [i for i, d in G.in_degree() if d == 0]
for root in root_nodes:
if self[root].path.filling == Filling.FILL:
current = [root]
while current:
n = current.pop(0)
fill_neighbors, erase_neighbors = [], []
for m in G.neighbors(n):
if G.in_degree(m) == 1:
if self[m].path.filling == Filling.ERASE:
erase_neighbors.append(m)
else:
fill_neighbors.append(m)
G.remove_node(n)
path_group = SVGPathGroup([self[n].path.copy().set_orientation(Orientation.CLOCKWISE)], fill=True)
if erase_neighbors:
for n in erase_neighbors:
neighbor = self[n].path.copy().set_orientation(Orientation.COUNTER_CLOCKWISE)
path_group.append(neighbor)
G.remove_nodes_from(erase_neighbors)
path_groups.append(path_group)
current.extend(fill_neighbors)
# Add outlines in the end
for path_group in self.svg_path_groups:
if path_group.path.filling == Filling.OUTLINE:
path_groups.append(path_group)
return SVG(path_groups)
def to_points(self, sort=True):
points = np.concatenate([path_group.to_points() for path_group in self.svg_path_groups])
if sort:
ind = np.lexsort((points[:, 0], points[:, 1]))
points = points[ind]
# Remove duplicates
row_mask = np.append([True], np.any(np.diff(points, axis=0), 1))
points = points[row_mask]
return points
def permute(self, indices=None):
if indices is not None:
self.svg_path_groups = [self.svg_path_groups[i] for i in indices]
return self
def fill_(self, fill=True):
return self._apply_to_paths("fill_", fill)