arthur-lima commited on
Commit
31fe822
·
1 Parent(s): 82d0c99

Início aplicação

Browse files
Files changed (2) hide show
  1. app.py +132 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import streamlit as st
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from easyocr import Reader
9
+ from PIL import Image
10
+ from transformers import (
11
+ LayoutLMv3FeatureExtractor,
12
+ LayoutLMv3TokenizerFast,
13
+ LayoutLMv3Processor,
14
+ LayoutLMv3ForSequenceClassification,
15
+ )
16
+
17
+ # DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+ DEVICE = "cpu"
19
+ MICROSOFT_HODEL_NAME = "microsoft/layoutlmv3-base"
20
+ MODEL_NAME = "arthur-lima/layoutlmv3-triagem-documentos"
21
+
22
+
23
+ def create_bounding_box(bbox_data, width_scale: float, height_scale: float):
24
+ xs = []
25
+ ys = []
26
+ for x, y in bbox_data:
27
+ xs.append(x)
28
+ ys.append(y)
29
+ left = int(min(xs) * width_scale)
30
+ top = int(min(ys) * height_scale)
31
+ right = int(max(xs) * width_scale)
32
+ bottom = int(max(ys) * height_scale)
33
+ return [left, top, right, bottom]
34
+
35
+
36
+ @st.experimental_singleton
37
+ def create_ocr_reader():
38
+ # return Reader(["pt", "en"], gpu=True)
39
+ return Reader(["pt", "en"], gpu=False)
40
+
41
+
42
+ @st.experimental_singleton
43
+ def create_processor():
44
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
45
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_HODEL_NAME)
46
+ return LayoutLMv3Processor(feature_extractor, tokenizer)
47
+
48
+
49
+ @st.experimental_singleton
50
+ def create_model():
51
+ model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
52
+ return model.eval().to(DEVICE)
53
+
54
+
55
+ def predict(
56
+ image: Image.Image,
57
+ image_bytes: bytes,
58
+ reader: Reader,
59
+ processor: LayoutLMv3Processor,
60
+ model: LayoutLMv3ForSequenceClassification,
61
+ ):
62
+
63
+ ocr_result = reader.readtext(image_bytes)
64
+
65
+ width, height = image.size
66
+ width_scale = 1000 / width
67
+ height_scale = 1000 / height
68
+
69
+ words = []
70
+ boxes = []
71
+ for bbox, word, _ in ocr_result:
72
+ boxes.append(create_bounding_box(bbox, width_scale, height_scale))
73
+ words.append(word)
74
+
75
+ encoding = processor(
76
+ image,
77
+ words,
78
+ boxes=boxes,
79
+ max_length=512,
80
+ padding="max_length",
81
+ truncation=True,
82
+ return_tensors="pt",
83
+ )
84
+
85
+ with torch.inference_mode():
86
+ output = model(
87
+ input_ids=encoding["input_ids"].to(DEVICE),
88
+ attention_mask=encoding["attention_mask"].to(DEVICE),
89
+ bbox=encoding["bbox"].to(DEVICE),
90
+ pixel_values=encoding["pixel_values"].to(DEVICE),
91
+ )
92
+
93
+ logits = output.logits
94
+ predicted_class = logits.argmax()
95
+ probabilities = (
96
+ F.softmax(logits, dim=-1).flatten().tolist()
97
+ ) # Convertendo em probabilidades novamente
98
+ # return model.config.id2label[predicted_class.item()]
99
+ return predicted_class.detach().item(), probabilities
100
+
101
+
102
+ reader = create_ocr_reader()
103
+ processor = create_processor()
104
+ model = create_model()
105
+ uploaded_file = st.file_uploader("Upload Document Image", ["jpg", "png"])
106
+ if uploaded_file is not None:
107
+ # Upload da imagem
108
+ image_bytes = uploaded_file.getvalue()
109
+ bytes_data = io.BytesIO(image_bytes)
110
+ image = Image.open(bytes_data)
111
+
112
+ # Mostrar a imagem
113
+ st.image(image, "Página do documento", width=300)
114
+
115
+ # Fazer a previsão
116
+ predicted_class, probabilities = predict(
117
+ image, image_bytes, reader, processor, model
118
+ )
119
+
120
+ # Imprimir o resultado na tela
121
+ predicted_label = model.config.id2label[predicted_class]
122
+ st.markdown(f"Tipo do documento previsto: **{predicted_label}**")
123
+
124
+ # Desenhar o gráfico de confianças
125
+ df_predictions = pd.DataFrame(
126
+ {
127
+ "Tipo Documento": list(model.config.id2label.values()),
128
+ "Confiança": probabilities,
129
+ }
130
+ )
131
+ fig = px.bar(df_predictions, x="Tipo Documento", y="Confiança")
132
+ st.plotly_chart(fig, use_container_width=True)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ PyMuPDF==1.21.1
2
+ numpy==1.24.2
3
+ streamlit==1.15.2
4
+ transformers==4.25.1
5
+ pandas=2.0.0
6
+ plotly-express=0.4.1
7
+ python-dotenv==1.0.0
8
+ Pillow=9.4.0
9
+ torch=2.0.0
10
+ easyocr=1.6.2