from ultralytics import RTDETR import gradio as gr from huggingface_hub import snapshot_download from PIL import Image, ImageDraw, ImageFont import numpy as np import random from collections import defaultdict from typing import List, Dict import torch from transformers import LayoutLMv3ForTokenClassification from transformers import AutoProcessor from transformers import AutoModelForTokenClassification finetuned_fully = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/finetuned_pretrained_model") processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) MAX_LEN = 70 CLS_TOKEN_ID = 0 UNK_TOKEN_ID = 3 EOS_TOKEN_ID = 2 import torch def boxes2inputs(boxes): bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]] input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID] attention_mask = [1] + [1] * len(boxes) + [1] return { "bbox": torch.tensor([bbox]), "attention_mask": torch.tensor([attention_mask]), "input_ids": torch.tensor([input_ids]), } def parse_logits(logits: torch.Tensor, length): """ parse logits to orders :param logits: logits from model :param length: input length :return: orders """ logits = logits[1 : length + 1, :length] orders = logits.argsort(descending=False).tolist() ret = [o.pop() for o in orders] while True: order_to_idxes = defaultdict(list) for idx, order in enumerate(ret): order_to_idxes[order].append(idx) # filter idxes len > 1 order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1} if not order_to_idxes: break # filter for order, idxes in order_to_idxes.items(): # find original logits of idxes idxes_to_logit = {} for idx in idxes: idxes_to_logit[idx] = logits[idx, order] idxes_to_logit = sorted( idxes_to_logit.items(), key=lambda x: x[1], reverse=True ) # keep the highest logit as order, set others to next candidate for idx, _ in idxes_to_logit[1:]: ret[idx] = orders[idx].pop() return ret def prepare_inputs( inputs, model ): ret = {} for k, v in inputs.items(): v = v.to(model.device) if torch.is_floating_point(v): v = v.to(model.dtype) ret[k] = v return ret def get_orders(image_path , boxes): inputs = boxes2inputs(boxes) inputs = prepare_inputs(inputs, finetuned_fully) logits = finetuned_fully(**inputs).logits.cpu().squeeze(0) predictions = parse_logits(logits, len(boxes)) return predictions model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt" model = RTDETR(model_dir) def detect_layout(img, conf_threshold, iou_threshold): """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds.""" results = model.predict( source=img, conf=conf_threshold, iou=iou_threshold, show_labels=True, show_conf=True, imgsz=1024, agnostic_nms= True, max_det=34, nms=True )[0] bboxes = results.boxes.xyxy.cpu().tolist() classes = results.boxes.cls.cpu().tolist() mapping = {0: 'CheckBox', 1: 'List', 2: 'P', 3: 'abandon', 4: 'figure', 5: 'gridless_table', 6: 'handwritten_signature', 7: 'qr_code', 8: 'table', 9: 'title'} classes = [mapping[i] for i in classes] return bboxes , classes from PIL import Image, ImageDraw, ImageFont def draw_bboxes_on_image(image_path, bboxes, classes, reading_order): # Define a color map for each class name class_colors = { 'CheckBox': 'orange', 'List': 'blue', 'P': 'green', 'abandon': 'purple', 'figure': 'cyan', 'gridless_table': 'yellow', 'handwritten_signature': 'magenta', 'qr_code': 'red', 'table': 'brown', 'title': 'pink' } # Open the image using PIL image = image_path # Prepare to draw on the image draw = ImageDraw.Draw(image) # Try loading a default font, if it fails, use a basic font try: font = ImageFont.truetype("arial.ttf", 20) title_font = ImageFont.truetype("arial.ttf", 30) # Larger font for titles except IOError: font = ImageFont.load_default(size = 30) title_font = font # Use the same font for title if custom font fails # Loop through the bounding boxes and corresponding labels for i in range(len(bboxes)): x1, y1, x2, y2 = bboxes[i] class_name = classes[i] order = reading_order[i] # Get the color for the class color = class_colors[class_name] # If it's a title, make the bounding box thicker and text larger if class_name == 'title': box_thickness = 4 # Thicker box for title label_font = title_font # Larger font for title else: box_thickness = 2 # Default box thickness label_font = font # Default font for other classes # Draw the rectangle with the class color and box thickness draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness) # Label the box with the class and order label = f"{class_name}-{order}" # Calculate text size using textbbox() to get the bounding box of the text bbox = draw.textbbox((x1, y1 - 20), label, font=label_font) label_width = bbox[2] - bbox[0] label_height = bbox[3] - bbox[1] # Draw the text above the box draw.text((x1, y1 - label_height), label, fill="black", font=label_font) # Return the modified image as a PIL image object return image def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000): """ Scales and normalizes bounding boxes from original dimensions to new dimensions. Args: bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format. old_width (int or float): Width of the original image. old_height (int or float): Height of the original image. new_width (int or float): Width of the scaled image. new_height (int or float): Height of the scaled image. normalize_width (int or float): Width of the normalization range (e.g., target resolution width). normalize_height (int or float): Height of the normalization range (e.g., target resolution height). Returns: list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format. """ scale_x = new_width / old_width scale_y = new_height / old_height def scale_and_normalize_single(bbox): # Extract coordinates x_min, y_min, x_max, y_max = bbox # Scale to new dimensions x_min *= scale_x y_min *= scale_y x_max *= scale_x y_max *= scale_y # Normalize to the target range x_min = int(normalize_width * (x_min / new_width)) y_min = int(normalize_height * (y_min / new_height)) x_max = int(normalize_width * (x_max / new_width)) y_max = int(normalize_height * (y_max / new_height)) return [x_min, y_min, x_max, y_max] # Process all bounding boxes return [scale_and_normalize_single(bbox) for bbox in bboxes] from PIL import Image, ImageDraw def is_inside(box1, box2): # Check if box1 is inside box2 return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] def is_overlap(box1, box2): # Check if box1 overlaps with box2 x1, y1, x2, y2 = box1 x3, y3, x4, y4 = box2 # No overlap if one box is to the left, right, above, or below the other box return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1) def remove_overlapping_and_inside_boxes(boxes, classes): to_remove = [] for i, box1 in enumerate(boxes): for j, box2 in enumerate(boxes): if i != j: if is_inside(box1, box2): # Mark the smaller (inside) box for removal to_remove.append(i) elif is_inside(box2, box1): # Mark the smaller (inside) box for removal to_remove.append(j) elif is_overlap(box1, box2): # If the boxes overlap, mark the smaller one for removal if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]): to_remove.append(j) else: to_remove.append(i) # Remove duplicates and sort by the index to keep original boxes to_remove = sorted(set(to_remove), reverse=True) # Remove the boxes and their corresponding classes from the list for idx in to_remove: del boxes[idx] del classes[idx] return boxes, classes def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold): IMAGE_PATH = IMAGE_PATH.resize((1024,1024)) bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold) bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes) orders = get_orders(IMAGE_PATH, bboxes) final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders) return final_image iface = gr.Interface( fn=full_predictions, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"), gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"), ], outputs=gr.Image(type="pil", label="Result"), title="Ultralytics Gradio", description="Upload images for inference. The Ultralytics YOLO11n model is used by default.", examples=[ ["kashida.png", 0.2, 0.45], ["image.jpg", 0.2, 0.45], ["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45] ], theme=gr.themes.Default() ) if __name__ == "__main__": iface.launch()