Spaces:
Running
Running
jianzhichun
commited on
Commit
·
6138e21
1
Parent(s):
af3701b
update
Browse files- app.py +218 -0
- best.pt +3 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from openai import OpenAI
|
6 |
+
from PIL import Image
|
7 |
+
import base64
|
8 |
+
import io
|
9 |
+
|
10 |
+
client = OpenAI(api_key="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")
|
11 |
+
|
12 |
+
# import easyocr
|
13 |
+
|
14 |
+
# ----------------------
|
15 |
+
# Step 1: 图像预处理
|
16 |
+
# ----------------------
|
17 |
+
def preprocess_image(image):
|
18 |
+
"""对输入图像进行对比度增强、去噪、边缘检测等预处理。"""
|
19 |
+
gray_image = convert_to_grayscale(image)
|
20 |
+
contrast_enhanced = enhance_contrast(gray_image)
|
21 |
+
blurred_image = apply_gaussian_blur(contrast_enhanced)
|
22 |
+
sobel_normalized = detect_edges(blurred_image)
|
23 |
+
binary_image = threshold_image(sobel_normalized)
|
24 |
+
return gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image
|
25 |
+
|
26 |
+
def convert_to_grayscale(image):
|
27 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
28 |
+
|
29 |
+
def enhance_contrast(gray_image):
|
30 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
31 |
+
return clahe.apply(gray_image)
|
32 |
+
|
33 |
+
def apply_gaussian_blur(image):
|
34 |
+
return cv2.GaussianBlur(image, (5, 5), 0)
|
35 |
+
|
36 |
+
def detect_edges(image):
|
37 |
+
sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
|
38 |
+
sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
|
39 |
+
sobel = cv2.magnitude(sobel_x, sobel_y)
|
40 |
+
return cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
|
41 |
+
|
42 |
+
def threshold_image(image):
|
43 |
+
_, binary_image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
|
44 |
+
return binary_image
|
45 |
+
|
46 |
+
# ----------------------
|
47 |
+
# Step 2: YOLO 检测并提取区域
|
48 |
+
# ----------------------
|
49 |
+
def detect_with_yolo(image, model_path='best.pt', conf_threshold=0.5):
|
50 |
+
"""使用 YOLO 模型对图像进行目标检测。"""
|
51 |
+
model = load_yolo_model(model_path, conf_threshold)
|
52 |
+
return perform_yolo_detection(model, image)
|
53 |
+
|
54 |
+
def load_yolo_model(model_path, conf_threshold):
|
55 |
+
model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, force_reload=True)
|
56 |
+
model.conf = conf_threshold
|
57 |
+
return model
|
58 |
+
|
59 |
+
def perform_yolo_detection(model, image):
|
60 |
+
results = model(image)
|
61 |
+
detections = results.pandas().xyxy[0]
|
62 |
+
annotated_image, detected_regions = draw_yolo_detections(image, detections, model.names)
|
63 |
+
return annotated_image, detected_regions
|
64 |
+
|
65 |
+
def draw_yolo_detections(image, detections, class_names):
|
66 |
+
"""在图像上绘制 YOLO 检测结果,并返回检测区域。"""
|
67 |
+
annotated_image = image.copy()
|
68 |
+
|
69 |
+
if len(annotated_image.shape) == 2 or annotated_image.shape[2] == 1:
|
70 |
+
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2BGR)
|
71 |
+
|
72 |
+
boxes = []
|
73 |
+
for _, row in detections.iterrows():
|
74 |
+
x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax'])
|
75 |
+
conf, cls = row['confidence'], int(row['class'])
|
76 |
+
label = f"{class_names[cls]} {conf:.2f}"
|
77 |
+
|
78 |
+
# 提取检测区域
|
79 |
+
region = annotated_image[y1:y2, x1:x2]
|
80 |
+
boxes.append({"box": (x1, y1, x2, y2), "class": class_names[cls], "confidence": conf})
|
81 |
+
|
82 |
+
# 绘制检测框和标签
|
83 |
+
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
84 |
+
cv2.putText(annotated_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
85 |
+
|
86 |
+
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
|
87 |
+
return annotated_image, boxes
|
88 |
+
|
89 |
+
def crop_regions_from_boxes(original_image, boxes):
|
90 |
+
"""从原始图像中裁剪出 YOLO 检测的区域。"""
|
91 |
+
regions = []
|
92 |
+
for box_info in boxes:
|
93 |
+
x1, y1, x2, y2 = box_info["box"]
|
94 |
+
cropped_region = original_image[y1:y2, x1:x2]
|
95 |
+
regions.append((cropped_region, (x1, y1, x2, y2)))
|
96 |
+
return regions
|
97 |
+
|
98 |
+
|
99 |
+
# ----------------------
|
100 |
+
# Step 4: 文字识别
|
101 |
+
# ----------------------
|
102 |
+
def convert_to_base64(region):
|
103 |
+
"""
|
104 |
+
将裁剪区域转换为 Base64 格式。
|
105 |
+
输入: region (NumPy 数组)
|
106 |
+
输出: Base64 编码字符串
|
107 |
+
"""
|
108 |
+
# 确保输入是 C-contiguous 的 NumPy 数组
|
109 |
+
region = np.ascontiguousarray(region)
|
110 |
+
|
111 |
+
# 转换为 PIL 图像
|
112 |
+
pil_image = Image.fromarray(region)
|
113 |
+
|
114 |
+
# 将图像保存到内存中的字节流
|
115 |
+
buffer = io.BytesIO()
|
116 |
+
pil_image.save(buffer, format="PNG") # 保存为 PNG 格式
|
117 |
+
buffer.seek(0)
|
118 |
+
|
119 |
+
# 转为 Base64 编码
|
120 |
+
img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
|
121 |
+
return img_base64
|
122 |
+
def recognize_text(region):
|
123 |
+
"""使用 OCR 识别文字。"""
|
124 |
+
base64_image = convert_to_base64(region)
|
125 |
+
response = client.chat.completions.create(
|
126 |
+
model="gpt-4o-mini",
|
127 |
+
messages=[
|
128 |
+
{
|
129 |
+
"role": "user",
|
130 |
+
"content": [
|
131 |
+
{
|
132 |
+
"type": "text",
|
133 |
+
"text": "Directly output the numbers in the graph."
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"type": "image_url",
|
137 |
+
"image_url": {
|
138 |
+
"url": f"data:image/jpeg;base64,{base64_image}",
|
139 |
+
},
|
140 |
+
},
|
141 |
+
],
|
142 |
+
}
|
143 |
+
],
|
144 |
+
max_tokens=300,
|
145 |
+
)
|
146 |
+
print(response.choices)
|
147 |
+
return response.choices[0].message.content
|
148 |
+
|
149 |
+
# ----------------------
|
150 |
+
# Gradio 接口
|
151 |
+
# ----------------------
|
152 |
+
def gradio_wrapper(image):
|
153 |
+
# Step 1: 图像预处理
|
154 |
+
gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image = preprocess_image(image)
|
155 |
+
|
156 |
+
# Step 2: YOLO 检测
|
157 |
+
yolo_annotated_image, boxes = detect_with_yolo(binary_image, model_path='best.pt')
|
158 |
+
|
159 |
+
detected_regions = crop_regions_from_boxes(image, boxes)
|
160 |
+
|
161 |
+
recognized_texts = [recognize_text(region[0]) for region in detected_regions]
|
162 |
+
|
163 |
+
gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB)
|
164 |
+
enhanced_image = cv2.cvtColor(contrast_enhanced, cv2.COLOR_GRAY2RGB)
|
165 |
+
denoised_image = cv2.cvtColor(blurred_image, cv2.COLOR_GRAY2RGB)
|
166 |
+
edge_image = cv2.cvtColor(sobel_normalized, cv2.COLOR_GRAY2RGB)
|
167 |
+
binary_image = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2RGB)
|
168 |
+
detected_region_image = detected_regions[0][0] if detected_regions else np.zeros((100, 100, 3), dtype=np.uint8)
|
169 |
+
final_result = recognized_texts[0]
|
170 |
+
steps_info = [
|
171 |
+
{"title": "灰度图", "image": gray_image, "description": "这是图片的灰度处理结果。"},
|
172 |
+
{"title": "对比度增强图", "image": enhanced_image, "description": "这是增强对比度后的图片。"},
|
173 |
+
{"title": "去噪图", "image": denoised_image, "description": "这是经过去噪处理的图片。"},
|
174 |
+
{"title": "边缘增强图", "image": edge_image, "description": "这是图片边缘增强后的效果。"},
|
175 |
+
{"title": "二值图", "image": binary_image, "description": "这是二值化处理后的图片。"},
|
176 |
+
{"title": "YOLO 检测标注图", "image": yolo_annotated_image, "description": "这是通过 YOLO 检测标注后的结果。"},
|
177 |
+
{"title": "区域", "image": detected_region_image, "description": "这是识别到的带字螺栓孔。"},
|
178 |
+
]
|
179 |
+
return steps_info, final_result
|
180 |
+
|
181 |
+
# Gradio 接口函数:处理上传的图片并返回步骤图和描述
|
182 |
+
def process_and_display(image):
|
183 |
+
steps_info, result = gradio_wrapper(image)
|
184 |
+
steps = [
|
185 |
+
(
|
186 |
+
step["image"],
|
187 |
+
f"{i+1}.{step['title']}:\n{step['description']}"
|
188 |
+
)
|
189 |
+
for i, step in enumerate(steps_info)
|
190 |
+
]
|
191 |
+
return steps, result
|
192 |
+
|
193 |
+
# 创建 Gradio 界面
|
194 |
+
with gr.Blocks() as demo:
|
195 |
+
gr.Markdown("<h1 style='text-align:center'><strong>中信戴卡铸字识别系统</strong></h1>")
|
196 |
+
gr.Markdown("<p style='text-align:center'>上传图片后,系统将按照步骤检测和识别图片中的铸字。</p>")
|
197 |
+
|
198 |
+
with gr.Row():
|
199 |
+
# 左侧:上传图片和显示结果
|
200 |
+
with gr.Column(scale=1):
|
201 |
+
upload = gr.Image(type="numpy", label="上传图片")
|
202 |
+
final_result = gr.Label(label="最终文字识别结果")
|
203 |
+
|
204 |
+
# 右侧:检测步骤展示
|
205 |
+
with gr.Column(scale=2):
|
206 |
+
gr.Markdown("### 检测步骤展示")
|
207 |
+
gallery = gr.Gallery(label="")
|
208 |
+
step_desc = gr.Markdown()
|
209 |
+
|
210 |
+
# 上传图片触发处理
|
211 |
+
upload.change(
|
212 |
+
fn=process_and_display,
|
213 |
+
inputs=upload,
|
214 |
+
outputs=[gallery, final_result],
|
215 |
+
)
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
demo.launch()
|
best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b55e72c5a822eb6739f4097663e5f24756627b84585f130e5eceaca0bf275531
|
3 |
+
size 14442024
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
pillow
|
3 |
+
ultralytics
|
4 |
+
openai
|