"""This code is taken from by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte from the paper >https://arxiv.org/pdf/2007.11301.pdf> """ from __future__ import annotations from .geom import * import torch import re from typing import List, Union from xml.dom import minidom from .svg_path import SVGPath from .svg_command import SVGCommandLine, SVGCommandArc, SVGCommandBezier, SVGCommandClose import shapely as shapely import shapely.ops import shapely.geometry import networkx as nx FLOAT_RE = re.compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") def extract_args(args): return list(map(float, FLOAT_RE.findall(args))) class SVGPrimitive: """ Reference: https://developer.mozilla.org/en-US/docs/Web/SVG/Tutorial/Basic_Shapes """ def __init__(self, color="black", fill=False, dasharray=None, stroke_width=".3", opacity=1.0): self.color = color self.dasharray = dasharray self.stroke_width = stroke_width self.opacity = opacity self.fill = fill def _get_fill_attr(self): fill_attr = f'fill="{self.color}" fill-opacity="{self.opacity}"' if self.fill else f'fill="none" stroke="{self.color}" stroke-width="{self.stroke_width}" stroke-opacity="{self.opacity}"' if self.dasharray is not None and not self.fill: fill_attr += f' stroke-dasharray="{self.dasharray}"' return fill_attr @classmethod def from_xml(cls, x: minidom.Element): raise NotImplementedError def draw(self, viewbox=Bbox(24), *args, **kwargs): from .svg import SVG return SVG([self], viewbox=viewbox).draw(*args, **kwargs) def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=True, with_moves=True): return [] def to_path(self): raise NotImplementedError def copy(self): raise NotImplementedError def bbox(self): raise NotImplementedError def fill_(self, fill=True): self.fill = fill return self class SVGEllipse(SVGPrimitive): def __init__(self, center: Point, radius: Radius, *args, **kwargs): super().__init__(*args, **kwargs) self.center = center self.radius = radius def __repr__(self): return f'SVGEllipse(c={self.center} r={self.radius})' def to_str(self, *args, **kwargs): fill_attr = self._get_fill_attr() return f'' @classmethod def from_xml(_, x: minidom.Element): fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none" center = Point(float(x.getAttribute("cx")), float(x.getAttribute("cy"))) radius = Radius(float(x.getAttribute("rx")), float(x.getAttribute("ry"))) return SVGEllipse(center, radius, fill=fill) def to_path(self): p0, p1 = self.center + self.radius.xproj(), self.center + self.radius.yproj() p2, p3 = self.center - self.radius.xproj(), self.center - self.radius.yproj() commands = [ SVGCommandArc(p0, self.radius, Angle(0.), Flag(0.), Flag(1.), p1), SVGCommandArc(p1, self.radius, Angle(0.), Flag(0.), Flag(1.), p2), SVGCommandArc(p2, self.radius, Angle(0.), Flag(0.), Flag(1.), p3), SVGCommandArc(p3, self.radius, Angle(0.), Flag(0.), Flag(1.), p0), ] return SVGPath(commands, closed=True).to_group(fill=self.fill) class SVGCircle(SVGEllipse): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __repr__(self): return f'SVGCircle(c={self.center} r={self.radius})' def to_str(self, *args, **kwargs): fill_attr = self._get_fill_attr() return f'' @classmethod def from_xml(_, x: minidom.Element): fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none" center = Point(float(x.getAttribute("cx")), float(x.getAttribute("cy"))) radius = Radius(float(x.getAttribute("r"))) return SVGCircle(center, radius, fill=fill) class SVGRectangle(SVGPrimitive): def __init__(self, xy: Point, wh: Size, *args, **kwargs): super().__init__(*args, **kwargs) self.xy = xy self.wh = wh def __repr__(self): return f'SVGRectangle(xy={self.xy} wh={self.wh})' def to_str(self, *args, **kwargs): fill_attr = self._get_fill_attr() return f'' @classmethod def from_xml(_, x: minidom.Element): fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none" xy = Point(0.) if x.hasAttribute("x"): xy.pos[0] = float(x.getAttribute("x")) if x.hasAttribute("y"): xy.pos[1] = float(x.getAttribute("y")) wh = Size(float(x.getAttribute("width")), float(x.getAttribute("height"))) return SVGRectangle(xy, wh, fill=fill) def to_path(self): p0, p1, p2, p3 = self.xy, self.xy + self.wh.xproj(), self.xy + self.wh, self.xy + self.wh.yproj() commands = [ SVGCommandLine(p0, p1), SVGCommandLine(p1, p2), SVGCommandLine(p2, p3), SVGCommandLine(p3, p0) ] return SVGPath(commands, closed=True).to_group(fill=self.fill) class SVGLine(SVGPrimitive): def __init__(self, start_pos: Point, end_pos: Point, *args, **kwargs): super().__init__(*args, **kwargs) self.start_pos = start_pos self.end_pos = end_pos def __repr__(self): return f'SVGLine(xy1={self.start_pos} xy2={self.end_pos})' def to_str(self, *args, **kwargs): fill_attr = self._get_fill_attr() return f'' @classmethod def from_xml(_, x: minidom.Element): fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none" start_pos = Point(float(x.getAttribute("x1") or 0.), float(x.getAttribute("y1") or 0.)) end_pos = Point(float(x.getAttribute("x2") or 0.), float(x.getAttribute("y2") or 0.)) return SVGLine(start_pos, end_pos, fill=fill) def to_path(self): return SVGPath([SVGCommandLine(self.start_pos, self.end_pos)]).to_group(fill=self.fill) class SVGPolyline(SVGPrimitive): def __init__(self, points: List[Point], *args, **kwargs): super().__init__(*args, **kwargs) self.points = points def __repr__(self): return f'SVGPolyline(points={self.points})' def to_str(self, *args, **kwargs): fill_attr = self._get_fill_attr() return ''.format(fill_attr, ' '.join([p.to_str() for p in self.points])) @classmethod def from_xml(cls, x: minidom.Element): fill = not x.hasAttribute("fill") or not x.getAttribute("fill") == "none" args = extract_args(x.getAttribute("points")) assert len(args) % 2 == 0, f"Expected even number of arguments for SVGPolyline: {len(args)} given" points = [Point(x, args[2*i+1]) for i, x in enumerate(args[::2])] return cls(points, fill=fill) def to_path(self): commands = [SVGCommandLine(p1, p2) for p1, p2 in zip(self.points[:-1], self.points[1:])] is_closed = self.__class__.__name__ == "SVGPolygon" return SVGPath(commands, closed=is_closed).to_group(fill=self.fill) class SVGPolygon(SVGPolyline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __repr__(self): return f'SVGPolygon(points={self.points})' def to_str(self, *args, **kwargs): fill_attr = self._get_fill_attr() return ''.format(fill_attr, ' '.join([p.to_str() for p in self.points])) class SVGPathGroup(SVGPrimitive): def __init__(self, svg_paths: List[SVGPath] = None, origin=None, *args, **kwargs): super().__init__(*args, **kwargs) self.svg_paths = svg_paths if origin is None: origin = Point(0.) self.origin = origin # Alias @property def paths(self): return self.svg_paths @property def path(self): return self.svg_paths[0] def __getitem__(self, idx): return self.svg_paths[idx] def __len__(self): return len(self.paths) def total_len(self): return sum([len(path) for path in self.svg_paths]) @property def start_pos(self): return self.svg_paths[0].start_pos @property def end_pos(self): last_path = self.svg_paths[-1] if last_path.closed: return last_path.start_pos return last_path.end_pos def set_origin(self, origin: Point): self.origin = origin if self.svg_paths: self.svg_paths[0].origin = origin self.recompute_origins() def append(self, path: SVGPath): self.svg_paths.append(path) def copy(self): return SVGPathGroup([svg_path.copy() for svg_path in self.svg_paths], self.origin.copy(), self.color, self.fill, self.dasharray, self.stroke_width, self.opacity) def __repr__(self): return "SVGPathGroup({})".format(", ".join(svg_path.__repr__() for svg_path in self.svg_paths)) def _get_viz_elements(self, with_points=False, with_handles=False, with_bboxes=False, color_firstlast=True, with_moves=True): viz_elements = [] for svg_path in self.svg_paths: viz_elements.extend(svg_path._get_viz_elements(with_points, with_handles, with_bboxes, color_firstlast, with_moves)) if with_bboxes: viz_elements.append(self._get_bbox_viz()) return viz_elements def _get_bbox_viz(self): color = "red" if self.color == "black" else self.color bbox = self.bbox().to_rectangle(color=color) return bbox def to_path(self): return self def to_str(self, with_markers=False, *args, **kwargs): fill_attr = self._get_fill_attr() marker_attr = 'marker-start="url(#arrow)"' if with_markers else '' return ''.format(fill_attr, marker_attr, self.path.filling, " ".join(svg_path.to_str() for svg_path in self.svg_paths)) def to_tensor(self, PAD_VAL=-1): return torch.cat([p.to_tensor(PAD_VAL=PAD_VAL) for p in self.svg_paths], dim=0) def _apply_to_paths(self, method, *args, **kwargs): for path in self.svg_paths: getattr(path, method)(*args, **kwargs) return self def translate(self, vec): return self._apply_to_paths("translate", vec) def rotate(self, angle: Angle): return self._apply_to_paths("rotate", angle) def scale(self, factor): return self._apply_to_paths("scale", factor) def numericalize(self, n=256): return self._apply_to_paths("numericalize", n) def drop_z(self): return self._apply_to_paths("set_closed", False) def recompute_origins(self): origin = self.origin for path in self.svg_paths: path.origin = origin.copy() origin = path.end_pos return self def reorder(self): self._apply_to_paths("reorder") self.recompute_origins() return self def filter_empty(self): self.svg_paths = [path for path in self.svg_paths if path.path_commands] return self def canonicalize(self): self.svg_paths = sorted(self.svg_paths, key=lambda x: x.start_pos.tolist()[::-1]) if not self.svg_paths[0].is_clockwise(): self._apply_to_paths("reverse") self.recompute_origins() return self def reverse(self): self._apply_to_paths("reverse") self.recompute_origins() return self def duplicate_extremities(self): self._apply_to_paths("duplicate_extremities") return self def reverse_non_closed(self): self._apply_to_paths("reverse_non_closed") self.recompute_origins() return self def simplify(self, tolerance=0.1, epsilon=0.1, angle_threshold=179., force_smooth=False): self._apply_to_paths("simplify", tolerance=tolerance, epsilon=epsilon, angle_threshold=angle_threshold, force_smooth=force_smooth) self.recompute_origins() return self def split_paths(self): return [SVGPathGroup([svg_path], self.origin, self.color, self.fill, self.dasharray, self.stroke_width, self.opacity) for svg_path in self.svg_paths] def split(self, n=None, max_dist=None, include_lines=True): return self._apply_to_paths("split", n=n, max_dist=max_dist, include_lines=include_lines) def simplify_arcs(self): return self._apply_to_paths("simplify_arcs") def filter_consecutives(self): return self._apply_to_paths("filter_consecutives") def filter_duplicates(self): return self._apply_to_paths("filter_duplicates") def bbox(self): return union_bbox([path.bbox() for path in self.svg_paths]) def to_shapely(self): return shapely.ops.unary_union([path.to_shapely() for path in self.svg_paths]) def compute_filling(self): if self.fill: G = self.overlap_graph() root_nodes = [i for i, d in G.in_degree() if d == 0] for root in root_nodes: if not self.svg_paths[root].closed: continue current = [(1, root)] while current: visited = set() neighbors = set() for d, n in current: self.svg_paths[n].set_filling(d != 0) for n2 in G.neighbors(n): if not n2 in visited: d2 = d + (self.svg_paths[n2].is_clockwise() == self.svg_paths[n].is_clockwise()) * 2 - 1 visited.add(n2) neighbors.add((d2, n2)) G.remove_nodes_from([n for d, n in current]) current = [(d, n) for d, n in neighbors if G.in_degree(n) == 0] return self def overlap_graph(self, threshold=0.9, draw=False): G = nx.DiGraph() shapes = [path.to_shapely() for path in self.svg_paths] for i, path1 in enumerate(shapes): G.add_node(i) if self.svg_paths[i].closed: for j, path2 in enumerate(shapes): if i != j and self.svg_paths[j].closed: overlap = path1.intersection(path2).area / path1.area if overlap > threshold: G.add_edge(j, i, weight=overlap) if draw: pos = nx.spring_layout(G) nx.draw_networkx(G, pos, with_labels=True) labels = nx.get_edge_attributes(G, 'weight') nx.draw_networkx_edge_labels(G, pos, edge_labels=labels) return G def bbox_overlap(self, other: SVGPathGroup): return self.bbox().overlap(other.bbox()) def to_points(self): return np.concatenate([path.to_points() for path in self.svg_paths])