File size: 6,047 Bytes
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import cv2
import argparse
import numpy as np
from PIL import Image
import sys
sys.path.append("/home/wcx/wcx/GroundingDINO/LVLM/mmocr")
# MMOCR
from mmocr.apis.inferencers import MMOCRInferencer

# BUILD MMOCR


def arg_parse():
    parser = argparse.ArgumentParser(description='MMOCR demo for gradio app')
    parser.add_argument(
        '--rec_config',
        type=str,
        default='/home/wcx/wcx/GroundingDINO/LVLM/mmocr/configs/textrecog/maerec/maerec_b_union14m.py',
        help='The recognition config file.')
    parser.add_argument(
        '--rec_weight',
        type=str,
        default=
        '/newdisk3/wcx/ocr_model/maerec_b.pth',
        help='The recognition weight file.')
    parser.add_argument(
        '--det_config',
        type=str,
        default='/home/wcx/wcx/GroundingDINO/LVLM/mmocr/configs/textdet/dbnetpp/dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015.py',  # noqa,
        help='The detection config file.')
    parser.add_argument(
        '--det_weight',
        type=str,
        default='/newdisk3/wcx/ocr_model/dbnetpp.pth',
        help='The detection weight file.')
    parser.add_argument(
        '--device',
        type=str,
        default='cuda:0',
        help='The device used for inference.')
    args = parser.parse_args()
    return args

args = arg_parse()
mmocr_inferencer = MMOCRInferencer(
        args.det_config,
        args.det_weight,
        args.rec_config,
        args.rec_weight,
        device=args.device)

def run_mmocr(image_path, use_detector=False):
    """Run MMOCR and SAM

    Args:
        img (np.ndarray): Input image
        use_detector (bool, optional): Whether to use detector. Defaults to
            True.
    """
    data = Image.open(image_path).convert("RGB")
    img = np.array(data)
    if use_detector:
        mode = 'det_rec'
    else:
        mode = 'rec'
    # Build MMOCR
    mmocr_inferencer.mode = mode
    result = mmocr_inferencer(img, return_vis=True)
    visualization = result['visualization'][0]
    result = result['predictions'][0]

    if mode == 'det_rec':
        rec_texts = result['rec_texts']
        det_polygons = result['det_polygons']
        det_results = []
        for rec_text, det_polygon in zip(rec_texts, det_polygons):
            det_polygon = np.array(det_polygon).astype(np.int32).tolist()
            det_results.append(f'{rec_text}: {det_polygon}')
        out_results = '\n'.join(det_results)
        # visualization = cv2.cvtColor(
        #     np.array(visualization), cv2.COLOR_RGB2BGR)
        cv2.imwrite("/home/wcx/wcx/Union14M/results/{}".format(image_path.split("/")[-1]), np.array(visualization))
        visualization = "Done"
    else:
        rec_text = result['rec_texts'][0]
        rec_score = result['rec_scores'][0]
        out_results = f'pred: {rec_text} \n score: {rec_score:.2f}'
        visualization = None
    return visualization, out_results

image_path = "/home/wcx/wcx/Union14M/image/temp.jpg"
vis, res = run_mmocr(image_path)
print(vis)
print(res)
# if __name__ == '__main__':
#     args = arg_parse()
#     mmocr_inferencer = MMOCRInferencer(
#         args.det_config,
#         args.det_weight,
#         args.rec_config,
#         args.rec_weight,
#         device=args.device)
    
    

    # with gr.Blocks() as demo:
    #     with gr.Row():
    #         with gr.Column(scale=1):
    #             gr.HTML("""
    #                 <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
    #                 <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
    #                     MAERec: A MAE-pretrained Scene Text Recognizer
    #                 </h1>
    #                 <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem"> 
    #                 [<a href="https://arxiv.org/abs/2305.10855" style="color:blue;">arXiv</a>] 
    #                 [<a href="https://github.com/Mountchicken/Union14M" style="color:green;">Code</a>]
    #                 </h3>
    #                 <h2 style="text-align: left; font-weight: 600; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
    #                 MAERec is a scene text recognition model composed of a ViT backbone and a Transformer decoder in auto-regressive
    #                 style. It shows an outstanding performance in scene text recognition, especially when pre-trained on the
    #                 Union14M-U through MAE.
    #                 </h2>
    #                 <h2 style="text-align: left; font-weight: 600; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
    #                 In this demo, we combine MAERec with DBNet++ to build an
    #                 end-to-end scene text recognition model.
    #                 </h2>
    #                 </div>
    #                 """)
    #             gr.Image('github/maerec.png')
    #         with gr.Column(scale=1):
    #             input_image = gr.Image(label='Input Image')
    #             output_image = gr.Image(label='Output Image')
    #             use_detector = gr.Checkbox(
    #                 label=
    #                 'Use Scene Text Detector or Not (Disabled for Recognition Only)',
    #                 default=True)
    #             det_results = gr.Textbox(label='Detection Results')
    #             mmocr = gr.Button('Run MMOCR')
    #             gr.Markdown("## Image Examples")
    #     with gr.Row():
    #         gr.Examples(
    #             examples=[
    #                 'github/author.jpg', 'github/gradio1.jpeg',
    #                 'github/Art_Curve_178.jpg', 'github/cute_3.jpg',
    #                 'github/cute_168.jpg', 'github/hiercurve_2229.jpg',
    #                 'github/ic15_52.jpg', 'github/ic15_698.jpg',
    #                 'github/Art_Curve_352.jpg'
    #             ],
    #             inputs=input_image,
    #         )
    #     mmocr.click(
    #         fn=run_mmocr,
    #         inputs=[input_image, use_detector],
    #         outputs=[output_image, det_results])
    # demo.launch(debug=True)