Spaces:
Sleeping
Sleeping
import streamlit as st | |
import warnings | |
warnings.simplefilter("ignore", UserWarning) | |
from uuid import uuid4 | |
from laia.scripts.htr.decode_ctc import run as decode | |
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs | |
import sys | |
from tempfile import NamedTemporaryFile, mkdtemp | |
from pathlib import Path | |
from contextlib import redirect_stdout | |
import re | |
from PIL import Image | |
from bidi.algorithm import get_display | |
import multiprocessing | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import logging | |
# Configure logging | |
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) | |
# Load YOLOv8 model | |
model = YOLO('model.pt') | |
images = Path(mkdtemp()) | |
DEFAULT_HEIGHT = 128 | |
TEXT_DIRECTION = "RTL" | |
NUM_WORKERS = multiprocessing.cpu_count() | |
# Regex pattern for extracting results | |
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})" | |
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line | |
TEXT_PATTERN = r"\s*(?P<text>.*)\s*" | |
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}") | |
def get_width(image, height=DEFAULT_HEIGHT): | |
aspect_ratio = image.width / image.height | |
return height * aspect_ratio | |
def predict(model_name, input_img): | |
model_dir = 'pylaia-samaritan_v1' | |
temperature = 2.0 | |
batch_size = 1 | |
weights_path = f"{model_dir}/weights.ckpt" | |
syms_path = f"{model_dir}/syms.txt" | |
language_model_params = {"language_model_weight": 1.0} | |
use_language_model = True | |
if use_language_model: | |
language_model_params.update({ | |
"language_model_path": f"{model_dir}/language_model.binary", | |
"lexicon_path": f"{model_dir}/lexicon.txt", | |
"tokens_path": f"{model_dir}/tokens.txt", | |
}) | |
common_args = CommonArgs( | |
checkpoint="weights.ckpt", | |
train_path=f"{model_dir}", | |
experiment_dirname="", | |
) | |
data_args = DataArgs(batch_size=batch_size, color_mode="L") | |
trainer_args = TrainerArgs(progress_bar_refresh_rate=0) | |
decode_args = DecodeArgs( | |
include_img_ids=True, | |
join_string="", | |
convert_spaces=True, | |
print_line_confidence_scores=True, | |
print_word_confidence_scores=False, | |
temperature=temperature, | |
use_language_model=use_language_model, | |
**language_model_params, | |
) | |
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list: | |
image_id = uuid4() | |
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT)) | |
input_img.save(f"{images}/{image_id}.jpg") | |
Path(img_list.name).write_text("\n".join([str(image_id)])) | |
with redirect_stdout(open(pred_stdout.name, mode="w")): | |
decode( | |
syms=str(syms_path), | |
img_list=img_list.name, | |
img_dirs=[str(images)], | |
common=common_args, | |
data=data_args, | |
trainer=trainer_args, | |
decode=decode_args, | |
num_workers=1, | |
) | |
sys.stdout.flush() | |
predictions = Path(pred_stdout.name).read_text().strip().splitlines() | |
_, score, text = LINE_PREDICTION.match(predictions[0]).groups() | |
if TEXT_DIRECTION == "RTL": | |
return input_img, {"text": get_display(text), "score": score} | |
else: | |
return input_img, {"text": text, "score": score} | |
def process_image(image): | |
# Perform inference on an image, select textline only | |
results = model(image, classes=0) | |
img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
boxes = results[0].boxes.xyxy.tolist() | |
boxes.sort(key=lambda x: x[1]) | |
bboxes = [] | |
polygons = [] | |
texts = [] | |
for i, box in enumerate(boxes): | |
x1, y1, x2, y2 = map(int, box) | |
crop_img = img_cv2[y1:y2, x1:x2] | |
crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)) | |
# Recognize text using PyLaia model | |
predicted = predict('pylaia-samaritan_v1', crop_pil) | |
texts.append(predicted[1]["text"]) | |
bboxes.append((x1, y1, x2, y2)) | |
polygons.append(f"Line {i+1}: {[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]}") | |
# Draw bounding box | |
cv2.rectangle(img_cv2, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
# Convert image back to RGB for display in Streamlit | |
img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) | |
# Combine polygons and texts into a DataFrame for table display | |
table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts}) | |
return Image.fromarray(img_result), table_data | |
def segment_and_recognize(image): | |
segmented_image, table_data = process_image(image) | |
return segmented_image, table_data | |
# Streamlit app layout | |
st.title("YOLOv8 Text Line Segmentation & PyLaia Text Recognition") | |
# File uploader | |
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
# Process the image if uploaded | |
if uploaded_image is not None: | |
image = Image.open(uploaded_image) | |
if st.button("Segment and Recognize"): | |
# Perform segmentation and recognition | |
segmented_image, table_data = segment_and_recognize(image) | |
# Display the segmented image | |
st.image(segmented_image, caption="Segmented Image with Bounding Boxes", use_column_width=True) | |
# Display the table with polygons and recognized text | |
st.table(table_data) |