Charlie Li
add full page
a8f430e
raw
history blame
8.66 kB
import gradio as gr
import os
import random
import datetime
from utils import *
file_url = "https://storage.googleapis.com/derendering_model/derendering_supp.zip"
filename = "derendering_supp.zip"
download_file(file_url, filename)
unzip_file(filename)
print("Downloaded and unzipped the file.")
diagram = get_svg_content("derendering_supp/derender_diagram.svg")
org = get_svg_content("org/cor.svg")
org_content = f"{org}"
gif_filenames = [
"christians.gif",
"good.gif",
"october.gif",
"welcome.gif",
"you.gif",
"letter.gif",
]
captions = [
"CHRISTIANS",
"Good",
"October",
"WELOME",
"you",
"letter",
]
gif_base64_strings = {
caption: get_base64_encoded_gif(f"gifs/{name}")
for caption, name in zip(captions, gif_filenames)
}
sketches = [
"bird.gif",
"cat.gif",
"coffee.gif",
"penguin.gif",
]
sketches_base64_strings = {
name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches
}
def demo(Dataset, Model, Output_Format):
if Model == "Small-i":
inkml_path = f"./derendering_supp/small-i_{Dataset}_inkml"
elif Model == "Small-p":
inkml_path = f"./derendering_supp/small-p_{Dataset}_inkml"
elif Model == "Large-i":
inkml_path = f"./derendering_supp/large-i_{Dataset}_inkml"
now = datetime.datetime.now()
random.seed(now.timestamp())
now = now.strftime("%Y-%m-%d %H:%M:%S")
print(
now,
"Taking sample from dataset:",
Dataset,
"and model:",
Model,
"with output format:",
Output_Format,
)
path = f"./derendering_supp/{Dataset}/images_sample"
samples = os.listdir(path)
# Randomly pick a sample
picked_samples = random.sample(samples, min(1, len(samples)))
query_modes = ["d+t", "r+d", "vanilla"]
plot_title = {"r+d": "Recognized: ", "d+t": "OCR Input: ", "vanilla": ""}
text_outputs = []
img_outputs = []
video_outputs = []
for name in picked_samples:
img_path = os.path.join(path, name)
img = load_and_pad_img_dir(img_path)
for mode in query_modes:
example_id = name.strip(".png")
inkml_file = os.path.join(inkml_path, mode, example_id + ".inkml")
text_field = parse_inkml_annotations(inkml_file)["textField"]
output_text = f"{plot_title[mode]}{text_field}"
# Text output for three modes
# d+t: OCR recognition input to the model
# r+d: Recognition from the model
# vanilla: None
text_outputs.append(output_text)
ink = inkml_to_ink(inkml_file)
if Output_Format == "Image+Video":
video_filename = mode + ".mp4"
plot_ink_to_video(ink, video_filename, input_image=img)
video_outputs.append(video_filename)
else:
video_outputs.append(None)
fig, ax = plt.subplots()
ax.axis("off")
plot_ink(ink, ax, input_image=img)
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
plt.close(fig)
buf.seek(0)
res = Image.open(buf)
img_outputs.append(res)
return (
img,
text_outputs[0],
img_outputs[0],
video_outputs[0],
text_outputs[1],
img_outputs[1],
video_outputs[1],
text_outputs[2],
img_outputs[2],
video_outputs[2],
)
with gr.Blocks() as app:
gr.HTML(org_content)
gr.Markdown(
"# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write"
)
gr.HTML(
"""
<div style="display: flex; align-items: center; margin-bottom: 20px;">
<a href="https://arxiv.org/abs/2402.05804" target="_blank" style="font-size: 16px; background-color: #4CAF50; color: white; padding: 5px 7px; text-decoration: none; border-radius: 2px;">
πŸ“„ Read the Paper
</a>
</div>
"""
)
gr.HTML(f"<div style='margin: 20px 0;'>{diagram}</div>")
gr.Markdown(
"""
πŸš€ This demo highlights the capabilities of Small-i, Small-p, and Large-i across three public datasets (word-level, with 100 random samples each).<br>
🎲 Select a model variant and dataset (IAM, IMGUR5K, HierText), then hit 'Sample' to view a randomly selected input alongside its corresponding outputs for all three types of inference.<br>
πŸ–ΌοΈ Output options: Image or Image+Video. Opting for images yields quicker results, adding videos offers a dynamic view of the digital ink writing process.<br>
"""
)
with gr.Row():
dataset = gr.Dropdown(
["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM"
)
model = gr.Dropdown(
["Small-i", "Large-i", "Small-p"],
label="InkSight Model Variant",
value="Small-i",
)
output_format = gr.Dropdown(
["Image", "Image+Video"], label="Output Format", value="Image"
)
im = gr.Image(label="Input Image")
with gr.Row():
d_t_img = gr.Image(label="Derender with Text")
r_d_img = gr.Image(label="Recognize and Derender")
vanilla_img = gr.Image(label="Vanilla")
with gr.Row():
d_t_text = gr.Textbox(
label="OCR recognition input to the model", interactive=False
)
r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
gr.Markdown(
"To visualize the writing process in video, select *Output format* as **Image+Video**."
)
with gr.Row():
d_t_vid = gr.Video(
label="Derender with Text (Click to stop/play)", autoplay=True
)
r_d_vid = gr.Video(
label="Recognize and Derender (Click to stop/play)", autoplay=True
)
vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True)
with gr.Row():
btn_sub = gr.Button("Sample")
btn_sub.click(
fn=demo,
inputs=[dataset, model, output_format],
outputs=[
im,
d_t_text,
d_t_img,
d_t_vid,
r_d_text,
r_d_img,
r_d_vid,
vanilla_text,
vanilla_img,
vanilla_vid,
],
)
gr.Markdown("## More Word-level Samples")
html_content = """
<div style="display: flex; justify-content: space-around; flex-wrap: wrap; gap: 0px;">
"""
for caption, base64_string in gif_base64_strings.items():
title = caption
html_content += f"""
<div>
<img src="data:image/gif;base64,{base64_string}" alt="{title}" style="width: 100%; max-width: 200px;">
<p style="text-align: center;">{title}</p>
</div>
"""
html_content += "</div>"
gr.HTML(html_content)
# Sketches
gr.Markdown("## Sketch Samples")
html_content = """
<div style="display: flex; justify-content: space-around; flex-wrap: wrap; gap: 0px;">
"""
for _, base64_string in sketches_base64_strings.items():
html_content += f"""
<div>
<img src="data:image/gif;base64,{base64_string}" style="width: 100%; max-width: 200px;">
</div>
"""
html_content += "</div>"
gr.HTML(html_content)
gr.Markdown("## Scale Up to Full Page")
svg1_content = get_svg_content("full_page/danke.svg")
svg2_content = get_svg_content("full_page/multilingual_demo.svg")
svg3_content = get_svg_content("full_page/unsplash_frame.svg")
svg_html_template = """
<div style="display: block;">
<div>
<div style="margin-bottom: 10px;">{}</div>
<p style="text-align: center;">{}</p>
</div>
<div>
<div style="margin-bottom: 10px;">{}</div>
<p style="text-align: center;">{}</p>
</div>
<div>
<div style="margin-bottom: 10px;">{}</div>
<p style="text-align: center;">{}</p>
</div>
</div>
"""
full_svg_display = svg_html_template.format(
svg1_content,
'Writings on the beach. <a href="https://unsplash.com/photos/text-rG-PerMFjFA">Credit</a>',
svg2_content,
"Multilingual handwriting.",
svg3_content,
"Handwriting in a frame. <a href='https://unsplash.com/photos/white-wooden-framed-white-board-t7fLWMQl2Lw'>Credit</a>",
)
gr.HTML(full_svg_display)
app.launch()