omarelsayeed's picture
Update app.py
0f127fa verified
raw
history blame
11.1 kB
from ultralytics import RTDETR
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification
from transformers import AutoProcessor
from transformers import AutoModelForTokenClassification
quant=True
if quant:
finetuned_fully = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/ArabicLayoutReader_4bit" , load_in_4bit=True)
else:
finetuned_fully = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/YARAB_FOK_ELDE2A")
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base",
apply_ocr=False)
MAX_LEN = 70
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2
import torch
def boxes2inputs(boxes):
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}
def parse_logits(logits: torch.Tensor, length):
"""
parse logits to orders
:param logits: logits from model
:param length: input length
:return: orders
"""
logits = logits[1 : length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# filter idxes len > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# filter
for order, idxes in order_to_idxes.items():
# find original logits of idxes
idxes_to_logit = {}
for idx in idxes:
idxes_to_logit[idx] = logits[idx, order]
idxes_to_logit = sorted(
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
)
# keep the highest logit as order, set others to next candidate
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()
return ret
def prepare_inputs(
inputs, model
):
ret = {}
for k, v in inputs.items():
v = v.to(model.device)
if torch.is_floating_point(v):
v = v.to(model.dtype)
ret[k] = v
return ret
def get_orders(image_path , boxes):
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, finetuned_fully)
logits = finetuned_fully(**inputs).logits.cpu().squeeze(0)
predictions = parse_logits(logits, len(boxes))
return predictions
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
model = RTDETR(model_dir)
def detect_layout(img, conf_threshold, iou_threshold):
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
results = model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
show_labels=True,
show_conf=True,
imgsz=1024,
agnostic_nms= True,
max_det=34,
nms=True
)[0]
bboxes = results.boxes.xyxy.cpu().tolist()
classes = results.boxes.cls.cpu().tolist()
mapping = {0: 'CheckBox',
1: 'List',
2: 'P',
3: 'abandon',
4: 'figure',
5: 'gridless_table',
6: 'handwritten_signature',
7: 'qr_code',
8: 'table',
9: 'title'}
classes = [mapping[i] for i in classes]
return bboxes , classes
from PIL import Image, ImageDraw, ImageFont
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
# Define a color map for each class name
class_colors = {
'CheckBox': 'orange',
'List': 'blue',
'P': 'green',
'abandon': 'purple',
'figure': 'cyan',
'gridless_table': 'yellow',
'handwritten_signature': 'magenta',
'qr_code': 'red',
'table': 'brown',
'title': 'pink'
}
# Open the image using PIL
image = image_path
# Prepare to draw on the image
draw = ImageDraw.Draw(image)
# Try loading a default font, if it fails, use a basic font
try:
font = ImageFont.truetype("arial.ttf", 20)
title_font = ImageFont.truetype("arial.ttf", 30) # Larger font for titles
except IOError:
font = ImageFont.load_default(size = 30)
title_font = font # Use the same font for title if custom font fails
# Loop through the bounding boxes and corresponding labels
for i in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[i]
class_name = classes[i]
order = reading_order[i]
# Get the color for the class
color = class_colors[class_name]
# If it's a title, make the bounding box thicker and text larger
if class_name == 'title':
box_thickness = 4 # Thicker box for title
label_font = title_font # Larger font for title
else:
box_thickness = 2 # Default box thickness
label_font = font # Default font for other classes
# Draw the rectangle with the class color and box thickness
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness)
# Label the box with the class and order
label = f"{class_name}-{order}"
# Calculate text size using textbbox() to get the bounding box of the text
bbox = draw.textbbox((x1, y1 - 20), label, font=label_font)
label_width = bbox[2] - bbox[0]
label_height = bbox[3] - bbox[1]
# Draw the text above the box
draw.text((x1, y1 - label_height), label, fill="black", font=label_font)
# Return the modified image as a PIL image object
return image
def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000):
"""
Scales and normalizes bounding boxes from original dimensions to new dimensions.
Args:
bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format.
old_width (int or float): Width of the original image.
old_height (int or float): Height of the original image.
new_width (int or float): Width of the scaled image.
new_height (int or float): Height of the scaled image.
normalize_width (int or float): Width of the normalization range (e.g., target resolution width).
normalize_height (int or float): Height of the normalization range (e.g., target resolution height).
Returns:
list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format.
"""
scale_x = new_width / old_width
scale_y = new_height / old_height
def scale_and_normalize_single(bbox):
# Extract coordinates
x_min, y_min, x_max, y_max = bbox
# Scale to new dimensions
x_min *= scale_x
y_min *= scale_y
x_max *= scale_x
y_max *= scale_y
# Normalize to the target range
x_min = int(normalize_width * (x_min / new_width))
y_min = int(normalize_height * (y_min / new_height))
x_max = int(normalize_width * (x_max / new_width))
y_max = int(normalize_height * (y_max / new_height))
return [x_min, y_min, x_max, y_max]
# Process all bounding boxes
return [scale_and_normalize_single(bbox) for bbox in bboxes]
from PIL import Image, ImageDraw
def is_inside(box1, box2):
# Check if box1 is inside box2
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
def is_overlap(box1, box2):
# Check if box1 overlaps with box2
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
# No overlap if one box is to the left, right, above, or below the other box
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)
def remove_overlapping_and_inside_boxes(boxes, classes):
to_remove = []
for i, box1 in enumerate(boxes):
for j, box2 in enumerate(boxes):
if i != j:
if is_inside(box1, box2):
# Mark the smaller (inside) box for removal
to_remove.append(i)
elif is_inside(box2, box1):
# Mark the smaller (inside) box for removal
to_remove.append(j)
elif is_overlap(box1, box2):
# If the boxes overlap, mark the smaller one for removal
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
to_remove.append(j)
else:
to_remove.append(i)
# Remove duplicates and sort by the index to keep original boxes
to_remove = sorted(set(to_remove), reverse=True)
# Remove the boxes and their corresponding classes from the list
for idx in to_remove:
del boxes[idx]
del classes[idx]
return boxes, classes
def process_r1(r1):
# Step 1: Find the index of the maximum value
if len(r1) == 2:
return r1
max_index = r1.index(max(r1))
one_index = r1.index(1)
# Step 2: Swap the maximum value with 1
r1[max_index] = 1
# Step 3: Increment all values except 0 and 1
r1 = [x + 1 if x not in (0, 1) else x for x in r1]
r1[one_index] +=1
return r1
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
IMAGE_PATH = IMAGE_PATH.resize((1024,1024))
bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
orders = get_orders(IMAGE_PATH, scale_and_normalize_boxes(bboxes))
orders = process_r1(orders)
final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
return final_image
iface = gr.Interface(
fn=full_predictions,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
],
outputs=gr.Image(type="pil", label="Result"),
title="Ultralytics Gradio",
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.",
examples=[
["kashida.png", 0.2, 0.45],
["image.jpg", 0.2, 0.45],
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45]
],
theme=gr.themes.Default()
)
if __name__ == "__main__":
iface.launch()