RapidOCRDemo / app.py
SWHL's picture
Update models
8d6e841
raw
history blame
6.3 kB
# -*- encoding: utf-8 -*-
import time
from pathlib import Path
import cv2
import gradio as gr
from rapidocr_onnxruntime import RapidOCR
from utils import visualize
font_dict = {
'ch': 'FZYTK.TTF',
'japan': 'japan.ttc',
'korean': 'korean.ttf',
'en': 'FZYTK.TTF'
}
def inference(img_path, box_thresh=0.5, unclip_ratio=1.6, text_score=0.5,
text_det=None, text_rec=None):
out_log_list = []
det_model_path = str(Path('models') / 'text_det' / text_det)
rec_model_path = str(Path('models') / 'text_rec' / text_rec)
if 'v2' in rec_model_path:
rec_image_shape = [3, 32, 320]
else:
rec_image_shape = [3, 48, 320]
out_log_list.append('Init Model')
s = time.time()
rapid_ocr = RapidOCR(det_model_path=det_model_path,
rec_model_path=rec_model_path,
rec_img_shape=rec_image_shape)
elapse = time.time() - s
if 'ch' in rec_model_path or 'en' in rec_model_path:
lan_name = 'ch'
elif 'japan' in rec_model_path:
lan_name = 'japan'
elif 'korean' in rec_model_path:
lan_name = 'korean'
else:
lan_name = 'ch'
out_log_list.append(f'Init Model cost: {elapse:.5f}')
out_log_list.extend([f'det_model: {det_model_path}',
f'rec_model: {rec_model_path}',
f'rec_image_shape: {rec_image_shape}'])
img = cv2.imread(img_path)
ocr_result, infer_elapse = rapid_ocr(img, box_thresh=box_thresh,
unclip_ratio=unclip_ratio,
text_score=text_score)
det_cost, cls_cost, rec_cost = infer_elapse
out_log_list.extend([f'det cost: {det_cost:.5f}',
f'cls cost: {cls_cost:.5f}',
f'rec cost: {rec_cost:.5f}'])
out_log = '\n'.join([str(v) for v in out_log_list])
if not ocr_result:
return img_path, '未识别到有效文本', out_log
dt_boxes, rec_res, scores = list(zip(*ocr_result))
font_path = Path('fonts') / font_dict.get(lan_name)
img_save_path = visualize(img_path, dt_boxes, rec_res, scores,
font_path=str(font_path))
output_text = [f'{one_rec} {float(score):.4f}'
for one_rec, score in zip(rec_res, scores)]
return img_save_path, output_text, out_log
if __name__ == '__main__':
examples = [['images/1.jpg'],
['images/ch_en_num.jpg'],
['images/air_ticket.jpg'],
['images/car_plate.jpeg'],
['images/idcard.jpg'],
['images/train_ticket.jpeg'],
['images/japan_2.jpg'],
['images/korean_1.jpg']]
with gr.Blocks(title='RapidOCR') as demo:
gr.Markdown("""
<h1><center><a href="https://github.com/RapidAI/RapidOCR" target="_blank">Rapid⚡OCR</a></center></h1>
### Docs: [Docs](https://rapidocr.rtfd.io/)
### 运行环境:
Python: 3.8 | onnxruntime: 1.14.1 | rapidocr_onnxruntime: 1.2.5""")
gr.Markdown(
'''**[超参数调节](https://github.com/RapidAI/RapidOCR/tree/main/python#configyaml%E4%B8%AD%E5%B8%B8%E7%94%A8%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D)**
- **box_thresh**: 检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0]
- **unclip_ratio**: 控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0]
- **text_score**: 文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0]
''')
with gr.Row():
box_thresh = gr.Slider(minimum=0, maximum=1.0, value=0.5,
label='box_thresh', step=0.1,
interactive=True,
info='[0, 1.0]')
unclip_ratio = gr.Slider(minimum=1.5, maximum=2.0, value=1.6,
label='unclip_ratio', step=0.1,
interactive=True,
info='[1.5, 2.0]')
text_score = gr.Slider(minimum=0, maximum=1.0, value=0.5,
label='text_score', step=0.1,
interactive=True,
info='[0, 1.0]')
gr.Markdown('**[模型选择](https://github.com/RapidAI/RapidOCR/blob/main/docs/models.md)**')
with gr.Row():
text_det = gr.Dropdown(['ch_PP-OCRv3_det_infer.onnx',
'ch_PP-OCRv2_det_infer.onnx',
'ch_ppocr_server_v2.0_det_infer.onnx'],
label='选择文本检测模型',
value='ch_PP-OCRv3_det_infer.onnx',
interactive=True)
rec_model_list = [v.name for v in Path('models/text_rec').iterdir()]
text_rec = gr.Dropdown(rec_model_list,
label='选择文本识别模型(包括中英文和多语言)',
value='ch_PP-OCRv3_rec_infer.onnx',
interactive=True)
with gr.Row():
input_img = gr.Image(type='filepath', label='Input')
out_img = gr.Image(type='filepath', label='Output')
out_log = gr.outputs.Textbox(type='text', label='Run Log')
out_txt = gr.outputs.Textbox(type='text', label='RecText')
button = gr.Button('Submit')
button.click(fn=inference,
inputs=[input_img, box_thresh, unclip_ratio, text_score,
text_det, text_rec],
outputs=[out_img, out_txt, out_log])
gr.Examples(examples=examples,
inputs=[input_img, box_thresh, unclip_ratio, text_score,
text_det, text_rec],
outputs=[out_img, out_txt, out_log], fn=inference)
demo.launch(debug=True, enable_queue=True)