import streamlit as st import warnings warnings.simplefilter("ignore", UserWarning) from uuid import uuid4 from laia.scripts.htr.decode_ctc import run as decode from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs import sys from tempfile import NamedTemporaryFile, mkdtemp from pathlib import Path from contextlib import redirect_stdout import re from PIL import Image from bidi.algorithm import get_display import multiprocessing from ultralytics import YOLO import cv2 import numpy as np import pandas as pd import logging # Configure logging logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) # Load YOLOv8 model model = YOLO('model.pt') images = Path(mkdtemp()) DEFAULT_HEIGHT = 128 TEXT_DIRECTION = "RTL" NUM_WORKERS = multiprocessing.cpu_count() # Regex pattern for extracting results IMAGE_ID_PATTERN = r"(?P[-a-z0-9]{36})" CONFIDENCE_PATTERN = r"(?P[0-9.]+)" # For line TEXT_PATTERN = r"\s*(?P.*)\s*" LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}") def get_width(image, height=DEFAULT_HEIGHT): aspect_ratio = image.width / image.height return height * aspect_ratio def predict(model_name, input_img): model_dir = 'pylaia-samaritan_v1' temperature = 2.0 batch_size = 1 weights_path = f"{model_dir}/weights.ckpt" syms_path = f"{model_dir}/syms.txt" language_model_params = {"language_model_weight": 1.0} use_language_model = True if use_language_model: language_model_params.update({ "language_model_path": f"{model_dir}/language_model.binary", "lexicon_path": f"{model_dir}/lexicon.txt", "tokens_path": f"{model_dir}/tokens.txt", }) common_args = CommonArgs( checkpoint="weights.ckpt", train_path=f"{model_dir}", experiment_dirname="", ) data_args = DataArgs(batch_size=batch_size, color_mode="L") trainer_args = TrainerArgs(progress_bar_refresh_rate=0) decode_args = DecodeArgs( include_img_ids=True, join_string="", convert_spaces=True, print_line_confidence_scores=True, print_word_confidence_scores=False, temperature=temperature, use_language_model=use_language_model, **language_model_params, ) with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list: image_id = uuid4() input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT)) input_img.save(f"{images}/{image_id}.jpg") Path(img_list.name).write_text("\n".join([str(image_id)])) with redirect_stdout(open(pred_stdout.name, mode="w")): decode( syms=str(syms_path), img_list=img_list.name, img_dirs=[str(images)], common=common_args, data=data_args, trainer=trainer_args, decode=decode_args, num_workers=1, ) sys.stdout.flush() predictions = Path(pred_stdout.name).read_text().strip().splitlines() _, score, text = LINE_PREDICTION.match(predictions[0]).groups() if TEXT_DIRECTION == "RTL": return input_img, {"text": get_display(text), "score": score} else: return input_img, {"text": text, "score": score} def process_image(image): # Perform inference on an image, select textline only results = model(image, classes=0) img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) boxes = results[0].boxes.xyxy.tolist() boxes.sort(key=lambda x: x[1]) bboxes = [] polygons = [] texts = [] for i, box in enumerate(boxes): x1, y1, x2, y2 = map(int, box) crop_img = img_cv2[y1:y2, x1:x2] crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)) # Recognize text using PyLaia model predicted = predict('pylaia-samaritan_v1', crop_pil) texts.append(predicted[1]["text"]) bboxes.append((x1, y1, x2, y2)) polygons.append(f"Line {i+1}: {[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]}") # Draw bounding box cv2.rectangle(img_cv2, (x1, y1), (x2, y2), (0, 255, 0), 2) # Convert image back to RGB for display in Streamlit img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) # Combine polygons and texts into a DataFrame for table display table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts}) return Image.fromarray(img_result), table_data def segment_and_recognize(image): segmented_image, table_data = process_image(image) return segmented_image, table_data # Streamlit app layout st.title("YOLOv8 Text Line Segmentation & PyLaia Text Recognition") # File uploader uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) # Process the image if uploaded if uploaded_image is not None: image = Image.open(uploaded_image) if st.button("Segment and Recognize"): # Perform segmentation and recognition segmented_image, table_data = segment_and_recognize(image) # Display the segmented image st.image(segmented_image, caption="Segmented Image with Bounding Boxes", use_column_width=True) # Display the table with polygons and recognized text st.table(table_data)