jianzhichun commited on
Commit
6138e21
·
1 Parent(s): af3701b
Files changed (3) hide show
  1. app.py +218 -0
  2. best.pt +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from openai import OpenAI
6
+ from PIL import Image
7
+ import base64
8
+ import io
9
+
10
+ client = OpenAI(api_key="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")
11
+
12
+ # import easyocr
13
+
14
+ # ----------------------
15
+ # Step 1: 图像预处理
16
+ # ----------------------
17
+ def preprocess_image(image):
18
+ """对输入图像进行对比度增强、去噪、边缘检测等预处理。"""
19
+ gray_image = convert_to_grayscale(image)
20
+ contrast_enhanced = enhance_contrast(gray_image)
21
+ blurred_image = apply_gaussian_blur(contrast_enhanced)
22
+ sobel_normalized = detect_edges(blurred_image)
23
+ binary_image = threshold_image(sobel_normalized)
24
+ return gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image
25
+
26
+ def convert_to_grayscale(image):
27
+ return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
28
+
29
+ def enhance_contrast(gray_image):
30
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
31
+ return clahe.apply(gray_image)
32
+
33
+ def apply_gaussian_blur(image):
34
+ return cv2.GaussianBlur(image, (5, 5), 0)
35
+
36
+ def detect_edges(image):
37
+ sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
38
+ sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
39
+ sobel = cv2.magnitude(sobel_x, sobel_y)
40
+ return cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
41
+
42
+ def threshold_image(image):
43
+ _, binary_image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
44
+ return binary_image
45
+
46
+ # ----------------------
47
+ # Step 2: YOLO 检测并提取区域
48
+ # ----------------------
49
+ def detect_with_yolo(image, model_path='best.pt', conf_threshold=0.5):
50
+ """使用 YOLO 模型对图像进行目标检测。"""
51
+ model = load_yolo_model(model_path, conf_threshold)
52
+ return perform_yolo_detection(model, image)
53
+
54
+ def load_yolo_model(model_path, conf_threshold):
55
+ model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, force_reload=True)
56
+ model.conf = conf_threshold
57
+ return model
58
+
59
+ def perform_yolo_detection(model, image):
60
+ results = model(image)
61
+ detections = results.pandas().xyxy[0]
62
+ annotated_image, detected_regions = draw_yolo_detections(image, detections, model.names)
63
+ return annotated_image, detected_regions
64
+
65
+ def draw_yolo_detections(image, detections, class_names):
66
+ """在图像上绘制 YOLO 检测结果,并返回检测区域。"""
67
+ annotated_image = image.copy()
68
+
69
+ if len(annotated_image.shape) == 2 or annotated_image.shape[2] == 1:
70
+ annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2BGR)
71
+
72
+ boxes = []
73
+ for _, row in detections.iterrows():
74
+ x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax'])
75
+ conf, cls = row['confidence'], int(row['class'])
76
+ label = f"{class_names[cls]} {conf:.2f}"
77
+
78
+ # 提取检测区域
79
+ region = annotated_image[y1:y2, x1:x2]
80
+ boxes.append({"box": (x1, y1, x2, y2), "class": class_names[cls], "confidence": conf})
81
+
82
+ # 绘制检测框和标签
83
+ cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
84
+ cv2.putText(annotated_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
85
+
86
+ annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
87
+ return annotated_image, boxes
88
+
89
+ def crop_regions_from_boxes(original_image, boxes):
90
+ """从原始图像中裁剪出 YOLO 检测的区域。"""
91
+ regions = []
92
+ for box_info in boxes:
93
+ x1, y1, x2, y2 = box_info["box"]
94
+ cropped_region = original_image[y1:y2, x1:x2]
95
+ regions.append((cropped_region, (x1, y1, x2, y2)))
96
+ return regions
97
+
98
+
99
+ # ----------------------
100
+ # Step 4: 文字识别
101
+ # ----------------------
102
+ def convert_to_base64(region):
103
+ """
104
+ 将裁剪区域转换为 Base64 格式。
105
+ 输入: region (NumPy 数组)
106
+ 输出: Base64 编码字符串
107
+ """
108
+ # 确保输入是 C-contiguous 的 NumPy 数组
109
+ region = np.ascontiguousarray(region)
110
+
111
+ # 转换为 PIL 图像
112
+ pil_image = Image.fromarray(region)
113
+
114
+ # 将图像保存到内存中的字节流
115
+ buffer = io.BytesIO()
116
+ pil_image.save(buffer, format="PNG") # 保存为 PNG 格式
117
+ buffer.seek(0)
118
+
119
+ # 转为 Base64 编码
120
+ img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
121
+ return img_base64
122
+ def recognize_text(region):
123
+ """使用 OCR 识别文字。"""
124
+ base64_image = convert_to_base64(region)
125
+ response = client.chat.completions.create(
126
+ model="gpt-4o-mini",
127
+ messages=[
128
+ {
129
+ "role": "user",
130
+ "content": [
131
+ {
132
+ "type": "text",
133
+ "text": "Directly output the numbers in the graph."
134
+ },
135
+ {
136
+ "type": "image_url",
137
+ "image_url": {
138
+ "url": f"data:image/jpeg;base64,{base64_image}",
139
+ },
140
+ },
141
+ ],
142
+ }
143
+ ],
144
+ max_tokens=300,
145
+ )
146
+ print(response.choices)
147
+ return response.choices[0].message.content
148
+
149
+ # ----------------------
150
+ # Gradio 接口
151
+ # ----------------------
152
+ def gradio_wrapper(image):
153
+ # Step 1: 图像预处理
154
+ gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image = preprocess_image(image)
155
+
156
+ # Step 2: YOLO 检测
157
+ yolo_annotated_image, boxes = detect_with_yolo(binary_image, model_path='best.pt')
158
+
159
+ detected_regions = crop_regions_from_boxes(image, boxes)
160
+
161
+ recognized_texts = [recognize_text(region[0]) for region in detected_regions]
162
+
163
+ gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB)
164
+ enhanced_image = cv2.cvtColor(contrast_enhanced, cv2.COLOR_GRAY2RGB)
165
+ denoised_image = cv2.cvtColor(blurred_image, cv2.COLOR_GRAY2RGB)
166
+ edge_image = cv2.cvtColor(sobel_normalized, cv2.COLOR_GRAY2RGB)
167
+ binary_image = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2RGB)
168
+ detected_region_image = detected_regions[0][0] if detected_regions else np.zeros((100, 100, 3), dtype=np.uint8)
169
+ final_result = recognized_texts[0]
170
+ steps_info = [
171
+ {"title": "灰度图", "image": gray_image, "description": "这是图片的灰度处理结果。"},
172
+ {"title": "对比度增强图", "image": enhanced_image, "description": "这是增强对比度后的图片。"},
173
+ {"title": "去噪图", "image": denoised_image, "description": "这是经过去噪处理的图片。"},
174
+ {"title": "边缘增强图", "image": edge_image, "description": "这是图片边缘增强后的效果。"},
175
+ {"title": "二值图", "image": binary_image, "description": "这是二值化处理后的图片。"},
176
+ {"title": "YOLO 检测标注图", "image": yolo_annotated_image, "description": "这是通过 YOLO 检测标注后的结果。"},
177
+ {"title": "区域", "image": detected_region_image, "description": "这是识别到的带字螺栓孔。"},
178
+ ]
179
+ return steps_info, final_result
180
+
181
+ # Gradio 接口函数:处理上传的图片并返回步骤图和描述
182
+ def process_and_display(image):
183
+ steps_info, result = gradio_wrapper(image)
184
+ steps = [
185
+ (
186
+ step["image"],
187
+ f"{i+1}.{step['title']}:\n{step['description']}"
188
+ )
189
+ for i, step in enumerate(steps_info)
190
+ ]
191
+ return steps, result
192
+
193
+ # 创建 Gradio 界面
194
+ with gr.Blocks() as demo:
195
+ gr.Markdown("<h1 style='text-align:center'><strong>中信戴卡铸字识别系统</strong></h1>")
196
+ gr.Markdown("<p style='text-align:center'>上传图片后,系统将按照步骤检测和识别图片中的铸字。</p>")
197
+
198
+ with gr.Row():
199
+ # 左侧:上传图片和显示结果
200
+ with gr.Column(scale=1):
201
+ upload = gr.Image(type="numpy", label="上传图片")
202
+ final_result = gr.Label(label="最终文字识别结果")
203
+
204
+ # 右侧:检测步骤展示
205
+ with gr.Column(scale=2):
206
+ gr.Markdown("### 检测步骤展示")
207
+ gallery = gr.Gallery(label="")
208
+ step_desc = gr.Markdown()
209
+
210
+ # 上传图片触发处理
211
+ upload.change(
212
+ fn=process_and_display,
213
+ inputs=upload,
214
+ outputs=[gallery, final_result],
215
+ )
216
+
217
+ if __name__ == "__main__":
218
+ demo.launch()
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b55e72c5a822eb6739f4097663e5f24756627b84585f130e5eceaca0bf275531
3
+ size 14442024
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ pillow
3
+ ultralytics
4
+ openai