Spaces:
Build error
Build error
Ankur Goyal
commited on
Commit
·
253dc57
1
Parent(s):
d703b38
May remove the fields variable
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from PIL import Image, ImageDraw
|
|
6 |
import traceback
|
7 |
|
8 |
import gradio as gr
|
|
|
9 |
|
10 |
import torch
|
11 |
from docquery import pipeline
|
@@ -99,16 +100,17 @@ FIELDS = {
|
|
99 |
"Payment Terms": ["Payment Terms?"],
|
100 |
}
|
101 |
|
102 |
-
EMPTY_TABLE = dict(
|
103 |
-
headers=["Field", "Value"], value=[[name, None] for name in FIELDS.keys()]
|
104 |
-
)
|
105 |
|
|
|
|
|
106 |
|
107 |
-
|
|
|
108 |
if document is not None and error is None:
|
109 |
-
preview, json_output, table = process_fields(document, model)
|
110 |
return (
|
111 |
document,
|
|
|
112 |
preview,
|
113 |
gr.update(visible=True),
|
114 |
gr.update(visible=False, value=None),
|
@@ -118,6 +120,7 @@ def process_document(document, model, error=None):
|
|
118 |
else:
|
119 |
return (
|
120 |
None,
|
|
|
121 |
None,
|
122 |
gr.update(visible=False),
|
123 |
gr.update(visible=True, value=error) if error is not None else None,
|
@@ -129,6 +132,7 @@ def process_document(document, model, error=None):
|
|
129 |
def process_path(path, model):
|
130 |
error = None
|
131 |
document = None
|
|
|
132 |
if path:
|
133 |
try:
|
134 |
document = load_document(path)
|
@@ -136,7 +140,7 @@ def process_path(path, model):
|
|
136 |
traceback.print_exc()
|
137 |
error = str(e)
|
138 |
|
139 |
-
return process_document(document, model, error)
|
140 |
|
141 |
|
142 |
def process_upload(file, model):
|
@@ -159,40 +163,36 @@ def annotate_page(prediction, pages, document):
|
|
159 |
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
|
160 |
|
161 |
|
162 |
-
def process_question(
|
|
|
|
|
163 |
if not question or document is None:
|
164 |
return None, None, None
|
165 |
|
166 |
text_value = None
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
break
|
176 |
-
|
177 |
-
annotate_page(p, pages, document)
|
178 |
-
|
179 |
return (
|
180 |
gr.update(visible=True, value=pages),
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
value=text_value,
|
185 |
-
),
|
186 |
)
|
187 |
|
188 |
|
189 |
-
def process_fields(document, model=list(CHECKPOINTS.keys())[0]):
|
190 |
pages = [x.copy().convert("RGB") for x in document.preview]
|
191 |
|
192 |
ret = {}
|
193 |
table = []
|
194 |
|
195 |
-
for (field_name, questions) in
|
196 |
answers = [run_pipeline(model, q, document, top_k=1) for q in questions]
|
197 |
answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
|
198 |
top = answers[0]
|
@@ -208,23 +208,22 @@ def process_fields(document, model=list(CHECKPOINTS.keys())[0]):
|
|
208 |
|
209 |
|
210 |
def load_example_document(img, title, model):
|
|
|
|
|
211 |
if img is not None:
|
212 |
if title in QUESTION_FILES:
|
213 |
-
print("using document")
|
214 |
document = load_document(QUESTION_FILES[title])
|
215 |
else:
|
216 |
document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
|
217 |
-
else:
|
218 |
-
document = None
|
219 |
|
220 |
-
return process_document(document, model)
|
221 |
|
222 |
|
223 |
CSS = """
|
224 |
#question input {
|
225 |
font-size: 16px;
|
226 |
}
|
227 |
-
#url-textbox {
|
228 |
padding: 0 !important;
|
229 |
}
|
230 |
#short-upload-box .w-full {
|
@@ -327,6 +326,7 @@ with gr.Blocks(css=CSS) as demo:
|
|
327 |
)
|
328 |
|
329 |
document = gr.Variable()
|
|
|
330 |
example_question = gr.Textbox(visible=False)
|
331 |
example_image = gr.Image(visible=False)
|
332 |
|
@@ -364,13 +364,16 @@ with gr.Blocks(css=CSS) as demo:
|
|
364 |
)
|
365 |
|
366 |
with gr.Column() as col:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
gr.Markdown("## 2. Ask a question")
|
368 |
-
question = gr.Textbox(
|
369 |
-
label="Question",
|
370 |
-
placeholder="e.g. What is the invoice number?",
|
371 |
-
lines=1,
|
372 |
-
max_lines=1,
|
373 |
-
)
|
374 |
model = gr.Radio(
|
375 |
choices=list(CHECKPOINTS.keys()),
|
376 |
value=list(CHECKPOINTS.keys())[0],
|
@@ -379,24 +382,27 @@ with gr.Blocks(css=CSS) as demo:
|
|
379 |
)
|
380 |
|
381 |
with gr.Row():
|
382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
submit_button = gr.Button(
|
384 |
-
"
|
385 |
)
|
386 |
-
with gr.Tabs():
|
387 |
-
with gr.TabItem("Table"):
|
388 |
-
output_table = gr.Dataframe(**EMPTY_TABLE)
|
389 |
-
|
390 |
-
with gr.TabItem("JSON"):
|
391 |
-
output = gr.JSON(label="Output", visible=False)
|
392 |
|
393 |
for cb in [img_clear_button, clear_button]:
|
394 |
cb.click(
|
395 |
lambda _: (
|
396 |
-
gr.update(visible=False, value=None),
|
397 |
-
None,
|
398 |
-
|
399 |
-
gr.update(
|
|
|
400 |
gr.update(visible=False),
|
401 |
None,
|
402 |
None,
|
@@ -408,6 +414,7 @@ with gr.Blocks(css=CSS) as demo:
|
|
408 |
outputs=[
|
409 |
image,
|
410 |
document,
|
|
|
411 |
output,
|
412 |
output_table,
|
413 |
img_clear_button,
|
@@ -419,22 +426,32 @@ with gr.Blocks(css=CSS) as demo:
|
|
419 |
],
|
420 |
)
|
421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
upload.change(
|
423 |
fn=process_upload,
|
424 |
inputs=[upload, model],
|
425 |
-
outputs=
|
426 |
)
|
427 |
|
428 |
submit.click(
|
429 |
fn=process_path,
|
430 |
inputs=[url, model],
|
431 |
-
outputs=
|
432 |
)
|
433 |
|
434 |
question.submit(
|
435 |
fn=process_question,
|
436 |
-
inputs=[question, document, model],
|
437 |
-
outputs=[image, output, output_table],
|
438 |
)
|
439 |
|
440 |
submit_button.click(
|
@@ -452,7 +469,7 @@ with gr.Blocks(css=CSS) as demo:
|
|
452 |
example_image.change(
|
453 |
fn=load_example_document,
|
454 |
inputs=[example_image, example_question, model],
|
455 |
-
outputs=
|
456 |
)
|
457 |
|
458 |
if __name__ == "__main__":
|
|
|
6 |
import traceback
|
7 |
|
8 |
import gradio as gr
|
9 |
+
from gradio import processing_utils
|
10 |
|
11 |
import torch
|
12 |
from docquery import pipeline
|
|
|
100 |
"Payment Terms": ["Payment Terms?"],
|
101 |
}
|
102 |
|
|
|
|
|
|
|
103 |
|
104 |
+
def empty_table(fields):
|
105 |
+
return {"value": [[name, None] for name in fields.keys()], "interactive": False}
|
106 |
|
107 |
+
|
108 |
+
def process_document(document, fields, model, error=None):
|
109 |
if document is not None and error is None:
|
110 |
+
preview, json_output, table = process_fields(document, fields, model)
|
111 |
return (
|
112 |
document,
|
113 |
+
fields,
|
114 |
preview,
|
115 |
gr.update(visible=True),
|
116 |
gr.update(visible=False, value=None),
|
|
|
120 |
else:
|
121 |
return (
|
122 |
None,
|
123 |
+
fields,
|
124 |
None,
|
125 |
gr.update(visible=False),
|
126 |
gr.update(visible=True, value=error) if error is not None else None,
|
|
|
132 |
def process_path(path, model):
|
133 |
error = None
|
134 |
document = None
|
135 |
+
fields = {**FIELDS}
|
136 |
if path:
|
137 |
try:
|
138 |
document = load_document(path)
|
|
|
140 |
traceback.print_exc()
|
141 |
error = str(e)
|
142 |
|
143 |
+
return process_document(document, fields, model, error)
|
144 |
|
145 |
|
146 |
def process_upload(file, model):
|
|
|
163 |
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
|
164 |
|
165 |
|
166 |
+
def process_question(
|
167 |
+
question, document, img_gallery, model, fields, output, output_table
|
168 |
+
):
|
169 |
if not question or document is None:
|
170 |
return None, None, None
|
171 |
|
172 |
text_value = None
|
173 |
+
pages = [processing_utils.decode_base64_to_image(p) for p in img_gallery]
|
174 |
+
prediction = run_pipeline(model, question, document, 1)
|
175 |
+
annotate_page(prediction, pages, document)
|
176 |
+
|
177 |
+
field_name = question.rstrip("?")
|
178 |
+
fields = {**FIELDS, field_name: [question]}
|
179 |
+
output[field_name] = prediction
|
180 |
+
table = output_table.values.tolist() + [[field_name, prediction.get("answer")]]
|
|
|
|
|
|
|
|
|
181 |
return (
|
182 |
gr.update(visible=True, value=pages),
|
183 |
+
fields,
|
184 |
+
output,
|
185 |
+
gr.update(value=table, interactive=False),
|
|
|
|
|
186 |
)
|
187 |
|
188 |
|
189 |
+
def process_fields(document, fields, model=list(CHECKPOINTS.keys())[0]):
|
190 |
pages = [x.copy().convert("RGB") for x in document.preview]
|
191 |
|
192 |
ret = {}
|
193 |
table = []
|
194 |
|
195 |
+
for (field_name, questions) in fields.items():
|
196 |
answers = [run_pipeline(model, q, document, top_k=1) for q in questions]
|
197 |
answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
|
198 |
top = answers[0]
|
|
|
208 |
|
209 |
|
210 |
def load_example_document(img, title, model):
|
211 |
+
document = None
|
212 |
+
fields = {**FIELDS}
|
213 |
if img is not None:
|
214 |
if title in QUESTION_FILES:
|
|
|
215 |
document = load_document(QUESTION_FILES[title])
|
216 |
else:
|
217 |
document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
|
|
|
|
|
218 |
|
219 |
+
return process_document(document, fields, model)
|
220 |
|
221 |
|
222 |
CSS = """
|
223 |
#question input {
|
224 |
font-size: 16px;
|
225 |
}
|
226 |
+
#url-textbox, #question-textbox {
|
227 |
padding: 0 !important;
|
228 |
}
|
229 |
#short-upload-box .w-full {
|
|
|
326 |
)
|
327 |
|
328 |
document = gr.Variable()
|
329 |
+
fields = gr.Variable(value={**FIELDS})
|
330 |
example_question = gr.Textbox(visible=False)
|
331 |
example_image = gr.Image(visible=False)
|
332 |
|
|
|
364 |
)
|
365 |
|
366 |
with gr.Column() as col:
|
367 |
+
with gr.Tabs():
|
368 |
+
with gr.TabItem("Table"):
|
369 |
+
output_table = gr.Dataframe(
|
370 |
+
headers=["Field", "Value"], **empty_table(fields.value)
|
371 |
+
)
|
372 |
+
|
373 |
+
with gr.TabItem("JSON"):
|
374 |
+
output = gr.JSON(label="Output", visible=False)
|
375 |
+
|
376 |
gr.Markdown("## 2. Ask a question")
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
model = gr.Radio(
|
378 |
choices=list(CHECKPOINTS.keys()),
|
379 |
value=list(CHECKPOINTS.keys())[0],
|
|
|
382 |
)
|
383 |
|
384 |
with gr.Row():
|
385 |
+
question = gr.Textbox(
|
386 |
+
label="Question",
|
387 |
+
show_label=False,
|
388 |
+
placeholder="e.g. What is the invoice number?",
|
389 |
+
lines=1,
|
390 |
+
max_lines=1,
|
391 |
+
elem_id="question-textbox",
|
392 |
+
)
|
393 |
+
clear_button = gr.Button("Clear", variant="secondary", visible=False)
|
394 |
submit_button = gr.Button(
|
395 |
+
"Add", variant="primary", elem_id="submit-button"
|
396 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
for cb in [img_clear_button, clear_button]:
|
399 |
cb.click(
|
400 |
lambda _: (
|
401 |
+
gr.update(visible=False, value=None), # image
|
402 |
+
None, # document
|
403 |
+
{**FIELDS}, # fields
|
404 |
+
gr.update(visible=False, value=None), # output
|
405 |
+
gr.update(**empty_table(FIELDS)), # output_table
|
406 |
gr.update(visible=False),
|
407 |
None,
|
408 |
None,
|
|
|
414 |
outputs=[
|
415 |
image,
|
416 |
document,
|
417 |
+
fields,
|
418 |
output,
|
419 |
output_table,
|
420 |
img_clear_button,
|
|
|
426 |
],
|
427 |
)
|
428 |
|
429 |
+
submit_outputs = [
|
430 |
+
document,
|
431 |
+
fields,
|
432 |
+
image,
|
433 |
+
img_clear_button,
|
434 |
+
url_error,
|
435 |
+
output,
|
436 |
+
output_table,
|
437 |
+
]
|
438 |
+
|
439 |
upload.change(
|
440 |
fn=process_upload,
|
441 |
inputs=[upload, model],
|
442 |
+
outputs=submit_outputs,
|
443 |
)
|
444 |
|
445 |
submit.click(
|
446 |
fn=process_path,
|
447 |
inputs=[url, model],
|
448 |
+
outputs=submit_outputs,
|
449 |
)
|
450 |
|
451 |
question.submit(
|
452 |
fn=process_question,
|
453 |
+
inputs=[question, document, image, model, fields, output, output_table],
|
454 |
+
outputs=[image, fields, output, output_table],
|
455 |
)
|
456 |
|
457 |
submit_button.click(
|
|
|
469 |
example_image.change(
|
470 |
fn=load_example_document,
|
471 |
inputs=[example_image, example_question, model],
|
472 |
+
outputs=submit_outputs,
|
473 |
)
|
474 |
|
475 |
if __name__ == "__main__":
|