aegnor8 commited on
Commit
78d1692
·
verified ·
1 Parent(s): bb640a3
Files changed (2) hide show
  1. app.py +118 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import streamlit as st
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from easyocr import Reader
9
+ from PIL import Image
10
+ from transformers import(
11
+ LayoutLMv3FeatureExtractor,
12
+ LayoutLMv3ForSequenceClassification,
13
+ LayoutLMv3Processor,
14
+ LayoutLMv3TokenizerFast,
15
+ )
16
+
17
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+ MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
19
+ MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"
20
+
21
+ def create_bounding_box(bbox_data,
22
+ width_scale: float,
23
+ height_scale: float):
24
+ xs = []
25
+ ys = []
26
+ for x, y in bbox_data:
27
+ xs.append(x)
28
+ ys.append(y)
29
+
30
+ left = int(min(xs) * width_scale)
31
+ top = int(min(ys) * height_scale)
32
+ right = int(max(xs) * width_scale)
33
+ bottom = int(max(ys) * height_scale)
34
+
35
+ return [left, top, right, bottom]
36
+
37
+ @st.cache_resource
38
+ def create_ocr_reader():
39
+ return Reader(["en"])
40
+
41
+ @st.cache_resource
42
+ def create_processor():
43
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
44
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
45
+ return LayoutLMv3Processor(feature_extractor, tokenizer)
46
+
47
+ @st.cache_resource
48
+ def create_model():
49
+ model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
50
+ return model.eval().to(DEVICE)
51
+
52
+ def predict(image: Image,
53
+ reader: Reader,
54
+ processor: LayoutLMv3Processor,
55
+ model: LayoutLMv3ForSequenceClassification):
56
+
57
+ ocr_result = reader.readtext(image)
58
+
59
+ width, height = image.size
60
+ width_scale = 1000 / width
61
+ height_scale = 1000 / height
62
+
63
+ words = []
64
+ boxes = []
65
+
66
+ for bbox, word, confidence in ocr_result:
67
+ words.append(word)
68
+ boxes.append(create_bounding_box(bbox, width_scale, height_scale))
69
+
70
+ encoding = processor(image,
71
+ words,
72
+ boxes = boxes,
73
+ max_length = 512,
74
+ padding = "max_length",
75
+ truncation = True,
76
+ return_tensors = "pt",)
77
+
78
+ with torch.inference_mode():
79
+ output = model(
80
+ input_ids = encoding["input_ids"].to(DEVICE),
81
+ attention_mask = encoding["attention_mask"].to(DEVICE),
82
+ bbox = encoding["bbox"].to(DEVICE),
83
+ pixel_values = encoding["pixel_values"].to(DEVICE)
84
+ )
85
+
86
+ logits = output.logits
87
+ predicted_class = logits.argmax()
88
+ probabilities = F.softmax(logits, dim = -1).flatten().tolist()
89
+
90
+ return predicted_class.detach().item(), probabilities
91
+
92
+ reader = create_ocr_reader()
93
+ processor = create_processor()
94
+ model = create_model()
95
+
96
+ uploaded_file = st.file_uploader("Upload Document Image", ["jpg", "png"])
97
+ if uploaded_file is not None:
98
+ bytes_data = io.BytesIO(uploaded_file.getvalue())
99
+ image = Image.open(bytes_data)
100
+ st.image(image, "Your Document")
101
+
102
+ predicted_class, probabilities = predict(image,
103
+ reader,
104
+ processor,
105
+ model)
106
+ predicted_label = model.config.id2label[predicted_class]
107
+
108
+ st.markdown(f"Predicted document type: **{predicted_label}**")
109
+
110
+ df_predictions = pd.DataFrame(
111
+ {"Document": list(model.config.id2label.values()),
112
+ "Confidence": probabilities}
113
+ )
114
+
115
+ fig = px.bar(df_predictions, x = "Document", y = "Confidence")
116
+ st.plotly_chart(fig, use_container_width = True)
117
+
118
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ easyocr==1.7.1
2
+ pandas==2.2.0
3
+ Pillow==9.5.0
4
+ plotly-express==0.4.1
5
+ torch==2.2.0
6
+ transformers==4.37.2