File size: 3,310 Bytes
9262ebb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import io
import pandas as pd
import plotly.express as px
import streamlit as st
import torch
import torch.nn.functional as F
from easyocr import Reader
from PIL import Image
from transformers import (
LayoutLMv3FeatureExtractor,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Processor,
LayoutLMv3TokenizerFast,
)
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"
def create_bounding_box(bbox_data, width_scale: float, height_scale: float):
xs = []
ys = []
for x, y in bbox_data:
xs.append(x)
ys.append(y)
left = int(min(xs) * width_scale)
top = int(min(ys) * height_scale)
right = int(max(xs) * width_scale)
bottom = int(max(ys) * height_scale)
return [left, top, right, bottom]
@st.experimental_singleton
def create_ocr_reader():
return Reader(["en"])
@st.experimental_singleton
def create_processor():
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
return LayoutLMv3Processor(feature_extractor, tokenizer)
@st.experimental_singleton
def create_model():
model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
return model.eval().to(DEVICE)
def predict(
image: Image,
reader: Reader,
processor: LayoutLMv3Processor,
model: LayoutLMv3ForSequenceClassification,
):
width, height = image.size
ocr_result = reader.readtext(image)
width_scale = 1000 / width
height_scale = 1000 / height
words = []
boxes = []
for bbox, word, confidence in ocr_result:
words.append(word)
boxes.append(create_bounding_box(bbox, width_scale, height_scale))
encoding = processor(
image,
words,
boxes=boxes,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.inference_mode():
output = model(
input_ids=encoding["input_ids"].to(DEVICE),
attention_mask=encoding["attention_mask"].to(DEVICE),
bbox=encoding["bbox"].to(DEVICE),
pixel_values=encoding["pixel_values"].to(DEVICE),
)
logits = output.logits
predicted_class = logits.argmax()
probabilities = F.softmax(logits, dim=-1).flatten().tolist()
return predicted_class, probabilities
reader = create_ocr_reader()
processor = create_processor()
model = create_model()
uploaded_file = st.file_uploader("Upload Document image", ["jpg", "png"])
if uploaded_file is not None:
bytes_data = io.BytesIO(uploaded_file.getvalue())
image = Image.open(bytes_data)
predicted_class, probabilities = predict(image, reader, processor, model)
predicted_label = model.config.id2label[predicted_class.item()]
st.image(image, "Your document image")
st.markdown(f"Predicted document type: **{predicted_label}**")
df_predictions = pd.DataFrame(
{"Document": list(model.config.id2label.values()), "Confidence": probabilities}
)
fig = px.bar(
df_predictions,
x="Document",
y="Confidence",
)
st.plotly_chart(fig, use_container_width=True)
|