File size: 9,922 Bytes
fa50974
679d3c5
 
190bae6
fa50974
6e73f0b
 
 
 
 
 
 
 
8645db9
6e73f0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a79011
 
6e73f0b
 
5a79011
6e73f0b
 
190bae6
74c842e
679d3c5
6e73f0b
679d3c5
fa50974
679d3c5
 
 
 
 
 
 
a231a61
fa50974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6329b7d
fa50974
 
e996bd4
fa50974
e996bd4
 
 
 
 
 
 
 
 
 
fa50974
 
 
e996bd4
fa50974
e996bd4
 
 
 
fa50974
 
e996bd4
fa50974
e996bd4
 
fa50974
 
 
 
 
 
e996bd4
 
 
 
 
fa50974
e996bd4
 
fa50974
e996bd4
 
 
 
 
 
 
fa50974
 
e996bd4
 
 
 
fa50974
e996bd4
 
 
 
 
6329b7d
14db543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa50974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679d3c5
fa50974
 
679d3c5
fa50974
 
 
 
679d3c5
fa50974
14db543
 
faf74d9
5a79011
6e73f0b
 
 
 
 
 
679d3c5
14db543
679d3c5
fa50974
679d3c5
 
 
 
 
 
 
 
 
45440e8
 
63b70c6
679d3c5
2132edd
679d3c5
 
 
3ec7d0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from ultralytics import RTDETR
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
from PIL import Image, ImageDraw, ImageFont


from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification

# Load the LayoutLMv3 model
layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/test_3")

MAX_LEN = 100
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2


def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
    bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
    input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
    attention_mask = [1] + [1] * len(boxes) + [1]
    return {
        "bbox": torch.tensor([bbox]),
        "attention_mask": torch.tensor([attention_mask]),
        "input_ids": torch.tensor([input_ids]),
    }

def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
    """
    Parse logits to determine the reading order.
    """
    logits = logits[1: length + 1, :length]
    orders = logits.argsort(descending=False).tolist()
    ret = [o.pop() for o in orders]
    while True:
        order_to_idxes = defaultdict(list)
        for idx, order in enumerate(ret):
            order_to_idxes[order].append(idx)
        # Filter indices with length > 1
        order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
        if not order_to_idxes:
            break
        # Resolve conflicts
        for order, idxes in order_to_idxes.items():
            idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
            idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
            for idx, _ in idxes_to_logit[1:]:
                ret[idx] = orders[idx].pop()

    return ret

def get_orders(image_path, boxes):
    b = scale_and_normalize_boxes(boxes)
    inputs = boxes2inputs(b)
    inputs = {k: v.to(layout_model.device) for k, v in inputs.items()}  # Move inputs to model device
    logits = layout_model(**inputs).logits.cpu().squeeze(0)  # Perform inference and get logits
    orders = parse_logits(logits, len(b))
    return orders


model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
model = RTDETR(model_dir)


