File size: 3,310 Bytes
9262ebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import io

import pandas as pd
import plotly.express as px
import streamlit as st
import torch
import torch.nn.functional as F
from easyocr import Reader
from PIL import Image
from transformers import (
    LayoutLMv3FeatureExtractor,
    LayoutLMv3ForSequenceClassification,
    LayoutLMv3Processor,
    LayoutLMv3TokenizerFast,
)

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"


def create_bounding_box(bbox_data, width_scale: float, height_scale: float):
    xs = []
    ys = []
    for x, y in bbox_data:
        xs.append(x)
        ys.append(y)

    left = int(min(xs) * width_scale)
    top = int(min(ys) * height_scale)
    right = int(max(xs) * width_scale)
    bottom = int(max(ys) * height_scale)

    return [left, top, right, bottom]


@st.experimental_singleton
def create_ocr_reader():
    return Reader(["en"])


@st.experimental_singleton
def create_processor():
    feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
    tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
    return LayoutLMv3Processor(feature_extractor, tokenizer)


@st.experimental_singleton
def create_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
    return model.eval().to(DEVICE)


def predict(
    image: Image,
    reader: Reader,
    processor: LayoutLMv3Processor,
    model: LayoutLMv3ForSequenceClassification,
):
    width, height = image.size
    ocr_result = reader.readtext(image)

    width_scale = 1000 / width
    height_scale = 1000 / height

    words = []
    boxes = []
    for bbox, word, confidence in ocr_result:
        words.append(word)
        boxes.append(create_bounding_box(bbox, width_scale, height_scale))

    encoding = processor(
        image,
        words,
        boxes=boxes,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    with torch.inference_mode():
        output = model(
            input_ids=encoding["input_ids"].to(DEVICE),
            attention_mask=encoding["attention_mask"].to(DEVICE),
            bbox=encoding["bbox"].to(DEVICE),
            pixel_values=encoding["pixel_values"].to(DEVICE),
        )
    logits = output.logits
    predicted_class = logits.argmax()
    probabilities = F.softmax(logits, dim=-1).flatten().tolist()
    return predicted_class, probabilities


reader = create_ocr_reader()
processor = create_processor()
model = create_model()

uploaded_file = st.file_uploader("Upload Document image", ["jpg", "png"])
if uploaded_file is not None:
    bytes_data = io.BytesIO(uploaded_file.getvalue())
    image = Image.open(bytes_data)
    predicted_class, probabilities = predict(image, reader, processor, model)
    predicted_label = model.config.id2label[predicted_class.item()]

    st.image(image, "Your document image")
    st.markdown(f"Predicted document type: **{predicted_label}**")

    df_predictions = pd.DataFrame(
        {"Document": list(model.config.id2label.values()), "Confidence": probabilities}
    )

    fig = px.bar(
        df_predictions,
        x="Document",
        y="Confidence",
    )
    st.plotly_chart(fig, use_container_width=True)