File size: 6,596 Bytes
d948a30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import json, os, copy

from surya.input.langs import replace_lang_with_code, get_unique_langs
from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image

from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection

from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings


#load models
#line detection, layout, order
det_model = load_detection_model()
det_processor = load_detection_processor()

layout_model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)

order_model = load_order_model()
order_processor = load_order_processor()



with open("languages.json", "r", encoding='utf-8') as file:
    language_map = json.load(file)

def ocr_main(input_path, max_pages=None, start_page=0, langs=None, lang_file=None, 
            det_model=det_model, det_processor=det_processor):

    assert langs or lang_file, "Must provide either langs or lang_file"

    if os.path.isdir(input_path):
        images, names = load_from_folder(input_path, max_pages, start_page)
    else:
        images, names = load_from_file(input_path, max_pages, start_page)


    langs = langs.split(",")
    replace_lang_with_code(langs)
    image_langs = [langs] * len(images)

    _, lang_tokens = _tokenize("", get_unique_langs(image_langs))
    rec_model = load_recognition_model(langs=lang_tokens) # Prune model moe layer to only include languages we need
    rec_processor = load_recognition_processor()

    predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)

    for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
        bboxes = [l.bbox for l in pred.text_lines]
        pred_text = [l.text for l in pred.text_lines]
        page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs)
        return page_image

def layout_main(input_path, max_pages=None, 
                det_model=det_model, det_processor=det_processor, 
                model=layout_model, processor=layout_processor):

    if os.path.isdir(input_path):
        images, names = load_from_folder(input_path, max_pages)

    else:
        images, names = load_from_file(input_path, max_pages)


    line_predictions = batch_text_detection(images, det_model, det_processor)

    layout_predictions = batch_layout_detection(images, model, processor, line_predictions)


    for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
        polygons = [p.polygon for p in layout_pred.bboxes]
        labels = [p.label for p in layout_pred.bboxes]
        bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
        return bbox_image    

def reading_main(input_path, max_pages=None, model=order_model, processor=order_processor, 
                layout_model=layout_model, layout_processor=layout_processor, 
                det_model=det_model, det_processor=det_processor):

    if os.path.isdir(input_path):
        images, names = load_from_folder(input_path, max_pages)

    else:
        images, names = load_from_file(input_path, max_pages)


    line_predictions = batch_text_detection(images, det_model, det_processor)
    layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
    bboxes = []
    for layout_pred in layout_predictions:
        bbox = [l.bbox for l in layout_pred.bboxes]
        bboxes.append(bbox)

    order_predictions = batch_ordering(images, bboxes, model, processor)

    for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
        polys = [l.polygon for l in order_pred.bboxes]
        labels = [str(l.position) for l in order_pred.bboxes]
        bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
        return bbox_image




def model1(image_path, languages):
    langs = ""
    if languages == [] or not languages:
        langs = "English"
    else:
        for lang in languages:
            langs += f"{lang},"
        langs = langs[:-1]
        
    annotated = ocr_main(image_path, langs=langs)
    return annotated

def model2(image_path):

    annotated = layout_main(image_path)
    return annotated

def model3(image_path):

    annotated = reading_main(image_path)
    return annotated


with gr.Blocks() as demo:
    gr.Markdown("<center><h1>Surya - Image OCR/Layout/Reading Order</h1></center>")
    
    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(type="filepath", label="Input Image", sources="upload")
            with gr.Row():
                dropdown = gr.Dropdown(label="Select Languages for OCR", choices=list(language_map.keys()), multiselect=True, value=["English"], interactive=True)
            with gr.Row():
                btn1 = gr.Button("OCR", variant="primary")
                btn2 = gr.Button("Layout", variant="primary")
                btn3 = gr.Button("Reading Order", variant="primary")
            with gr.Row():
                clear = gr.ClearButton()

        with gr.Column():
            with gr.Tabs():
                with gr.TabItem("OCR"):
                    output_image1 = gr.Image()
                with gr.TabItem("Layout"):
                    output_image2 = gr.Image()
                with gr.TabItem("Reading Order"):
                    output_image3 = gr.Image()

    btn1.click(fn=model1, inputs=[input_image, dropdown], outputs=output_image1)
    btn2.click(fn=model2, inputs=[input_image], outputs=output_image2)
    btn3.click(fn=model3, inputs=[input_image], outputs=output_image3)
    clear.add(components=[input_image, output_image1, output_image2, output_image3])

demo.launch()