File size: 6,596 Bytes
d948a30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
import json, os, copy
from surya.input.langs import replace_lang_with_code, get_unique_langs
from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
#load models
#line detection, layout, order
det_model = load_detection_model()
det_processor = load_detection_processor()
layout_model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
order_model = load_order_model()
order_processor = load_order_processor()
with open("languages.json", "r", encoding='utf-8') as file:
language_map = json.load(file)
def ocr_main(input_path, max_pages=None, start_page=0, langs=None, lang_file=None,
det_model=det_model, det_processor=det_processor):
assert langs or lang_file, "Must provide either langs or lang_file"
if os.path.isdir(input_path):
images, names = load_from_folder(input_path, max_pages, start_page)
else:
images, names = load_from_file(input_path, max_pages, start_page)
langs = langs.split(",")
replace_lang_with_code(langs)
image_langs = [langs] * len(images)
_, lang_tokens = _tokenize("", get_unique_langs(image_langs))
rec_model = load_recognition_model(langs=lang_tokens) # Prune model moe layer to only include languages we need
rec_processor = load_recognition_processor()
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
bboxes = [l.bbox for l in pred.text_lines]
pred_text = [l.text for l in pred.text_lines]
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs)
return page_image
def layout_main(input_path, max_pages=None,
det_model=det_model, det_processor=det_processor,
model=layout_model, processor=layout_processor):
if os.path.isdir(input_path):
images, names = load_from_folder(input_path, max_pages)
else:
images, names = load_from_file(input_path, max_pages)
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
polygons = [p.polygon for p in layout_pred.bboxes]
labels = [p.label for p in layout_pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
return bbox_image
def reading_main(input_path, max_pages=None, model=order_model, processor=order_processor,
layout_model=layout_model, layout_processor=layout_processor,
det_model=det_model, det_processor=det_processor):
if os.path.isdir(input_path):
images, names = load_from_folder(input_path, max_pages)
else:
images, names = load_from_file(input_path, max_pages)
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
bboxes = []
for layout_pred in layout_predictions:
bbox = [l.bbox for l in layout_pred.bboxes]
bboxes.append(bbox)
order_predictions = batch_ordering(images, bboxes, model, processor)
for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
polys = [l.polygon for l in order_pred.bboxes]
labels = [str(l.position) for l in order_pred.bboxes]
bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
return bbox_image
def model1(image_path, languages):
langs = ""
if languages == [] or not languages:
langs = "English"
else:
for lang in languages:
langs += f"{lang},"
langs = langs[:-1]
annotated = ocr_main(image_path, langs=langs)
return annotated
def model2(image_path):
annotated = layout_main(image_path)
return annotated
def model3(image_path):
annotated = reading_main(image_path)
return annotated
with gr.Blocks() as demo:
gr.Markdown("<center><h1>Surya - Image OCR/Layout/Reading Order</h1></center>")
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(type="filepath", label="Input Image", sources="upload")
with gr.Row():
dropdown = gr.Dropdown(label="Select Languages for OCR", choices=list(language_map.keys()), multiselect=True, value=["English"], interactive=True)
with gr.Row():
btn1 = gr.Button("OCR", variant="primary")
btn2 = gr.Button("Layout", variant="primary")
btn3 = gr.Button("Reading Order", variant="primary")
with gr.Row():
clear = gr.ClearButton()
with gr.Column():
with gr.Tabs():
with gr.TabItem("OCR"):
output_image1 = gr.Image()
with gr.TabItem("Layout"):
output_image2 = gr.Image()
with gr.TabItem("Reading Order"):
output_image3 = gr.Image()
btn1.click(fn=model1, inputs=[input_image, dropdown], outputs=output_image1)
btn2.click(fn=model2, inputs=[input_image], outputs=output_image2)
btn3.click(fn=model3, inputs=[input_image], outputs=output_image3)
clear.add(components=[input_image, output_image1, output_image2, output_image3])
demo.launch() |