File size: 10,414 Bytes
fa50974 679d3c5 190bae6 fa50974 6e73f0b 0b742a9 14db543 0b742a9 6e73f0b 190bae6 74c842e 679d3c5 6e73f0b 679d3c5 fa50974 679d3c5 a231a61 fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d fa50974 6329b7d 14db543 fa50974 679d3c5 fa50974 679d3c5 fa50974 679d3c5 fa50974 14db543 faf74d9 6e73f0b 679d3c5 14db543 679d3c5 fa50974 679d3c5 45440e8 63b70c6 679d3c5 2132edd 679d3c5 3ec7d0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
from ultralytics import RTDETR
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification
# Load the LayoutLMv3 model
layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/LayoutReader80Small")
MAX_LEN = 100
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
Parse logits to determine the reading order.
"""
logits = logits[1: length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# Filter indices with length > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# Resolve conflicts
for order, idxes in order_to_idxes.items():
idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()
return ret
def get_orders(image_path, boxes):
print(boxes)
boxes = scale_and_normalize_boxes(boxes)
print(boxes)
inputs = boxes2inputs(boxes)
inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} # Move inputs to model device
logits = layout_model(**inputs).logits.cpu().squeeze(0) # Perform inference and get logits
orders = parse_logits(logits, len(boxes))
return orders
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
model = RTDETR(model_dir)
def detect_layout(img, conf_threshold, iou_threshold):
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
results = model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
show_labels=True,
show_conf=True,
imgsz=1024,
agnostic_nms= True,
max_det=34,
nms=True
)[0]
bboxes = results.boxes.xyxy.cpu().tolist()
classes = results.boxes.cls.cpu().tolist()
mapping = {0: 'CheckBox',
1: 'List',
2: 'P',
3: 'abandon',
4: 'figure',
5: 'gridless_table',
6: 'handwritten_signature',
7: 'qr_code',
8: 'table',
9: 'title'}
classes = [mapping[i] for i in classes]
return bboxes , classes
from PIL import Image, ImageDraw, ImageFont
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
# Define a color palette for classes
class_colors = {
'CheckBox': '#FFA500', # Orange
'List': '#4682B4', # Steel Blue
'P': '#32CD32', # Lime Green
'abandon': '#8A2BE2', # Blue Violet
'figure': '#00CED1', # Dark Turquoise
'gridless_table': '#FFD700', # Gold
'handwritten_signature': '#FF69B4', # Hot Pink
'qr_code': '#FF4500', # Orange Red
'table': '#8B4513', # Saddle Brown
'title': '#FF1493' # Deep Pink
}
# Open the image using PIL
image = Image.open(image_path).convert("RGBA")
overlay = Image.new("RGBA", image.size, (255, 255, 255, 0)) # Transparent overlay
draw = ImageDraw.Draw(overlay)
# Try loading a modern font, or fall back to default
try:
font = ImageFont.truetype("arial.ttf", 20)
title_font = ImageFont.truetype("arial.ttf", 30)
except IOError:
font = ImageFont.load_default()
title_font = font
# Loop through the bounding boxes and corresponding labels
for i in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[i]
class_name = classes[i]
order = reading_order[i]
color = class_colors.get(class_name, "#FFFFFF") # Default white if class is unknown
# Determine box and label styles
if class_name == 'title':
box_thickness = 6
label_font = title_font
else:
box_thickness = 3
label_font = font
# Draw a rounded rectangle for the bounding box
draw.rounded_rectangle(
[(x1, y1), (x2, y2)],
radius=10, # Rounded corners
outline=color,
width=box_thickness
)
# Create the label with the class and order
label = f"{class_name}-{order}"
text_width, text_height = draw.textsize(label, font=label_font)
# Add padding to the text background
padding = 5
label_bg_coords = [
x1, y1 - text_height - 2 * padding,
x1 + text_width + 2 * padding, y1
]
# Draw a semi-transparent rectangle for the label background
draw.rectangle(label_bg_coords, fill=(0, 0, 0, 128)) # Semi-transparent black
# Draw the label text
draw.text(
(x1 + padding, y1 - text_height - padding),
label,
fill="white",
font=label_font
)
# Merge the overlay with the original image
image_with_overlay = Image.alpha_composite(image, overlay)
# Convert back to RGB mode for saving/display
return image_with_overlay.convert("RGB")
def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000):
"""
Scales and normalizes bounding boxes from original dimensions to new dimensions.
Args:
bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format.
old_width (int or float): Width of the original image.
old_height (int or float): Height of the original image.
new_width (int or float): Width of the scaled image.
new_height (int or float): Height of the scaled image.
normalize_width (int or float): Width of the normalization range (e.g., target resolution width).
normalize_height (int or float): Height of the normalization range (e.g., target resolution height).
Returns:
list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format.
"""
scale_x = new_width / old_width
scale_y = new_height / old_height
def scale_and_normalize_single(bbox):
# Extract coordinates
x_min, y_min, x_max, y_max = bbox
# Scale to new dimensions
x_min *= scale_x
y_min *= scale_y
x_max *= scale_x
y_max *= scale_y
# Normalize to the target range
x_min = int(normalize_width * (x_min / new_width))
y_min = int(normalize_height * (y_min / new_height))
x_max = int(normalize_width * (x_max / new_width))
y_max = int(normalize_height * (y_max / new_height))
return [x_min, y_min, x_max, y_max]
# Process all bounding boxes
return [scale_and_normalize_single(bbox) for bbox in bboxes]
from PIL import Image, ImageDraw
def is_inside(box1, box2):
# Check if box1 is inside box2
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
def is_overlap(box1, box2):
# Check if box1 overlaps with box2
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
# No overlap if one box is to the left, right, above, or below the other box
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)
def remove_overlapping_and_inside_boxes(boxes, classes):
to_remove = []
for i, box1 in enumerate(boxes):
for j, box2 in enumerate(boxes):
if i != j:
if is_inside(box1, box2):
# Mark the smaller (inside) box for removal
to_remove.append(i)
elif is_inside(box2, box1):
# Mark the smaller (inside) box for removal
to_remove.append(j)
elif is_overlap(box1, box2):
# If the boxes overlap, mark the smaller one for removal
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
to_remove.append(j)
else:
to_remove.append(i)
# Remove duplicates and sort by the index to keep original boxes
to_remove = sorted(set(to_remove), reverse=True)
# Remove the boxes and their corresponding classes from the list
for idx in to_remove:
del boxes[idx]
del classes[idx]
return boxes, classes
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
orders = get_orders(IMAGE_PATH, bboxes)
final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
return final_image
iface = gr.Interface(
fn=full_predictions,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
],
outputs=gr.Image(type="pil", label="Result"),
title="Ultralytics Gradio",
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.",
examples=[
["kashida.png", 0.2, 0.45],
["image.jpg", 0.2, 0.45],
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45]
],
theme=gr.themes.Default()
)
if __name__ == "__main__":
iface.launch() |