domenicCarter commited on
Commit
7013115
1 Parent(s): e57ccf6

feat: first commit

Browse files
Files changed (1) hide show
  1. app.py +89 -3
app.py CHANGED
@@ -1,7 +1,93 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from rapidocr_onnxruntime import RapidOCR
5
 
6
+ engine = RapidOCR()
 
7
 
8
+ info_points = {
9
+ "customer_name": [156, 109, 928, 168],
10
+ "amount": [157, 397, 606, 461],
11
+ "price": [155, 341, 607, 399],
12
+ "plateNumber": [740, 173, 928, 227]
13
+ }
14
+
15
+ def find_reference_points(template_image, target_image):
16
+ # OCR处理模板图像和目标图像
17
+ template_result, _ = engine(template_image)
18
+ target_result, _ = engine(target_image)
19
+
20
+ reference_points_template = []
21
+ reference_points_target = []
22
+
23
+ # 查找匹配的文本块
24
+ for template_word in template_result:
25
+ template_text = template_word[1]
26
+ template_x, template_y = template_word[0][1]
27
+
28
+ for target_word in target_result:
29
+ target_text = target_word[1]
30
+ target_x, target_y = target_word[0][1]
31
+
32
+ if template_text == target_text:
33
+ reference_points_template.append((template_x, template_y))
34
+ reference_points_target.append((target_x, target_y))
35
+ break
36
+
37
+ return np.array(reference_points_template), np.array(reference_points_target)
38
+
39
+ def align_images(template_image, target_image):
40
+ # 找到参考点
41
+ src_pts, dst_pts = find_reference_points(template_image, target_image)
42
+
43
+ if len(src_pts) < 4 or len(dst_pts) < 4:
44
+ return target_image # 如果找不到足够的参考点,返回原始图像
45
+
46
+ # 计算透视变换矩阵
47
+ M, _ = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 5.0)
48
+
49
+ # 应用透视变换
50
+ aligned_image = cv2.warpPerspective(target_image, M, (template_image.shape[1], template_image.shape[0]))
51
+
52
+ return aligned_image
53
+
54
+ def process_images(template_image, target_image):
55
+ # 将Gradio的图像格式转换为OpenCV格式
56
+ template_image = cv2.cvtColor(template_image, cv2.COLOR_RGB2BGR)
57
+ # template_image = cv2.imread("../data/template.jpg")
58
+ target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
59
+
60
+ # 对齐图像
61
+ aligned_image = align_images(template_image, target_image)
62
+
63
+ # 将结果转换回RGB格式以供Gradio显示
64
+ aligned_image = cv2.cvtColor(aligned_image, cv2.COLOR_BGR2RGB)
65
+
66
+ # 识别信息
67
+ info_dict = {}
68
+ # 在info_points中绘制矩形框
69
+ for key, value in info_points.items():
70
+ cv2.rectangle(aligned_image, (value[0], value[1]), (value[2], value[3]), (0, 255, 0), 2)
71
+ # ocr识别
72
+ ocr_result, _ = engine(aligned_image[value[1]:value[3], value[0]:value[2]])
73
+ info_dict[key] = ocr_result[0][1]
74
+
75
+ return aligned_image, info_dict
76
+
77
+ # 创建Gradio界面
78
+ demo = gr.Interface(
79
+ fn=process_images,
80
+ inputs=[
81
+ gr.Image(label="模板图像"),
82
+ gr.Image(label="目标图像")
83
+ ],
84
+ outputs=[
85
+ gr.Image(label="对齐后的图像"),
86
+ gr.Textbox(label="识别信息")
87
+ ],
88
+ title="磅单提取工具",
89
+ description="上传一张模板图像和一张目标图像,提取关键信息。"
90
+ )
91
+
92
+ # 启动Gradio应用
93
  demo.launch()