File size: 3,689 Bytes
78d1692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import io

import pandas as pd
import plotly.express as px
import streamlit as st
import torch
import torch.nn.functional as F
from easyocr import Reader
from PIL import Image
from transformers import(
    LayoutLMv3FeatureExtractor,
    LayoutLMv3ForSequenceClassification,
    LayoutLMv3Processor,
    LayoutLMv3TokenizerFast,
)

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"

def create_bounding_box(bbox_data, 
                        width_scale: float, 
                        height_scale: float):
    xs = []
    ys = []
    for x, y in bbox_data:
        xs.append(x)
        ys.append(y)
 
    left = int(min(xs) * width_scale)
    top = int(min(ys) * height_scale)
    right = int(max(xs) * width_scale)
    bottom = int(max(ys) * height_scale)
 
    return [left, top, right, bottom]

@st.cache_resource
def create_ocr_reader():
    return Reader(["en"])

@st.cache_resource
def create_processor():
    feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
    tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
    return LayoutLMv3Processor(feature_extractor, tokenizer)

@st.cache_resource
def create_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
    return model.eval().to(DEVICE)

def predict(image: Image, 
            reader: Reader, 
            processor: LayoutLMv3Processor, 
            model: LayoutLMv3ForSequenceClassification):
    
    ocr_result = reader.readtext(image)

    width, height = image.size
    width_scale = 1000 / width
    height_scale = 1000 / height

    words = []
    boxes = []

    for bbox, word, confidence in ocr_result:
        words.append(word)
        boxes.append(create_bounding_box(bbox, width_scale, height_scale))

        encoding = processor(image, 
                             words, 
                             boxes = boxes, 
                             max_length = 512, 
                             padding = "max_length", 
                             truncation = True, 
                             return_tensors = "pt",)
        
        with torch.inference_mode():
            output = model(
                input_ids = encoding["input_ids"].to(DEVICE),
                attention_mask = encoding["attention_mask"].to(DEVICE),
                bbox = encoding["bbox"].to(DEVICE),
                pixel_values = encoding["pixel_values"].to(DEVICE)
            )

            logits = output.logits
            predicted_class = logits.argmax()
            probabilities = F.softmax(logits, dim = -1).flatten().tolist()

    return predicted_class.detach().item(), probabilities

reader = create_ocr_reader()
processor = create_processor()
model = create_model()

uploaded_file = st.file_uploader("Upload Document Image", ["jpg", "png"])
if uploaded_file is not None:
    bytes_data = io.BytesIO(uploaded_file.getvalue())
    image = Image.open(bytes_data)
    st.image(image, "Your Document")

    predicted_class, probabilities = predict(image, 
                                             reader, 
                                             processor, 
                                             model)
    predicted_label = model.config.id2label[predicted_class]

    st.markdown(f"Predicted document type: **{predicted_label}**")

    df_predictions = pd.DataFrame(
        {"Document": list(model.config.id2label.values()),
         "Confidence": probabilities}
    )

    fig = px.bar(df_predictions, x = "Document", y = "Confidence")
    st.plotly_chart(fig, use_container_width = True)