yolo-pylaia / app.py
johnlockejrr's picture
Update app.py
4d40b14 verified
import streamlit as st
import warnings
warnings.simplefilter("ignore", UserWarning)
from uuid import uuid4
from laia.scripts.htr.decode_ctc import run as decode
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
import sys
from tempfile import NamedTemporaryFile, mkdtemp
from pathlib import Path
from contextlib import redirect_stdout
import re
from PIL import Image
from bidi.algorithm import get_display
import multiprocessing
from ultralytics import YOLO
import cv2
import numpy as np
import pandas as pd
import logging
# Configure logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
# Load YOLOv8 model
model = YOLO('model.pt')
images = Path(mkdtemp())
DEFAULT_HEIGHT = 128
TEXT_DIRECTION = "RTL"
NUM_WORKERS = multiprocessing.cpu_count()
# Regex pattern for extracting results
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
def get_width(image, height=DEFAULT_HEIGHT):
aspect_ratio = image.width / image.height
return height * aspect_ratio
def predict(model_name, input_img):
model_dir = 'pylaia-samaritan_v1'
temperature = 2.0
batch_size = 1
weights_path = f"{model_dir}/weights.ckpt"
syms_path = f"{model_dir}/syms.txt"
language_model_params = {"language_model_weight": 1.0}
use_language_model = True
if use_language_model:
language_model_params.update({
"language_model_path": f"{model_dir}/language_model.binary",
"lexicon_path": f"{model_dir}/lexicon.txt",
"tokens_path": f"{model_dir}/tokens.txt",
})
common_args = CommonArgs(
checkpoint="weights.ckpt",
train_path=f"{model_dir}",
experiment_dirname="",
)
data_args = DataArgs(batch_size=batch_size, color_mode="L")
trainer_args = TrainerArgs(progress_bar_refresh_rate=0)
decode_args = DecodeArgs(
include_img_ids=True,
join_string="",
convert_spaces=True,
print_line_confidence_scores=True,
print_word_confidence_scores=False,
temperature=temperature,
use_language_model=use_language_model,
**language_model_params,
)
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
image_id = uuid4()
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
input_img.save(f"{images}/{image_id}.jpg")
Path(img_list.name).write_text("\n".join([str(image_id)]))
with redirect_stdout(open(pred_stdout.name, mode="w")):
decode(
syms=str(syms_path),
img_list=img_list.name,
img_dirs=[str(images)],
common=common_args,
data=data_args,
trainer=trainer_args,
decode=decode_args,
num_workers=1,
)
sys.stdout.flush()
predictions = Path(pred_stdout.name).read_text().strip().splitlines()
_, score, text = LINE_PREDICTION.match(predictions[0]).groups()
if TEXT_DIRECTION == "RTL":
return input_img, {"text": get_display(text), "score": score}
else:
return input_img, {"text": text, "score": score}
def process_image(image):
# Perform inference on an image, select textline only
results = model(image, classes=0)
img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
boxes = results[0].boxes.xyxy.tolist()
boxes.sort(key=lambda x: x[1])
bboxes = []
polygons = []
texts = []
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box)
crop_img = img_cv2[y1:y2, x1:x2]
crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
# Recognize text using PyLaia model
predicted = predict('pylaia-samaritan_v1', crop_pil)
texts.append(predicted[1]["text"])
bboxes.append((x1, y1, x2, y2))
polygons.append(f"Line {i+1}: {[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]}")
# Draw bounding box
cv2.rectangle(img_cv2, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Convert image back to RGB for display in Streamlit
img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
# Combine polygons and texts into a DataFrame for table display
table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts})
return Image.fromarray(img_result), table_data
def segment_and_recognize(image):
segmented_image, table_data = process_image(image)
return segmented_image, table_data
# Streamlit app layout
st.title("YOLOv8 Text Line Segmentation & PyLaia Text Recognition")
# File uploader
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
# Process the image if uploaded
if uploaded_image is not None:
image = Image.open(uploaded_image)
if st.button("Segment and Recognize"):
# Perform segmentation and recognition
segmented_image, table_data = segment_and_recognize(image)
# Display the segmented image
st.image(segmented_image, caption="Segmented Image with Bounding Boxes", use_column_width=True)
# Display the table with polygons and recognized text
st.table(table_data)