da03 commited on
Commit
de88b8c
·
1 Parent(s): ed21af3
Files changed (1) hide show
  1. app.py +27 -20
app.py CHANGED
@@ -7,24 +7,31 @@ import os
7
  API_ENDPOINT = os.getenv('API_ENDPOINT')
8
  API_KEY = os.getenv('API_KEY')
9
 
10
- # setup
11
- gallery = gr.Gallery(label="Rendered Image", show_label=False, elem_id="gallery").style(grid=[1], height="auto")
12
 
13
- # infer
14
- def infer(latex):
15
- formula = latex
16
- data = {'formula': formula, 'api_key': API_KEY}
17
- with requests.post(url=API_ENDPOINT, data=data, timeout=600, stream=True) as r:
18
- i = 0
19
- for line in r.iter_lines():
20
- response = line.decode('ascii').strip()
21
- r = base64.decodebytes(response.encode('ascii'))
22
- q = np.frombuffer(r, dtype=np.float32).reshape((64, 320, 3))
23
- i += 1
24
- yield i, [q,]
25
-
26
- title = "Markup-to-Image Diffusion Models with Scheduled Sampling"
27
- description="Yuntian Deng, Noriyuki Kojima, Alexander M. Rush"
28
-
29
- # launch
30
- gr.Interface(fn=infer, inputs=["text"], outputs=[gr.Slider(0, 1000, value=0, label='step (out of 1000)'), gallery],title=title,description=description).queue(concurrency_count=20, max_size=200).launch(enable_queue=True)
 
 
 
 
 
 
 
 
7
  API_ENDPOINT = os.getenv('API_ENDPOINT')
8
  API_KEY = os.getenv('API_KEY')
9
 
10
+ title = "<h1><center>Markup-to-Image Diffusion Models with Scheduled Sampling</center></h1>"
11
+ description = "<center>Yuntian Deng, Noriyuki Kojima, Alexander M. Rush</center>"
12
 
13
+ with gr.Blocks() as demo:
14
+ gr.Markdown(title)
15
+ gr.Markdown(description)
16
+ with gr.Row():
17
+ with gr.Column(scale=2):
18
+ textbox = gr.Textbox(label="LaTeX", lines=1, max_lines=1, placeholder='Type LaTeX formula here and click "Generate"')
19
+ submit_btn = gr.Button("Generate", elem_id="btn")
20
+ with gr.Column(scale=3):
21
+ slider = gr.Slider(0, 1000, value=0, label='step (out of 1000)')
22
+ image = gr.Image(label="Rendered Image", show_label=False, elem_id="image")
23
+ inputs = [textbox]
24
+ outputs = [slider, image, submit_btn]
25
+ def infer(formula):
26
+ data = {'formula': formula, 'api_key': API_KEY}
27
+ with requests.post(url=API_ENDPOINT, data=data, timeout=600, stream=True) as r:
28
+ i = 0
29
+ for line in r.iter_lines():
30
+ response = line.decode('ascii').strip()
31
+ r = base64.decodebytes(response.encode('ascii'))
32
+ q = np.frombuffer(r, dtype=np.float32).reshape((64, 320, 3))
33
+ i += 1
34
+ yield i, q, submit_btn.update(visible=False)
35
+ yield i, q, submit_btn.update(visible=True)
36
+ submit_btn.click(fn=infer, inputs=inputs, outputs=outputs)
37
+ demo.queue(concurrency_count=20, max_size=200).launch(enable_queue=True)