omarelsayeed's picture
Update app.py
0b742a9 verified
raw
history blame
10.4 kB
from ultralytics import RTDETR
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification
# Load the LayoutLMv3 model
layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/LayoutReader80Small")
MAX_LEN = 100
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
Parse logits to determine the reading order.
"""
logits = logits[1: length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# Filter indices with length > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# Resolve conflicts
for order, idxes in order_to_idxes.items():
idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()
return ret
def get_orders(image_path, boxes):
print(boxes)
boxes = scale_and_normalize_boxes(boxes)
print(boxes)
inputs = boxes2inputs(boxes)
inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} # Move inputs to model device
logits = layout_model(**inputs).logits.cpu().squeeze(0) # Perform inference and get logits
orders = parse_logits(logits, len(boxes))
return orders
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
model = RTDETR(model_dir)
def detect_layout(img, conf_threshold, iou_threshold):
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
results = model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
show_labels=True,
show_conf=True,
imgsz=1024,
agnostic_nms= True,
max_det=34,
nms=True
)[0]
bboxes = results.boxes.xyxy.cpu().tolist()
classes = results.boxes.cls.cpu().tolist()
mapping = {0: 'CheckBox',
1: 'List',
2: 'P',
3: 'abandon',
4: 'figure',
5: 'gridless_table',
6: 'handwritten_signature',
7: 'qr_code',
8: 'table',
9: 'title'}
classes = [mapping[i] for i in classes]
return bboxes , classes
from PIL import Image, ImageDraw, ImageFont
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
# Define a color palette for classes
class_colors = {
'CheckBox': '#FFA500', # Orange
'List': '#4682B4', # Steel Blue
'P': '#32CD32', # Lime Green
'abandon': '#8A2BE2', # Blue Violet
'figure': '#00CED1', # Dark Turquoise
'gridless_table': '#FFD700', # Gold
'handwritten_signature': '#FF69B4', # Hot Pink
'qr_code': '#FF4500', # Orange Red
'table': '#8B4513', # Saddle Brown
'title': '#FF1493' # Deep Pink
}
# Open the image using PIL
image = Image.open(image_path).convert("RGBA")
overlay = Image.new("RGBA", image.size, (255, 255, 255, 0)) # Transparent overlay
draw = ImageDraw.Draw(overlay)
# Try loading a modern font, or fall back to default
try:
font = ImageFont.truetype("arial.ttf", 20)
title_font = ImageFont.truetype("arial.ttf", 30)
except IOError:
font = ImageFont.load_default()
title_font = font
# Loop through the bounding boxes and corresponding labels
for i in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[i]
class_name = classes[i]
order = reading_order[i]
color = class_colors.get(class_name, "#FFFFFF") # Default white if class is unknown
# Determine box and label styles
if class_name == 'title':
box_thickness = 6
label_font = title_font
else:
box_thickness = 3
label_font = font
# Draw a rounded rectangle for the bounding box
draw.rounded_rectangle(
[(x1, y1), (x2, y2)],
radius=10, # Rounded corners
outline=color,
width=box_thickness
)
# Create the label with the class and order
label = f"{class_name}-{order}"
text_width, text_height = draw.textsize(label, font=label_font)
# Add padding to the text background
padding = 5
label_bg_coords = [
x1, y1 - text_height - 2 * padding,
x1 + text_width + 2 * padding, y1
]
# Draw a semi-transparent rectangle for the label background
draw.rectangle(label_bg_coords, fill=(0, 0, 0, 128)) # Semi-transparent black
# Draw the label text
draw.text(
(x1 + padding, y1 - text_height - padding),
label,
fill="white",
font=label_font
)
# Merge the overlay with the original image
image_with_overlay = Image.alpha_composite(image, overlay)
# Convert back to RGB mode for saving/display
return image_with_overlay.convert("RGB")
def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000):
"""
Scales and normalizes bounding boxes from original dimensions to new dimensions.
Args:
bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format.
old_width (int or float): Width of the original image.
old_height (int or float): Height of the original image.
new_width (int or float): Width of the scaled image.
new_height (int or float): Height of the scaled image.
normalize_width (int or float): Width of the normalization range (e.g., target resolution width).
normalize_height (int or float): Height of the normalization range (e.g., target resolution height).
Returns:
list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format.
"""
scale_x = new_width / old_width
scale_y = new_height / old_height
def scale_and_normalize_single(bbox):
# Extract coordinates
x_min, y_min, x_max, y_max = bbox
# Scale to new dimensions
x_min *= scale_x
y_min *= scale_y
x_max *= scale_x
y_max *= scale_y
# Normalize to the target range
x_min = int(normalize_width * (x_min / new_width))
y_min = int(normalize_height * (y_min / new_height))
x_max = int(normalize_width * (x_max / new_width))
y_max = int(normalize_height * (y_max / new_height))
return [x_min, y_min, x_max, y_max]
# Process all bounding boxes
return [scale_and_normalize_single(bbox) for bbox in bboxes]
from PIL import Image, ImageDraw
def is_inside(box1, box2):
# Check if box1 is inside box2
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
def is_overlap(box1, box2):
# Check if box1 overlaps with box2
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
# No overlap if one box is to the left, right, above, or below the other box
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)
def remove_overlapping_and_inside_boxes(boxes, classes):
to_remove = []
for i, box1 in enumerate(boxes):
for j, box2 in enumerate(boxes):
if i != j:
if is_inside(box1, box2):
# Mark the smaller (inside) box for removal
to_remove.append(i)
elif is_inside(box2, box1):
# Mark the smaller (inside) box for removal
to_remove.append(j)
elif is_overlap(box1, box2):
# If the boxes overlap, mark the smaller one for removal
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
to_remove.append(j)
else:
to_remove.append(i)
# Remove duplicates and sort by the index to keep original boxes
to_remove = sorted(set(to_remove), reverse=True)
# Remove the boxes and their corresponding classes from the list
for idx in to_remove:
del boxes[idx]
del classes[idx]
return boxes, classes
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
orders = get_orders(IMAGE_PATH, bboxes)
final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
return final_image
iface = gr.Interface(
fn=full_predictions,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
],
outputs=gr.Image(type="pil", label="Result"),
title="Ultralytics Gradio",
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.",
examples=[
["kashida.png", 0.2, 0.45],
["image.jpg", 0.2, 0.45],
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45]
],
theme=gr.themes.Default()
)
if __name__ == "__main__":
iface.launch()