File size: 3,064 Bytes
e57ccf6
7013115
 
 
e57ccf6
7013115
e57ccf6
7013115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e57ccf6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import cv2
import numpy as np
from rapidocr_onnxruntime import RapidOCR

engine = RapidOCR()

info_points = {
    "customer_name": [156, 109, 928, 168],
    "amount": [157, 397, 606, 461],
    "price": [155, 341, 607, 399],
    "plateNumber": [740, 173, 928, 227]
}

def find_reference_points(template_image, target_image):
    # OCR处理模板图像和目标图像
    template_result, _ = engine(template_image)
    target_result, _ = engine(target_image)
    
    reference_points_template = []
    reference_points_target = []
    
    # 查找匹配的文本块
    for template_word in template_result:
        template_text = template_word[1]
        template_x, template_y = template_word[0][1]
        
        for target_word in target_result:
            target_text = target_word[1]
            target_x, target_y = target_word[0][1]
            
            if template_text == target_text:
                reference_points_template.append((template_x, template_y))
                reference_points_target.append((target_x, target_y))
                break
    
    return np.array(reference_points_template), np.array(reference_points_target)

def align_images(template_image, target_image):
    # 找到参考点
    src_pts, dst_pts = find_reference_points(template_image, target_image)

    if len(src_pts) < 4 or len(dst_pts) < 4:
        return target_image  # 如果找不到足够的参考点,返回原始图像

    # 计算透视变换矩阵
    M, _ = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 5.0)
    
    # 应用透视变换
    aligned_image = cv2.warpPerspective(target_image, M, (template_image.shape[1], template_image.shape[0]))
    
    return aligned_image

def process_images(template_image, target_image):
    # 将Gradio的图像格式转换为OpenCV格式
    template_image = cv2.cvtColor(template_image, cv2.COLOR_RGB2BGR)
    # template_image = cv2.imread("../data/template.jpg")
    target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
    
    # 对齐图像
    aligned_image = align_images(template_image, target_image)
    
    # 将结果转换回RGB格式以供Gradio显示
    aligned_image = cv2.cvtColor(aligned_image, cv2.COLOR_BGR2RGB)
    
    # 识别信息
    info_dict = {}
    # 在info_points中绘制矩形框
    for key, value in info_points.items():
        cv2.rectangle(aligned_image, (value[0], value[1]), (value[2], value[3]), (0, 255, 0), 2)
        # ocr识别
        ocr_result, _ = engine(aligned_image[value[1]:value[3], value[0]:value[2]])
        info_dict[key] = ocr_result[0][1]
    
    return aligned_image, info_dict

# 创建Gradio界面
demo = gr.Interface(
    fn=process_images,
    inputs=[
        gr.Image(label="模板图像"),
        gr.Image(label="目标图像")
    ],
    outputs=[
        gr.Image(label="对齐后的图像"),
        gr.Textbox(label="识别信息")
    ],
    title="磅单提取工具",
    description="上传一张模板图像和一张目标图像,提取关键信息。"
)

# 启动Gradio应用
demo.launch()