EasyDetect / pipeline /tool /scene_text_model.py
sunnychenxiwang's picture
Update pipeline/tool/scene_text_model.py
09efc91 verified
raw
history blame
2.4 kB
import cv2
import numpy as np
from PIL import Image
# import sys
# sys.path.append("pipeline/mmocr")
# from mmocr.apis.inferencers import MMOCRInferencer
from pipeline.mmocr.mmocr.apis.inferencers import MMOCRInferencer
# BUILD MMOCR
class MAERec:
def __init__(self):
self.mmocr_inferencer = MMOCRInferencer(
"pipeline/mmocr/configs/textdet/dbnetpp/dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015.py",
"models/dbnetpp.pth",
"pipeline/mmocr/configs/textrecog/maerec/maerec_b_union14m.py",
"models/maerec_b.pth",)
#device="cuda:0")
def execute(self, image_path, use_detector=False):
"""Run MMOCR and SAM
Args:
img (np.ndarray): Input image
use_detector (bool, optional): Whether to use detector. Defaults to
True.
"""
data = Image.open(image_path).convert("RGB")
img = np.array(data)
if use_detector:
mode = 'det_rec'
else:
mode = 'rec'
# Build MMOCR
self.mmocr_inferencer.mode = mode
result = self.mmocr_inferencer(img, return_vis=True)
visualization = result['visualization'][0]
result = result['predictions'][0]
if mode == 'det_rec':
rec_texts = result['rec_texts']
det_polygons = result['det_polygons']
det_results = []
for rec_text, det_polygon in zip(rec_texts, det_polygons):
det_polygon = np.array(det_polygon).astype(np.int32).tolist()
det_results.append(f'{rec_text}: {det_polygon}')
out_results = '\n'.join(det_results)
visualization = cv2.cvtColor(
np.array(visualization), cv2.COLOR_RGB2BGR)
cv2.imwrite("/home/wcx/wcx/Union14M/results/{}".format(image_path.split("/")[-1]), np.array(visualization))
visualization = "Done"
else:
rec_text = result['rec_texts'][0]
rec_score = result['rec_scores'][0]
out_results = f'pred: {rec_text} \n score: {rec_score:.2f}'
visualization = None
return visualization, out_results.split("\n")[0][6:]
if __name__ == '__main__':
scene_text_model = MAERec()
vis, res = scene_text_model.execute("/newdisk3/wcx/MLLM/text-to-image/dalle3/582.jpg")
print(vis)
print(res)