|
--- |
|
language: |
|
- zh |
|
library_name: transformers |
|
tags: |
|
- donut |
|
- donut-python |
|
--- |
|
|
|
### Installtion |
|
```bash |
|
pip install torch |
|
pip install transformers==4.11.3 |
|
pip install opencv-python==4.6.0.66 |
|
pip install donut-python |
|
``` |
|
|
|
### Usage |
|
```python |
|
import sys |
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
import shutil |
|
|
|
from tqdm import tqdm |
|
import re |
|
|
|
from donut import DonutModel |
|
import torch |
|
from PIL import Image |
|
|
|
zh_model_path = "question_generator_by_zh_on_pic" |
|
|
|
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" |
|
zh_pretrained_model = DonutModel.from_pretrained(zh_model_path) |
|
|
|
if torch.cuda.is_available(): |
|
zh_pretrained_model.half() |
|
device = torch.device("cuda") |
|
zh_pretrained_model.to(device) |
|
|
|
zh_pretrained_model.eval() |
|
print("have load !") |
|
|
|
def demo_process_vqa(input_img, question): |
|
#input_img = Image.fromarray(input_img) |
|
global zh_pretrained_model, task_prompt |
|
user_prompt = task_prompt.replace("{user_input}", question) |
|
output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] |
|
req = { |
|
"question": output["answer"], |
|
"answer": output["question"] |
|
} |
|
return req |
|
|
|
|
|
img_path = "zh_img.png" |
|
demo_process_vqa(Image.open(img_path), "零钱通", ) |
|
|
|
''' |
|
{ |
|
"question": "支付方式是什么?", |
|
"answer": "零钱通" |
|
} |
|
''' |
|
|
|
``` |
|
|
|
### Sample Image |
|
<img src="https://raw.githubusercontent.com/svjack/docvqa-gen/main/imgs/zh_img.png" width = "500px" height = "500px"/> |