Charlie Li
update page
17f8269
raw
history blame
8.38 kB
import json
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from xml.dom import minidom
import os
from PIL import Image
import matplotlib.animation as animation
import copy
from PIL import ImageEnhance
import colorsys
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from matplotlib.patheffects import withStroke
import random
import warnings
from matplotlib.figure import Figure
from io import BytesIO
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter
import requests
import zipfile
import base64
warnings.filterwarnings("ignore")
def get_svg_content(svg_path):
with open(svg_path, "r") as file:
return file.read()
def download_file(url, filename):
if os.path.exists(filename):
return
response = requests.get(url)
with open(filename, "wb") as f:
f.write(response.content)
def unzip_file(filename, extract_to="."):
with zipfile.ZipFile(filename, "r") as zip_ref:
zip_ref.extractall(extract_to)
def get_base64_encoded_gif(gif_path):
with open(gif_path, "rb") as gif_file:
return base64.b64encode(gif_file.read()).decode("utf-8")
def load_and_pad_img_dir(file_dir):
image_path = os.path.join(file_dir)
image = Image.open(image_path)
width, height = image.size
ratio = min(224 / width, 224 / height)
image = image.resize((int(width * ratio), int(height * ratio)))
width, height = image.size
if height < 224:
# If width is shorter than height pad top and bottom.
top_padding = (224 - height) // 2
bottom_padding = 224 - height - top_padding
padded_image = Image.new("RGB", (width, 224), (255, 255, 255))
padded_image.paste(image, (0, top_padding))
else:
# Otherwise pad left and right.
left_padding = (224 - width) // 2
right_padding = 224 - width - left_padding
padded_image = Image.new("RGB", (224, height), (255, 255, 255))
padded_image.paste(image, (left_padding, 0))
return padded_image
def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="white"):
if input_image is not None:
img = copy.deepcopy(input_image)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(0.45)
ax.imshow(img)
base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes))
for i, stroke in enumerate(ink.strokes):
x, y = np.array(stroke.x), np.array(stroke.y)
base_color = base_colors(len(ink.strokes) - 1 - i)
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
colors = [mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) for j in range(len(x))]
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = LineCollection(segments, colors=colors, linewidth=lw)
if with_path:
lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
ax.add_collection(lc)
ax.set_xlim(0, 224)
ax.set_ylim(0, 224)
ax.invert_yaxis()
def plot_ink_to_video(ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30):
fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
if input_image is not None:
img = copy.deepcopy(input_image)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(0.45)
ax.imshow(img)
ax.set_xlim(0, 224)
ax.set_ylim(0, 224)
ax.invert_yaxis()
ax.axis("off")
base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes))
all_points = sum([len(stroke.x) for stroke in ink.strokes], 0)
def update(frame):
ax.clear()
if input_image is not None:
ax.imshow(img)
ax.set_xlim(0, 224)
ax.set_ylim(0, 224)
ax.invert_yaxis()
ax.axis("off")
points_drawn = 0
for stroke_index, stroke in enumerate(ink.strokes):
x, y = np.array(stroke.x), np.array(stroke.y)
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
visible_segments = segments[: frame - points_drawn] if frame - points_drawn < len(segments) else segments
colors = [
mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(visible_segments)))
for j in range(len(visible_segments))
]
if len(visible_segments) > 0:
lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
ax.add_collection(lc)
points_drawn += len(segments)
if points_drawn >= frame:
break
ani = FuncAnimation(fig, update, frames=all_points + 1, blit=False)
Writer = FFMpegWriter(fps=fps)
plt.tight_layout()
ani.save(output_name, writer=Writer)
plt.close(fig)
class Stroke:
def __init__(self, list_of_coordinates=None) -> None:
self.x = []
self.y = []
if list_of_coordinates:
for point in list_of_coordinates:
self.x.append(point[0])
self.y.append(point[1])
def __len__(self):
return len(self.x)
def __getitem__(self, index):
return (self.x[index], self.y[index])
class Ink:
def __init__(self, list_of_strokes=None) -> None:
self.strokes = []
if list_of_strokes:
self.strokes = list_of_strokes
def __len__(self):
return len(self.strokes)
def __getitem__(self, index):
return self.strokes[index]
def inkml_to_ink(inkml_file):
"""Convert inkml file to Ink"""
tree = ET.parse(inkml_file)
root = tree.getroot()
inkml_namespace = {"inkml": "http://www.w3.org/2003/InkML"}
strokes = []
for trace in root.findall("inkml:trace", inkml_namespace):
points = trace.text.strip().split()
stroke_points = []
for point in points:
x, y = point.split(",")
stroke_points.append((float(x), float(y)))
strokes.append(Stroke(stroke_points))
return Ink(strokes)
def parse_inkml_annotations(inkml_file):
tree = ET.parse(inkml_file)
root = tree.getroot()
annotations = root.findall(".//{http://www.w3.org/2003/InkML}annotation")
annotation_dict = {}
for annotation in annotations:
annotation_type = annotation.get("type")
annotation_text = annotation.text
annotation_dict[annotation_type] = annotation_text
return annotation_dict
def pregenerate_videos(video_cache_dir):
datasets = ["IAM", "IMGUR5K", "HierText"]
models = ["Small-i", "Large-i", "Small-p"]
query_modes = ["d+t", "r+d", "vanilla"]
for Dataset in datasets:
for Model in models:
inkml_path_base = f"./derendering_supp/{Model.lower()}_{Dataset}_inkml"
for mode in query_modes:
path = f"./derendering_supp/{Dataset}/images_sample"
if not os.path.exists(path):
continue
samples = os.listdir(path)
for name in tqdm(samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"):
example_id = name.strip(".png")
inkml_file = os.path.join(inkml_path_base, mode, f"{example_id}.inkml")
if not os.path.exists(inkml_file):
continue
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
video_filepath = video_cache_dir / video_filename
if not video_filepath.exists():
img_path = os.path.join(path, name)
img = load_and_pad_img_dir(img_path)
ink = inkml_to_ink(inkml_file)
plot_ink_to_video(ink, str(video_filepath), input_image=img)