MaxP commited on
Commit
9eddc4e
·
1 Parent(s): ab64ec8
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Imports
2
+ import re
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
6
+ from dotenv import load_dotenv
7
+ import os
8
+ from torchvision import transforms
9
+ import torch
10
+ from PIL import Image
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ model_name = 'naver-clova-ix/donut-base-finetuned-docvqa'
14
+
15
+ # Importante esta app esta pensada para que el modelo corra en CPU
16
+ processor = DonutProcessor.from_pretrained(model_name)
17
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
18
+
19
+ # Defino la funcion principal que ejecuta el modelo y obtiene los resultados
20
+ def process_image(image, question):
21
+ # Paso por el procesador la imagen y especifico que los outputs sean tensores de pytorch
22
+ pixel_values = processor(image, return_tensors='pt').pixel_values
23
+ # Seteo el prompt
24
+ prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
25
+ # Generamos la sequencia de tokens de salida esto es un vector largo con los ids
26
+ # Esta parte encodea la pregunta y se la pasa al decoder junto con la representación de la imagen post encoder
27
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids
28
+ # Defino los outputs
29
+ outputs = model.generate(
30
+ pixel_values.to(device).half(),
31
+ decoder_input_ids=decoder_input_ids.to(device),
32
+ max_length=model.decoder.config.max_position_embeddings,
33
+ early_stopping=True,
34
+ pad_token_id=processor.tokenizer.pad_token_id,
35
+ eos_token_id=processor.tokenizer.eos_token_id,
36
+ use_cache=True,
37
+ num_beams=1, # Probar cambiando este parametro
38
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
39
+ return_dict_in_generate=True
40
+ )
41
+ # Realizo el Post-processing de la salida del modelo
42
+ sequence = processor.batch_decode(outputs.sequences)[0]
43
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
+ processor.tokenizer.eos_token
45
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
46
+
47
+ return processor.token2json(sequence)
48
+
49
+ description = "Esta es una aplicacción realizada con el modelo Donut fine tuned en DocVQA"
50
+
51
+ demo = gr.Interface(
52
+ fn=process_image,
53
+ inputs=['image', 'text'],
54
+ outputs='json',
55
+ title='Demo: Document Question Answering',
56
+ description=description,
57
+ enable_queue=True,
58
+ examples=[
59
+ ['examples/dni_25.jpg', 'cual es el documento / document number?'],
60
+ ['examples/extracto.jpg', 'cual es el telefono de centros servicios de banco galicia?'],
61
+ ['examples/factura_5.jpg', 'cual es el total de la factura?'],
62
+ ]
63
+ )
examples/dni_25.jpg ADDED
examples/extracto.jpg ADDED
examples/factura_5.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate @ git+https://github.com/huggingface/accelerate.git@109f3272f542cbd0e34022f5455078a4ab99d7eb
2
+ aiofiles==23.1.0
3
+ aiohttp==3.8.4
4
+ aiosignal==1.3.1
5
+ altair==5.0.1
6
+ anyio==3.7.0
7
+ arxiv==1.4.7
8
+ asttokens==2.2.1
9
+ async-timeout==4.0.2
10
+ attrs==23.1.0
11
+ backcall==0.2.0
12
+ backoff==2.2.1
13
+ bitsandbytes @ https://github.com/acpopescu/bitsandbytes/releases/download/v0.38.0-win0/bitsandbytes-0.38.1-py3-none-any.whl
14
+ certifi==2023.5.7
15
+ charset-normalizer==3.1.0
16
+ chromadb==0.3.25
17
+ click==8.1.3
18
+ clickhouse-connect==0.5.25
19
+ colorama==0.4.6
20
+ coloredlogs==15.0.1
21
+ comm==0.1.3
22
+ contourpy==1.0.7
23
+ cycler==0.11.0
24
+ dataclasses-json==0.5.7
25
+ debugpy==1.6.7
26
+ decorator==5.1.1
27
+ diffusers==0.16.1
28
+ duckdb==0.8.0
29
+ einops==0.6.1
30
+ exceptiongroup==1.1.1
31
+ executing==1.2.0
32
+ fastapi==0.96.0
33
+ feedparser==6.0.10
34
+ ffmpy==0.3.0
35
+ filelock==3.12.0
36
+ flatbuffers==23.5.26
37
+ fonttools==4.39.4
38
+ frozenlist==1.3.3
39
+ fsspec==2023.5.0
40
+ gradio==3.34.0
41
+ gradio_client==0.2.6
42
+ greenlet==2.0.2
43
+ h11==0.14.0
44
+ hnswlib==0.7.0
45
+ httpcore==0.17.2
46
+ httptools==0.5.0
47
+ httpx==0.24.1
48
+ huggingface-hub==0.15.1
49
+ humanfriendly==10.0
50
+ idna==3.4
51
+ importlib-metadata==6.6.0
52
+ ipykernel==6.23.1
53
+ ipython==8.14.0
54
+ ipywidgets==8.0.6
55
+ jedi==0.18.2
56
+ Jinja2==3.1.2
57
+ joblib==1.2.0
58
+ jsonschema==4.17.3
59
+ jupyter_client==8.2.0
60
+ jupyter_core==5.3.0
61
+ jupyterlab-widgets==3.0.7
62
+ kiwisolver==1.4.4
63
+ langchain==0.0.189
64
+ linkify-it-py==2.0.2
65
+ lz4==4.3.2
66
+ markdown-it-py==2.2.0
67
+ MarkupSafe==2.1.2
68
+ marshmallow==3.19.0
69
+ marshmallow-enum==1.5.1
70
+ matplotlib==3.7.1
71
+ matplotlib-inline==0.1.6
72
+ mdit-py-plugins==0.3.3
73
+ mdurl==0.1.2
74
+ monotonic==1.6
75
+ mpmath==1.3.0
76
+ multidict==6.0.4
77
+ mypy-extensions==1.0.0
78
+ nest-asyncio==1.5.6
79
+ networkx==3.1
80
+ nltk==3.8.1
81
+ numexpr==2.8.4
82
+ numpy==1.24.3
83
+ onnxruntime==1.15.0
84
+ openai==0.27.7
85
+ openapi-schema-pydantic==1.2.4
86
+ orjson==3.9.0
87
+ overrides==7.3.1
88
+ packaging==23.1
89
+ pandas==2.0.2
90
+ parso==0.8.3
91
+ peft @ git+https://github.com/huggingface/peft.git@fcff23f005fc7bfb816ad1f55360442c170cd5f5
92
+ pickleshare==0.7.5
93
+ Pillow==9.3.0
94
+ platformdirs==3.5.1
95
+ posthog==3.0.1
96
+ prompt-toolkit==3.0.38
97
+ protobuf==4.23.2
98
+ psutil==5.9.5
99
+ pure-eval==0.2.2
100
+ pydantic==1.10.8
101
+ pydub==0.25.1
102
+ Pygments==2.15.1
103
+ pyodbc==4.0.39
104
+ pyparsing==3.0.9
105
+ pypdf==3.9.1
106
+ pyreadline3==3.4.1
107
+ pyrsistent==0.19.3
108
+ python-dateutil==2.8.2
109
+ python-dotenv==1.0.0
110
+ python-multipart==0.0.6
111
+ pytz==2023.3
112
+ PyYAML==6.0
113
+ pyzmq==25.1.0
114
+ regex==2023.5.5
115
+ requests==2.31.0
116
+ safetensors==0.3.1
117
+ scikit-learn==1.2.2
118
+ scipy==1.10.1
119
+ semantic-version==2.10.0
120
+ sentence-transformers==2.2.2
121
+ sentencepiece==0.1.99
122
+ sgmllib3k==1.0.0
123
+ six==1.16.0
124
+ sniffio==1.3.0
125
+ SQLAlchemy==2.0.15
126
+ stack-data==0.6.2
127
+ starlette==0.27.0
128
+ sympy==1.12
129
+ tenacity==8.2.2
130
+ text-generation==0.6.0
131
+ threadpoolctl==3.1.0
132
+ tiktoken==0.4.0
133
+ tokenizers==0.13.3
134
+ toolz==0.12.0
135
+ torch
136
+ tornado==6.3.2
137
+ tqdm==4.65.0
138
+ traitlets==5.9.0
139
+ transformers @ git+https://github.com/huggingface/transformers.git@bacaab1629972b85664fe61ec3caa4da7b55b041
140
+ typing-inspect==0.9.0
141
+ typing_extensions==4.6.3
142
+ tzdata==2023.3
143
+ uc-micro-py==1.0.2
144
+ urllib3==2.0.2
145
+ uvicorn==0.22.0
146
+ watchfiles==0.19.0
147
+ wcwidth==0.2.6
148
+ websockets==11.0.3
149
+ widgetsnbextension==4.0.7
150
+ yarl==1.9.2
151
+ zipp==3.15.0
152
+ zstandard==0.21.0