from ultralytics import RTDETR import gradio as gr from huggingface_hub import snapshot_download from PIL import Image from PIL import Image, ImageDraw, ImageFont from collections import defaultdict from typing import List, Dict import torch from transformers import LayoutLMv3ForTokenClassification # Load the LayoutLMv3 model layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/LayoutReader80Small") MAX_LEN = 100 CLS_TOKEN_ID = 0 UNK_TOKEN_ID = 3 EOS_TOKEN_ID = 2 def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]: bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]] input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID] attention_mask = [1] + [1] * len(boxes) + [1] return { "bbox": torch.tensor([bbox]), "attention_mask": torch.tensor([attention_mask]), "input_ids": torch.tensor([input_ids]), } def parse_logits(logits: torch.Tensor, length: int) -> List[int]: """ Parse logits to determine the reading order. """ logits = logits[1: length + 1, :length] orders = logits.argsort(descending=False).tolist() ret = [o.pop() for o in orders] while True: order_to_idxes = defaultdict(list) for idx, order in enumerate(ret): order_to_idxes[order].append(idx) # Filter indices with length > 1 order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1} if not order_to_idxes: break # Resolve conflicts for order, idxes in order_to_idxes.items(): idxes_to_logit = {idx: logits[idx, order] for idx in idxes} idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True) for idx, _ in idxes_to_logit[1:]: ret[idx] = orders[idx].pop() return ret def get_orders(image_path, boxes): print(boxes) boxes = scale_and_normalize_boxes(boxes) print(boxes) inputs = boxes2inputs(boxes) inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} # Move inputs to model device logits = layout_model(**inputs).logits.cpu().squeeze(0) # Perform inference and get logits orders = parse_logits(logits, len(boxes)) return orders model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt" model = RTDETR(model_dir) def detect_layout(img, conf_threshold, iou_threshold): """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds.""" results = model.predict( source=img, conf=conf_threshold, iou=iou_threshold, show_labels=True, show_conf=True, imgsz=1024, agnostic_nms= True, max_det=34, nms=True )[0] bboxes = results.boxes.xyxy.cpu().tolist() classes = results.boxes.cls.cpu().tolist() mapping = {0: 'CheckBox', 1: 'List', 2: 'P', 3: 'abandon', 4: 'figure', 5: 'gridless_table', 6: 'handwritten_signature', 7: 'qr_code', 8: 'table', 9: 'title'} classes = [mapping[i] for i in classes] return bboxes , classes from PIL import Image, ImageDraw, ImageFont def draw_bboxes_on_image(image_path, bboxes, classes, reading_order): # Define a color palette for classes class_colors = { 'CheckBox': '#FFA500', # Orange 'List': '#4682B4', # Steel Blue 'P': '#32CD32', # Lime Green 'abandon': '#8A2BE2', # Blue Violet 'figure': '#00CED1', # Dark Turquoise 'gridless_table': '#FFD700', # Gold 'handwritten_signature': '#FF69B4', # Hot Pink 'qr_code': '#FF4500', # Orange Red 'table': '#8B4513', # Saddle Brown 'title': '#FF1493' # Deep Pink } # Open the image using PIL image = Image.open(image_path).convert("RGBA") overlay = Image.new("RGBA", image.size, (255, 255, 255, 0)) # Transparent overlay draw = ImageDraw.Draw(overlay) # Try loading a modern font, or fall back to default try: font = ImageFont.truetype("arial.ttf", 20) title_font = ImageFont.truetype("arial.ttf", 30) except IOError: font = ImageFont.load_default() title_font = font # Loop through the bounding boxes and corresponding labels for i in range(len(bboxes)): x1, y1, x2, y2 = bboxes[i] class_name = classes[i] order = reading_order[i] color = class_colors.get(class_name, "#FFFFFF") # Default white if class is unknown # Determine box and label styles if class_name == 'title': box_thickness = 6 label_font = title_font else: box_thickness = 3 label_font = font # Draw a rounded rectangle for the bounding box draw.rounded_rectangle( [(x1, y1), (x2, y2)], radius=10, # Rounded corners outline=color, width=box_thickness ) # Create the label with the class and order label = f"{class_name}-{order}" text_width, text_height = draw.textsize(label, font=label_font) # Add padding to the text background padding = 5 label_bg_coords = [ x1, y1 - text_height - 2 * padding, x1 + text_width + 2 * padding, y1 ] # Draw a semi-transparent rectangle for the label background draw.rectangle(label_bg_coords, fill=(0, 0, 0, 128)) # Semi-transparent black # Draw the label text draw.text( (x1 + padding, y1 - text_height - padding), label, fill="white", font=label_font ) # Merge the overlay with the original image image_with_overlay = Image.alpha_composite(image, overlay) # Convert back to RGB mode for saving/display return image_with_overlay.convert("RGB") def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000): """ Scales and normalizes bounding boxes from original dimensions to new dimensions. Args: bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format. old_width (int or float): Width of the original image. old_height (int or float): Height of the original image. new_width (int or float): Width of the scaled image. new_height (int or float): Height of the scaled image. normalize_width (int or float): Width of the normalization range (e.g., target resolution width). normalize_height (int or float): Height of the normalization range (e.g., target resolution height). Returns: list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format. """ scale_x = new_width / old_width scale_y = new_height / old_height def scale_and_normalize_single(bbox): # Extract coordinates x_min, y_min, x_max, y_max = bbox # Scale to new dimensions x_min *= scale_x y_min *= scale_y x_max *= scale_x y_max *= scale_y # Normalize to the target range x_min = int(normalize_width * (x_min / new_width)) y_min = int(normalize_height * (y_min / new_height)) x_max = int(normalize_width * (x_max / new_width)) y_max = int(normalize_height * (y_max / new_height)) return [x_min, y_min, x_max, y_max] # Process all bounding boxes return [scale_and_normalize_single(bbox) for bbox in bboxes] from PIL import Image, ImageDraw def is_inside(box1, box2): # Check if box1 is inside box2 return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] def is_overlap(box1, box2): # Check if box1 overlaps with box2 x1, y1, x2, y2 = box1 x3, y3, x4, y4 = box2 # No overlap if one box is to the left, right, above, or below the other box return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1) def remove_overlapping_and_inside_boxes(boxes, classes): to_remove = [] for i, box1 in enumerate(boxes): for j, box2 in enumerate(boxes): if i != j: if is_inside(box1, box2): # Mark the smaller (inside) box for removal to_remove.append(i) elif is_inside(box2, box1): # Mark the smaller (inside) box for removal to_remove.append(j) elif is_overlap(box1, box2): # If the boxes overlap, mark the smaller one for removal if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]): to_remove.append(j) else: to_remove.append(i) # Remove duplicates and sort by the index to keep original boxes to_remove = sorted(set(to_remove), reverse=True) # Remove the boxes and their corresponding classes from the list for idx in to_remove: del boxes[idx] del classes[idx] return boxes, classes def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold): bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold) bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes) orders = get_orders(IMAGE_PATH, bboxes) final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders) return final_image iface = gr.Interface( fn=full_predictions, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"), gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"), ], outputs=gr.Image(type="pil", label="Result"), title="Ultralytics Gradio", description="Upload images for inference. The Ultralytics YOLO11n model is used by default.", examples=[ ["kashida.png", 0.2, 0.45], ["image.jpg", 0.2, 0.45], ["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45] ], theme=gr.themes.Default() ) if __name__ == "__main__": iface.launch()