rishabh062 commited on
Commit
2adef2c
1 Parent(s): 6d35390

Adding files

Browse files
bill_of_lading_1.png ADDED
japanese-invoice.png ADDED
main.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import boto3
4
+ import traceback
5
+ import re
6
+ import logging
7
+
8
+ import gradio as gr
9
+ from PIL import Image, ImageDraw
10
+
11
+ from docquery.document import load_document, ImageDocument
12
+ from docquery.ocr_reader import get_ocr_reader
13
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
14
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
15
+ from transformers import pipeline
16
+
17
+ # avoid ssl errors
18
+ import ssl
19
+
20
+ ssl._create_default_https_context = ssl._create_unverified_context
21
+
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+
24
+ logging.basicConfig(level=logging.DEBUG)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Init models
28
+
29
+ layoutlm_pipeline = pipeline(
30
+ "document-question-answering",
31
+ model="impira/layoutlm-document-qa",
32
+ )
33
+ lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
34
+ lilt_model = AutoModelForQuestionAnswering.from_pretrained(
35
+ "nielsr/lilt-xlm-roberta-base"
36
+ )
37
+
38
+ donut_processor = DonutProcessor.from_pretrained(
39
+ "naver-clova-ix/donut-base-finetuned-docvqa"
40
+ )
41
+ donut_model = VisionEncoderDecoderModel.from_pretrained(
42
+ "naver-clova-ix/donut-base-finetuned-docvqa"
43
+ )
44
+
45
+ TEXTRACT = "Textract Query"
46
+ LAYOUTLM = "LayoutLM"
47
+ DONUT = "Donut"
48
+ LILT = "LiLT"
49
+
50
+
51
+ def image_to_byte_array(image: Image) -> bytes:
52
+ image_as_byte_array = io.BytesIO()
53
+ image.save(image_as_byte_array, format="PNG")
54
+ image_as_byte_array = image_as_byte_array.getvalue()
55
+ return image_as_byte_array
56
+
57
+
58
+ def run_textract(question, document):
59
+ logger.info(f"Running Textract model.")
60
+ image_as_byte_base64 = image_to_byte_array(image=document.b)
61
+ response = boto3.client("textract").analyze_document(
62
+ Document={
63
+ "Bytes": image_as_byte_base64,
64
+ },
65
+ FeatureTypes=[
66
+ "QUERIES",
67
+ ],
68
+ QueriesConfig={
69
+ "Queries": [
70
+ {
71
+ "Text": question,
72
+ "Pages": [
73
+ "*",
74
+ ],
75
+ },
76
+ ]
77
+ },
78
+ )
79
+ logger.info(f"Output of Textract model {response}.")
80
+ for element in response["Blocks"]:
81
+ if element["BlockType"] == "QUERY_RESULT":
82
+ return {
83
+ "score": element["Confidence"],
84
+ "answer": element["Text"],
85
+ # "word_ids": element
86
+ }
87
+ else:
88
+ Exception("No QUERY_RESULT found in the response from Textract.")
89
+
90
+
91
+ def run_layoutlm(question, document):
92
+ logger.info(f"Running layoutlm model.")
93
+ result = layoutlm_pipeline(document.context["image"][0][0], question)[0]
94
+ logger.info(f"Output of layoutlm model {result}.")
95
+ # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
96
+ return {
97
+ "score": result["score"],
98
+ "answer": result["answer"],
99
+ "word_ids": [result["start"], result["end"]],
100
+ "page": 0,
101
+ }
102
+
103
+
104
+ def run_lilt(question, document):
105
+ logger.info(f"Running lilt model.")
106
+ # use this model + tokenizer
107
+ processed_document = document.context["image"][0][1]
108
+ words = [x[0] for x in processed_document]
109
+ boxes = [x[1] for x in processed_document]
110
+
111
+ encoding = lilt_tokenizer(
112
+ text=question,
113
+ text_pair=words,
114
+ boxes=boxes,
115
+ add_special_tokens=True,
116
+ return_tensors="pt",
117
+ )
118
+ outputs = lilt_model(**encoding)
119
+ logger.info(f"Output for lilt model {outputs}.")
120
+
121
+ answer_start_index = outputs.start_logits.argmax()
122
+ answer_end_index = outputs.end_logits.argmax()
123
+
124
+ predict_answer_tokens = encoding.input_ids[
125
+ 0, answer_start_index: answer_end_index + 1
126
+ ]
127
+ predict_answer = lilt_tokenizer.decode(
128
+ predict_answer_tokens, skip_special_tokens=True
129
+ )
130
+ return {
131
+ "score": "n/a",
132
+ "answer": predict_answer,
133
+ # "word_ids": element
134
+ }
135
+
136
+
137
+ def run_donut(question, document):
138
+ logger.info(f"Running donut model.")
139
+ # prepare encoder inputs
140
+ pixel_values = donut_processor(
141
+ document.context["image"][0][0], return_tensors="pt"
142
+ ).pixel_values
143
+
144
+ # prepare decoder inputs
145
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
146
+ prompt = task_prompt.replace("{user_input}", question)
147
+ decoder_input_ids = donut_processor.tokenizer(
148
+ prompt, add_special_tokens=False, return_tensors="pt"
149
+ ).input_ids
150
+
151
+ # generate answer
152
+ outputs = donut_model.generate(
153
+ pixel_values,
154
+ decoder_input_ids=decoder_input_ids,
155
+ max_length=donut_model.decoder.config.max_position_embeddings,
156
+ early_stopping=True,
157
+ pad_token_id=donut_processor.tokenizer.pad_token_id,
158
+ eos_token_id=donut_processor.tokenizer.eos_token_id,
159
+ use_cache=True,
160
+ num_beams=1,
161
+ bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
162
+ return_dict_in_generate=True,
163
+ )
164
+ logger.info(f"Output for donut {outputs}")
165
+ sequence = donut_processor.batch_decode(outputs.sequences)[0]
166
+ sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(
167
+ donut_processor.tokenizer.pad_token, ""
168
+ )
169
+ sequence = re.sub(
170
+ r"<.*?>", "", sequence, count=1
171
+ ).strip() # remove first task start token
172
+
173
+ result = donut_processor.token2json(sequence)
174
+ return {
175
+ "score": "n/a",
176
+ "answer": result["answer"],
177
+ # "word_ids": element
178
+ }
179
+
180
+
181
+ def process_path(path):
182
+ error = None
183
+ if path:
184
+ try:
185
+ document = load_document(path)
186
+ return (
187
+ document,
188
+ gr.update(visible=True, value=document.preview),
189
+ gr.update(visible=True),
190
+ gr.update(visible=False, value=None),
191
+ gr.update(visible=False, value=None),
192
+ None,
193
+ )
194
+ except Exception as e:
195
+ traceback.print_exc()
196
+ error = str(e)
197
+ return (
198
+ None,
199
+ gr.update(visible=False, value=None),
200
+ gr.update(visible=False),
201
+ gr.update(visible=False, value=None),
202
+ gr.update(visible=False, value=None),
203
+ gr.update(visible=True, value=error) if error is not None else None,
204
+ None,
205
+ )
206
+
207
+
208
+ def process_upload(file):
209
+ if file:
210
+ return process_path(file.name)
211
+ else:
212
+ return (
213
+ None,
214
+ gr.update(visible=False, value=None),
215
+ gr.update(visible=False),
216
+ gr.update(visible=False, value=None),
217
+ gr.update(visible=False, value=None),
218
+ None,
219
+ )
220
+
221
+
222
+ def lift_word_boxes(document, page):
223
+ return document.context["image"][page][1]
224
+
225
+
226
+ def expand_bbox(word_boxes):
227
+ if len(word_boxes) == 0:
228
+ return None
229
+
230
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
231
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
232
+ return [min_x, min_y, max_x, max_y]
233
+
234
+
235
+ # LayoutLM boxes are normalized to 0, 1000
236
+ def normalize_bbox(box, width, height, padding=0.005):
237
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
238
+ if padding != 0:
239
+ min_x = max(0, min_x - padding)
240
+ min_y = max(0, min_y - padding)
241
+ max_x = min(max_x + padding, 1)
242
+ max_y = min(max_y + padding, 1)
243
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
244
+
245
+
246
+ MODELS = {
247
+ LAYOUTLM: run_layoutlm,
248
+ DONUT: run_donut,
249
+ # LILT: run_lilt,
250
+ TEXTRACT: run_textract,
251
+ }
252
+
253
+
254
+ def process_question(question, document, model=list(MODELS.keys())[0]):
255
+ if not question or document is None:
256
+ return None, None, None
257
+ logger.info(f"Running for model {model}")
258
+ prediction = MODELS[model](question=question, document=document)
259
+ logger.info(f"Got prediction {prediction}")
260
+ pages = [x.copy().convert("RGB") for x in document.preview]
261
+ text_value = prediction["answer"]
262
+ if "word_ids" in prediction:
263
+ logger.info(f"Setting bounding boxes.")
264
+ image = pages[prediction["page"]]
265
+ draw = ImageDraw.Draw(image, "RGBA")
266
+ word_boxes = lift_word_boxes(document, prediction["page"])
267
+ x1, y1, x2, y2 = normalize_bbox(
268
+ expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
269
+ image.width,
270
+ image.height,
271
+ )
272
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
273
+
274
+ return (
275
+ gr.update(visible=True, value=pages),
276
+ gr.update(visible=True, value=prediction),
277
+ gr.update(
278
+ visible=True,
279
+ value=text_value,
280
+ ),
281
+ )
282
+
283
+
284
+ def load_example_document(img, question, model):
285
+ if img is not None:
286
+ document = ImageDocument(Image.fromarray(img), get_ocr_reader())
287
+ preview, answer, answer_text = process_question(question, document, model)
288
+ return document, question, preview, gr.update(visible=True), answer, answer_text
289
+ else:
290
+ return None, None, None, gr.update(visible=False), None, None
291
+
292
+
293
+ CSS = """
294
+ #question input {
295
+ font-size: 16px;
296
+ }
297
+ #url-textbox {
298
+ padding: 0 !important;
299
+ }
300
+ #short-upload-box .w-full {
301
+ min-height: 10rem !important;
302
+ }
303
+ /* I think something like this can be used to re-shape
304
+ * the table
305
+ */
306
+ /*
307
+ .gr-samples-table tr {
308
+ display: inline;
309
+ }
310
+ .gr-samples-table .p-2 {
311
+ width: 100px;
312
+ }
313
+ */
314
+ #select-a-file {
315
+ width: 100%;
316
+ }
317
+ #file-clear {
318
+ padding-top: 2px !important;
319
+ padding-bottom: 2px !important;
320
+ padding-left: 8px !important;
321
+ padding-right: 8px !important;
322
+ margin-top: 10px;
323
+ }
324
+ .gradio-container .gr-button-primary {
325
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
326
+ border: 1px solid #B0DCCC;
327
+ border-radius: 8px;
328
+ color: #1B8700;
329
+ }
330
+ .gradio-container.dark button#submit-button {
331
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
332
+ border: 1px solid #B0DCCC;
333
+ border-radius: 8px;
334
+ color: #1B8700
335
+ }
336
+ table.gr-samples-table tr td {
337
+ border: none;
338
+ outline: none;
339
+ }
340
+ table.gr-samples-table tr td:first-of-type {
341
+ width: 0%;
342
+ }
343
+ div#short-upload-box div.absolute {
344
+ display: none !important;
345
+ }
346
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
347
+ gap: 0px 2%;
348
+ }
349
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
350
+ gap: 0px;
351
+ }
352
+ gradio-app h2, .gradio-app h2 {
353
+ padding-top: 10px;
354
+ }
355
+ #answer {
356
+ overflow-y: scroll;
357
+ color: white;
358
+ background: #666;
359
+ border-color: #666;
360
+ font-size: 20px;
361
+ font-weight: bold;
362
+ }
363
+ #answer span {
364
+ color: white;
365
+ }
366
+ #answer textarea {
367
+ color:white;
368
+ background: #777;
369
+ border-color: #777;
370
+ font-size: 18px;
371
+ }
372
+ #url-error input {
373
+ color: red;
374
+ }
375
+ """
376
+
377
+ examples = [
378
+ [
379
+ "scenario-1.png",
380
+ "What is the final consignee?",
381
+ ],
382
+ [
383
+ "scenario-1.png",
384
+ "What are the payment terms?",
385
+ ],
386
+ [
387
+ "scenario-2.png",
388
+ "What is the actual manufacturer?",
389
+ ],
390
+ [
391
+ "scenario-3.png",
392
+ 'What is the "ship to" destination?',
393
+ ],
394
+ [
395
+ "scenario-4.png",
396
+ "What is the color?",
397
+ ],
398
+ [
399
+ "scenario-5.png",
400
+ 'What is the "said to contain"?',
401
+ ],
402
+ [
403
+ "scenario-5.png",
404
+ 'What is the "Net Weight"?',
405
+ ],
406
+ [
407
+ "scenario-5.png",
408
+ 'What is the "Freight Collect"?',
409
+ ],
410
+ [
411
+ "bill_of_lading_1.png",
412
+ "What is the shipper?",
413
+ ],
414
+ [
415
+ "japanese-invoice.png",
416
+ "What is the total amount?",
417
+ ]
418
+ ]
419
+
420
+ with gr.Blocks(css=CSS) as demo:
421
+ gr.Markdown("# Document Question Answer Comparator")
422
+ gr.Markdown("""
423
+ This space compares some of the latest models that can be used commercially.
424
+ - [LayoutLM](https://huggingface.co/impira/layoutlm-document-qa) uses text/layout and images. Uses tesseract for OCR.
425
+ - [Donut](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa) OCR free document understanding. Uses vision encoder for OCR and a text decoder for providing the answer.
426
+ - [Textract Query](https://docs.aws.amazon.com/textract/latest/dg/what-is.html) OCR + document understanding solution of AWS.
427
+ """)
428
+
429
+ document = gr.Variable()
430
+ example_question = gr.Textbox(visible=False)
431
+ example_image = gr.Image(visible=False)
432
+
433
+ with gr.Row(equal_height=True):
434
+ with gr.Column():
435
+ with gr.Row():
436
+ gr.Markdown("## 1. Select a file", elem_id="select-a-file")
437
+ img_clear_button = gr.Button(
438
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
439
+ )
440
+ image = gr.Gallery(visible=False)
441
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
442
+ gr.Examples(
443
+ examples=examples,
444
+ inputs=[example_image, example_question],
445
+ )
446
+
447
+ with gr.Column() as col:
448
+ gr.Markdown("## 2. Ask a question")
449
+ question = gr.Textbox(
450
+ label="Question",
451
+ placeholder="e.g. What is the invoice number?",
452
+ lines=1,
453
+ max_lines=1,
454
+ )
455
+ model = gr.Radio(
456
+ choices=list(MODELS.keys()),
457
+ value=list(MODELS.keys())[0],
458
+ label="Model",
459
+ )
460
+
461
+ with gr.Row():
462
+ clear_button = gr.Button("Clear", variant="secondary")
463
+ submit_button = gr.Button(
464
+ "Submit", variant="primary", elem_id="submit-button"
465
+ )
466
+ with gr.Column():
467
+ output_text = gr.Textbox(
468
+ label="Top Answer", visible=False, elem_id="answer"
469
+ )
470
+ output = gr.JSON(label="Output", visible=False)
471
+
472
+ for cb in [img_clear_button, clear_button]:
473
+ cb.click(
474
+ lambda _: (
475
+ gr.update(visible=False, value=None),
476
+ None,
477
+ gr.update(visible=False, value=None),
478
+ gr.update(visible=False, value=None),
479
+ gr.update(visible=False),
480
+ None,
481
+ None,
482
+ None,
483
+ gr.update(visible=False, value=None),
484
+ None,
485
+ ),
486
+ inputs=clear_button,
487
+ outputs=[
488
+ image,
489
+ document,
490
+ output,
491
+ output_text,
492
+ img_clear_button,
493
+ example_image,
494
+ upload,
495
+ question,
496
+ ],
497
+ )
498
+
499
+ upload.change(
500
+ fn=process_upload,
501
+ inputs=[upload],
502
+ outputs=[document, image, img_clear_button, output, output_text],
503
+ )
504
+
505
+ question.submit(
506
+ fn=process_question,
507
+ inputs=[question, document, model],
508
+ outputs=[image, output, output_text],
509
+ )
510
+
511
+ submit_button.click(
512
+ process_question,
513
+ inputs=[question, document, model],
514
+ outputs=[image, output, output_text],
515
+ )
516
+
517
+ model.change(
518
+ process_question,
519
+ inputs=[question, document, model],
520
+ outputs=[image, output, output_text],
521
+ )
522
+
523
+ example_image.change(
524
+ fn=load_example_document,
525
+ inputs=[example_image, example_question, model],
526
+ outputs=[document, question, image, img_clear_button, output, output_text],
527
+ )
528
+
529
+ if __name__ == "__main__":
530
+ demo.launch(enable_queue=False)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torc
2
+ docquery[web,donut]
3
+ transformers
4
+ gradio
5
+ boto3
6
+ pillow
scenario-1.png ADDED
scenario-2.png ADDED
scenario-3.png ADDED
scenario-4.png ADDED
scenario-5.png ADDED