File size: 6,881 Bytes
5d6a0bb
d80ba6a
 
5d6a0bb
00e3b6c
5d6a0bb
 
d80ba6a
 
 
 
00e3b6c
d80ba6a
5d6a0bb
8d6e841
5d6a0bb
8d6e841
d80ba6a
 
 
 
8d6e841
5d6a0bb
9792e33
d80ba6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8803a47
 
 
 
 
7e03ca9
 
 
 
d80ba6a
 
 
 
 
 
 
 
 
 
 
 
8d6e841
d80ba6a
8d6e841
d80ba6a
 
 
 
8803a47
7e03ca9
19a047e
d80ba6a
00e3b6c
d80ba6a
 
 
 
 
 
 
 
d7642ea
d80ba6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508fdf6
d80ba6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3beaa28
d80ba6a
 
 
 
 
 
 
 
8803a47
 
 
 
 
 
 
 
 
 
 
 
e42ad77
 
1aa38b6
e42ad77
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import time
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
from rapidocr_onnxruntime import RapidOCR
from streamlit_image_select import image_select

from utils import visualize

font_dict = {
    "ch": "chinese_cht.ttf",
    "japan": "japan.ttc",
    "korean": "korean.ttf",
    "en": "chinese_cht.ttf",
}


def init_sidebar():
    st.session_state["params"] = {}

    st.sidebar.markdown(
        "### [🛠️ Parameter Settings](https://github.com/RapidAI/RapidOCR/wiki/config_parameter)"
    )
    box_thresh = st.sidebar.slider(
        "box_thresh",
        min_value=0.0,
        max_value=1.0,
        value=0.5,
        step=0.1,
        help="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
    )
    st.session_state["params"]["box_thresh"] = box_thresh

    unclip_ratio = st.sidebar.slider(
        "unclip_ratio",
        min_value=1.5,
        max_value=2.0,
        value=1.6,
        step=0.1,
        help="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6",
    )
    st.session_state["params"]["unclip_ratio"] = unclip_ratio

    text_score = st.sidebar.slider(
        "text_score",
        min_value=0.0,
        max_value=1.0,
        value=0.5,
        step=0.1,
        help="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
    )
    st.session_state["params"]["text_score"] = text_score

    with st.sidebar.container():
        img_path = image_select(
            label="Examples(click to select):",
            images=examples,
            key="equation_default",
            use_container_width=True,
        )
        img = cv2.imread(img_path)
    st.session_state["img"] = img


def inference(
    text_det=None,
    text_rec=None,
):
    img = st.session_state.get("img")
    box_thresh = st.session_state["params"].get("box_thresh")
    unclip_ratio = st.session_state["params"].get("unclip_ratio")
    text_score = st.session_state["params"].get("text_score")

    det_model_path = str(Path("models") / "text_det" / text_det)
    rec_model_path = str(Path("models") / "text_rec" / text_rec)
    if (
        "v2" in rec_model_path
        or "korean" in rec_model_path
        or "japan" in rec_model_path
    ):
        rec_image_shape = [3, 32, 320]
    else:
        rec_image_shape = [3, 48, 320]

    rapid_ocr = RapidOCR(
        det_model_path=det_model_path,
        rec_model_path=rec_model_path,
        rec_img_shape=rec_image_shape,
    )

    if "ch" in rec_model_path or "en" in rec_model_path:
        lan_name = "ch"
    elif "japan" in rec_model_path:
        lan_name = "japan"
    elif "korean" in rec_model_path:
        lan_name = "korean"
    else:
        lan_name = "ch"

    ocr_result, infer_elapse = rapid_ocr(
        img, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score
    )
    if not ocr_result or not infer_elapse:
        return None, None, None

    det_cost, cls_cost, rec_cost = infer_elapse
    elapse = f"- `det cost`: {det_cost:.5f}\n - `cls cost`: {cls_cost:.5f}\n - `rec cost`: {rec_cost:.5f}"
    dt_boxes, rec_res, scores = list(zip(*ocr_result))
    font_path = Path("fonts") / font_dict.get(lan_name)
    vis_img = visualize(
        Image.fromarray(img), dt_boxes, rec_res, scores, font_path=str(font_path)
    )
    out_df = pd.DataFrame(
        [[rec, score] for rec, score in zip(rec_res, scores)],
        columns=("Rec", "Score"),
    )
    return vis_img, out_df, elapse


def tips(txt: str, wait_time: int = 2, icon: str = "🎉"):
    st.toast(txt, icon=icon)
    time.sleep(wait_time)


if __name__ == "__main__":
    st.markdown(
        "<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidOCR' style='text-decoration: none'>Rapid⚡OCR</a></h1>",
        unsafe_allow_html=True,
    )
    st.markdown(
        """
    <p align="left">
        <a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.12-aff.svg"></a>
        <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
        <a href="https://pepy.tech/project/rapidocr_onnxruntime"><img src="https://static.pepy.tech/personalized-badge/rapidocr_onnxruntime?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Ort"></a>
        <a href="https://pypi.org/project/rapidocr-onnxruntime/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr-onnxruntime"></a>
    </p>
    """,
        unsafe_allow_html=True,
    )

    examples = [
        "images/1.jpg",
        "images/ch_en_num.jpg",
        "images/air_ticket.jpg",
        "images/car_plate.jpeg",
        "images/train_ticket.jpeg",
        "images/japan_2.jpg",
        "images/korean_1.jpg",
    ]

    init_sidebar()

    menu_det, menu_rec = st.columns([1, 1])
    det_models = [
        "ch_PP-OCRv4_det_infer.onnx",
        "ch_PP-OCRv3_det_infer.onnx",
        "ch_PP-OCRv2_det_infer.onnx",
        "ch_ppocr_server_v2.0_det_infer.onnx",
    ]
    select_det = menu_det.selectbox("Det model:", det_models)

    rec_models = [
        "ch_PP-OCRv4_rec_infer.onnx",
        "ch_PP-OCRv3_rec_infer.onnx",
        "ch_PP-OCRv2_rec_infer.onnx",
        "ch_PP-OCRv4_det_server_infer.onnx",
        "ch_ppocr_server_v2.0_rec_infer.onnx",
        "en_PP-OCRv3_rec_infer.onnx",
        "en_number_mobile_v2.0_rec_infer.onnx",
        "korean_mobile_v2.0_rec_infer.onnx",
        "japan_rec_crnn_v2.onnx",
    ]
    select_rec = menu_rec.selectbox("Rec model:", rec_models)

    with st.form("my-form", clear_on_submit=True):
        img_file_buffer = st.file_uploader(
            "Upload an image",
            accept_multiple_files=False,
            label_visibility="visible",
            type=["png", "jpg", "jpeg", "bmp"],
        )
        submit = st.form_submit_button("Upload")
        if submit and img_file_buffer is not None:
            image = Image.open(img_file_buffer)
            img = np.array(image)
            st.session_state["img"] = img

    if st.session_state["img"] is not None:
        out_img, out_json, elapse = inference(select_det, select_rec)
        if all(v is not None for v in [out_img, out_json, elapse]):
            st.markdown("#### Visualize:")
            st.image(out_img)

            st.markdown("### Rec Result:")
            st.markdown(elapse)
            st.dataframe(out_json, use_container_width=True)
        else:
            tips("识别结果为空", wait_time=5, icon="⚠️")