ankur-bohra commited on
Commit
491a9c3
1 Parent(s): e853e36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -29
app.py CHANGED
@@ -10,35 +10,38 @@ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-fin
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
- def process_document(image, question):
14
- # prepare encoder inputs
15
- pixel_values = processor(image, return_tensors="pt").pixel_values
16
-
17
- # prepare decoder inputs
18
- task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
19
- prompt = task_prompt.replace("{user_input}", question)
20
- decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
21
-
22
- # generate answer
23
- outputs = model.generate(
24
- pixel_values.to(device),
25
- decoder_input_ids=decoder_input_ids.to(device),
26
- max_length=model.decoder.config.max_position_embeddings,
27
- early_stopping=True,
28
- pad_token_id=processor.tokenizer.pad_token_id,
29
- eos_token_id=processor.tokenizer.eos_token_id,
30
- use_cache=True,
31
- num_beams=1,
32
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
33
- return_dict_in_generate=True,
34
- )
35
-
36
- # postprocess
37
- sequence = processor.batch_decode(outputs.sequences)[0]
38
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
39
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
40
-
41
- return processor.token2json(sequence)
 
 
 
42
 
43
  description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
44
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
+ def process_document(image, *questions):
14
+ output = []
15
+ for question in questions:
16
+ # prepare encoder inputs
17
+ pixel_values = processor(image, return_tensors="pt").pixel_values
18
+
19
+ # prepare decoder inputs
20
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
21
+ prompt = task_prompt.replace("{user_input}", question)
22
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
23
+
24
+ # generate answer
25
+ outputs = model.generate(
26
+ pixel_values.to(device),
27
+ decoder_input_ids=decoder_input_ids.to(device),
28
+ max_length=model.decoder.config.max_position_embeddings,
29
+ early_stopping=True,
30
+ pad_token_id=processor.tokenizer.pad_token_id,
31
+ eos_token_id=processor.tokenizer.eos_token_id,
32
+ use_cache=True,
33
+ num_beams=1,
34
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
35
+ return_dict_in_generate=True,
36
+ )
37
+
38
+ # postprocess
39
+ sequence = processor.batch_decode(outputs.sequences)[0]
40
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
41
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
42
+
43
+ output.append(processor.token2json(sequence))
44
+ return output
45
 
46
  description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
47
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"