File size: 3,903 Bytes
778c8b4
 
 
 
554cbc4
778c8b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bcfe93
778c8b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
os.environ["GRADIO_TEMP_DIR"] = "./tmp"

import sys
import spaces
import torch
import torchvision
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
from visualization import visualize_bbox

# == download weights ==
model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench', local_dir='./models/DocLayout-YOLO-DocStructBench')
# == select device ==
device = 'cuda' if torch.cuda.is_available() else 'cpu'

id_to_names = {
    0: 'title', 
    1: 'plain text',
    2: 'abandon', 
    3: 'figure', 
    4: 'figure_caption', 
    5: 'table', 
    6: 'table_caption', 
    7: 'table_footnote', 
    8: 'isolate_formula', 
    9: 'formula_caption'
}

@spaces.GPU
def recognize_image(input_img, conf_threshold, iou_threshold):
    det_res = model.predict(
        input_img,
        imgsz=1024,
        conf=conf_threshold,
        device=device,
    )[0]
    boxes = det_res.__dict__['boxes'].xyxy
    classes = det_res.__dict__['boxes'].cls
    scores = det_res.__dict__['boxes'].conf

    indices = torchvision.ops.nms(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=iou_threshold)
    boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
    if len(boxes.shape) == 1:
        boxes = np.expand_dims(boxes, 0)
        scores = np.expand_dims(scores, 0)
        classes = np.expand_dims(classes, 0)

    vis_result = visualize_bbox(input_img, boxes, classes, scores, id_to_names)
    return vis_result
    
def gradio_reset():
    return gr.update(value=None), gr.update(value=None)

    
if __name__ == "__main__":
    root_path = os.path.abspath(os.getcwd())
    # == load model ==
    from doclayout_yolo import YOLOv10
    print(f"Using device: {device}")
    model = YOLOv10(os.path.join(os.path.dirname(__file__), "models", "DocLayout-YOLO-DocStructBench", "doclayout_yolo_docstructbench_imgsz1024.pt"))  # load an official model
    
    with open("header.html", "r") as file:
        header = file.read()
    with gr.Blocks() as demo:
        gr.HTML(header)
        
        with gr.Row():
            with gr.Column():
                
                input_img = gr.Image(label=" ", interactive=True)
                with gr.Row():
                    clear = gr.Button(value="Clear")
                    predict = gr.Button(value="Detect", interactive=True, variant="primary")
                    
                with gr.Row():
                    conf_threshold = gr.Slider(
                        label="Confidence Threshold",
                        minimum=0.0,
                        maximum=1.0,
                        step=0.05,
                        value=0.25,
                    )
                    
                with gr.Row():
                    iou_threshold = gr.Slider(
                        label="NMS IOU Threshold",
                        minimum=0.0,
                        maximum=1.0,
                        step=0.05,
                        value=0.45,
                    )
                    
                with gr.Accordion("Examples:"):
                    example_root = os.path.join(os.path.dirname(__file__), "assets", "example")
                    gr.Examples(
                        examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                                    _.endswith("jpg")],
                        inputs=[input_img],
                    )
            with gr.Column():
                gr.Button(value="Predict Result:", interactive=False)
                output_img = gr.Image(label=" ", interactive=False)
    
        clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
        predict.click(recognize_image, inputs=[input_img,conf_threshold,iou_threshold], outputs=[output_img])
    
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)