latex2im / app.py
yuntian-deng's picture
Update app.py
48fd5b5 verified
raw
history blame
5.33 kB
import spaces
import gradio as gr
import numpy as np
import requests
import base64
import os
from datetime import datetime
from pytz import timezone
import torch
import diffusers
from diffusers import DDPMPipeline
from transformers import AutoTokenizer, AutoModel
tz = timezone('EST')
API_ENDPOINT = os.getenv('API_ENDPOINT')
API_KEY = os.getenv('API_KEY')
print (API_ENDPOINT)
print (API_KEY)
title = "<h1><center>Markup-to-Image Diffusion Models with Scheduled Sampling</center></h1>"
authors = "<center>Yuntian Deng, Noriyuki Kojima, Alexander M. Rush</center>"
info = '<center><a href="https://openreview.net/pdf?id=81VJDmOE2ol">Paper</a> <a href="https://github.com/da03/markup2im">Code</a></center>'
#notice = "<p><center><strong>Notice:</strong> Due to resource constraints, we've transitioned from GPU to CPU processing for this demo, which results in significantly longer inference times. We appreciate your understanding.</center></p>"
notice = "<p><center>Acknowledgment: This demo is powered by GPU resources supported by the Hugging Face Community Grant.</center></p>"
# setup
def setup():
device = ("cuda" if torch.cuda.is_available() else "cpu")
img_pipe = DDPMPipeline.from_pretrained("yuntian-deng/latex2im_ss_finetunegptneo")
img_pipe.to(device)
model_type = "EleutherAI/gpt-neo-125M"
#encoder = AutoModel.from_pretrained(model_type).to(device)
encoder = img_pipe.unet.text_encoder
if False:
l = len(img_pipe.unet.down_blocks)
for i in range(l):
img_pipe.unet.down_blocks[i] = torch.compile(img_pipe.unet.down_blocks[i])
l = len(img_pipe.unet.up_blocks)
for i in range(l):
img_pipe.unet.up_blocks[i] = torch.compile(img_pipe.unet.up_blocks[i])
tokenizer = AutoTokenizer.from_pretrained(model_type, max_length=1024)
eos_id = tokenizer.encode(tokenizer.eos_token)[0]
@spaces.GPU
def forward_encoder(latex):
encoded = tokenizer(latex, return_tensors='pt', truncation=True, max_length=1024)
input_ids = encoded['input_ids']
input_ids = torch.cat((input_ids, torch.LongTensor([eos_id,]).unsqueeze(0)), dim=-1)
input_ids = input_ids.to(device)
attention_mask = encoded['attention_mask']
attention_mask = torch.cat((attention_mask, torch.LongTensor([1,]).unsqueeze(0)), dim=-1)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = encoder(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = attention_mask.unsqueeze(-1) * last_hidden_state # shouldn't be necessary
return last_hidden_state
return img_pipe, forward_encoder
img_pipe, forward_encoder = setup()
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(authors)
gr.Markdown(info)
gr.Markdown(notice)
with gr.Row():
with gr.Column(scale=2):
textbox = gr.Textbox(label=r'Type LaTeX formula below and click "Generate"', lines=1, max_lines=1, placeholder='Type LaTeX formula here and click "Generate"', value=r'\sum_{t=1}^T\E_{y_t \sim {\tilde P(y_t| y_0)}} \left\| \frac{y_t - \sqrt{\bar{\alpha}_t}y_0}{\sqrt{1-\bar{\alpha}_t}} - \epsilon_\theta(y_t, t)\right\|^2.')
submit_btn = gr.Button("Generate", elem_id="btn")
with gr.Column(scale=3):
slider = gr.Slider(0, 1000, value=0, label='step (out of 1000)')
image = gr.Image(label="Rendered Image", show_label=False, elem_id="image")
inputs = [textbox]
outputs = [slider, image, submit_btn]
def infer(formula):
current_time = datetime.now(tz)
print (current_time, formula)
data = {'formula': formula, 'api_key': API_KEY}
latex = formula # TODO: normalize
encoder_hidden_states = forward_encoder(latex)
try:
i = 0
results = []
for _, image_clean in img_pipe.run_clean(batch_size=1, generator=torch.manual_seed(0), encoder_hidden_states=encoder_hidden_states, output_type="numpy"):
i += 1
image_clean = image_clean[0]
image_clean = np.ascontiguousarray(image_clean)
#s = base64.b64encode(image_clean).decode('ascii')
#yield s
q = image_clean
q = q.reshape((64, 320, 3))
#print (q.min(), q.max())
yield i, q, submit_btn.update(visible=False)
yield i, q, submit_btn.update(visible=True)
#with requests.post(url=API_ENDPOINT, data=data, timeout=600, stream=True) as r:
# i = 0
# for line in r.iter_lines():
# response = line.decode('ascii').strip()
# r = base64.decodebytes(response.encode('ascii'))
# q = np.frombuffer(r, dtype=np.float32).reshape((64, 320, 3))
# i += 1
# yield i, q, submit_btn.update(visible=False)
# yield i, q, submit_btn.update(visible=True)
except Exception as e:
yield 1000, 255*np.ones((64, 320, 3)), submit_btn.update(visible=True)
submit_btn.click(fn=infer, inputs=inputs, outputs=outputs)
demo.queue(concurrency_count=1, max_size=20).launch(enable_queue=True)