def detect_layout(img, conf_threshold, iou_threshold):
    """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
    results = model.predict(
        source=img,
        conf=conf_threshold,
        iou=iou_threshold,
        show_labels=True,
        show_conf=True,
        imgsz=1024,
        agnostic_nms= True,
        max_det=34,
        nms=True
    )[0]
    bboxes = results.boxes.xyxy.cpu().tolist()
    classes = results.boxes.cls.cpu().tolist()
    mapping = {0: 'CheckBox',
              1: 'List',
              2: 'P',
              3: 'abandon',
              4: 'figure',
              5: 'gridless_table',
              6: 'handwritten_signature',
              7: 'qr_code',
              8: 'table',
              9: 'title'}
    classes = [mapping[i] for i in classes]
    return bboxes , classes 
    
from PIL import Image, ImageDraw, ImageFont

def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
    # Define a color map for each class name
    class_colors = {
        'CheckBox': 'orange',
        'List': 'blue',
        'P': 'green',
        'abandon': 'purple',
        'figure': 'cyan',
        'gridless_table': 'yellow',
        'handwritten_signature': 'magenta',
        'qr_code': 'red',
        'table': 'brown',
        'title': 'pink'
    }

    # Open the image using PIL
    image = image_path
    
    # Prepare to draw on the image
    draw = ImageDraw.Draw(image)
    
    # Try loading a default font, if it fails, use a basic font
    try:
        font = ImageFont.truetype("arial.ttf", 20)
        title_font = ImageFont.truetype("arial.ttf", 30)  # Larger font for titles
    except IOError:
        font = ImageFont.load_default(size = 30)
        title_font = font  # Use the same font for title if custom font fails

    # Loop through the bounding boxes and corresponding labels
    for i in range(len(bboxes)):
        x1, y1, x2, y2 = bboxes[i]
        class_name = classes[i]
        order = reading_order[i]
        
        # Get the color for the class
        color = class_colors[class_name]
        
        # If it's a title, make the bounding box thicker and text larger
        if class_name == 'title':
            box_thickness = 4  # Thicker box for title
            label_font = title_font  # Larger font for title
        else:
            box_thickness = 2  # Default box thickness
            label_font = font  # Default font for other classes
        
        # Draw the rectangle with the class color and box thickness
        draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness)
        
        # Label the box with the class and order
        label = f"{class_name}-{order}"
        
        # Calculate text size using textbbox() to get the bounding box of the text
        bbox = draw.textbbox((x1, y1 - 20), label, font=label_font)
        label_width = bbox[2] - bbox[0]
        label_height = bbox[3] - bbox[1]
        
        # Draw the text above the box
        draw.text((x1, y1 - label_height), label, fill="black", font=label_font)
    
    # Return the modified image as a PIL image object
    return image



def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000):
    """
    Scales and normalizes bounding boxes from original dimensions to new dimensions.

    Args:
        bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format.
        old_width (int or float): Width of the original image.
        old_height (int or float): Height of the original image.
        new_width (int or float): Width of the scaled image.
        new_height (int or float): Height of the scaled image.
        normalize_width (int or float): Width of the normalization range (e.g., target resolution width).
        normalize_height (int or float): Height of the normalization range (e.g., target resolution height).

    Returns:
        list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format.
    """
    scale_x = new_width / old_width
    scale_y = new_height / old_height
    
    def scale_and_normalize_single(bbox):
        # Extract coordinates
        x_min, y_min, x_max, y_max = bbox
        
        # Scale to new dimensions
        x_min *= scale_x
        y_min *= scale_y
        x_max *= scale_x
        y_max *= scale_y
        
        # Normalize to the target range
        x_min = int(normalize_width * (x_min / new_width))
        y_min = int(normalize_height * (y_min / new_height))
        x_max = int(normalize_width * (x_max / new_width))
        y_max = int(normalize_height * (y_max / new_height))
        
        return [x_min, y_min, x_max, y_max]
    
    # Process all bounding boxes
    return [scale_and_normalize_single(bbox) for bbox in bboxes]



from PIL import Image, ImageDraw

def is_inside(box1, box2):
    # Check if box1 is inside box2
    return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]

def is_overlap(box1, box2):
    # Check if box1 overlaps with box2
    x1, y1, x2, y2 = box1
    x3, y3, x4, y4 = box2

    # No overlap if one box is to the left, right, above, or below the other box
    return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)

def remove_overlapping_and_inside_boxes(boxes, classes):
    to_remove = []
    
    for i, box1 in enumerate(boxes):
        for j, box2 in enumerate(boxes):
            if i != j:
                if is_inside(box1, box2):
                    # Mark the smaller (inside) box for removal
                    to_remove.append(i)
                elif is_inside(box2, box1):
                    # Mark the smaller (inside) box for removal
                    to_remove.append(j)
                elif is_overlap(box1, box2):
                    # If the boxes overlap, mark the smaller one for removal
                    if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
                        to_remove.append(j)
                    else:
                        to_remove.append(i)

    # Remove duplicates and sort by the index to keep original boxes
    to_remove = sorted(set(to_remove), reverse=True)

    # Remove the boxes and their corresponding classes from the list
    for idx in to_remove:
        del boxes[idx]
        del classes[idx]

    return boxes, classes


def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
    IMAGE_PATH = IMAGE_PATH.resize((1024,1024))
    bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
    bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
    orders = get_orders(IMAGE_PATH, bboxes)
    final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
    return final_image



iface = gr.Interface(
    fn=full_predictions,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
        gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
    ],
    outputs=gr.Image(type="pil", label="Result"),
    title="Ultralytics Gradio",
    description="Upload images for inference. The Ultralytics YOLO11n model is used by default.",
    examples=[
        ["kashida.png", 0.2, 0.45],
        ["image.jpg", 0.2, 0.45],
        ["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45]
    ],
    theme=gr.themes.Default()
)

if __name__ == "__main__":
    iface.launch()