jianzhichun's picture
update
ccea8da
import gradio as gr
import cv2
import numpy as np
import torch
from openai import OpenAI
from PIL import Image
import base64
import io
import os
client = OpenAI(api_key=os.getenv("openaikey"))
# import easyocr
# ----------------------
# Step 1: 图像预处理
# ----------------------
def preprocess_image(image):
"""对输入图像进行对比度增强、去噪、边缘检测等预处理。"""
gray_image = convert_to_grayscale(image)
contrast_enhanced = enhance_contrast(gray_image)
blurred_image = apply_gaussian_blur(contrast_enhanced)
sobel_normalized = detect_edges(blurred_image)
binary_image = threshold_image(sobel_normalized)
return gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image
def convert_to_grayscale(image):
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
def enhance_contrast(gray_image):
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return clahe.apply(gray_image)
def apply_gaussian_blur(image):
return cv2.GaussianBlur(image, (5, 5), 0)
def detect_edges(image):
sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
sobel = cv2.magnitude(sobel_x, sobel_y)
return cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
def threshold_image(image):
_, binary_image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
return binary_image
# ----------------------
# Step 2: YOLO 检测并提取区域
# ----------------------
def detect_with_yolo(image, model_path='best.pt', conf_threshold=0.5):
"""使用 YOLO 模型对图像进行目标检测。"""
model = load_yolo_model(model_path, conf_threshold)
return perform_yolo_detection(model, image)
def load_yolo_model(model_path, conf_threshold):
model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, force_reload=True)
model.conf = conf_threshold
return model
def perform_yolo_detection(model, image):
results = model(image)
detections = results.pandas().xyxy[0]
annotated_image, detected_regions = draw_yolo_detections(image, detections, model.names)
return annotated_image, detected_regions
def draw_yolo_detections(image, detections, class_names):
"""在图像上绘制 YOLO 检测结果,并返回检测区域。"""
annotated_image = image.copy()
if len(annotated_image.shape) == 2 or annotated_image.shape[2] == 1:
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2BGR)
boxes = []
for _, row in detections.iterrows():
x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax'])
conf, cls = row['confidence'], int(row['class'])
label = f"{class_names[cls]} {conf:.2f}"
# 提取检测区域
region = annotated_image[y1:y2, x1:x2]
boxes.append({"box": (x1, y1, x2, y2), "class": class_names[cls], "confidence": conf})
# 绘制检测框和标签
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(annotated_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
return annotated_image, boxes
def crop_regions_from_boxes(original_image, boxes):
"""从原始图像中裁剪出 YOLO 检测的区域。"""
regions = []
for box_info in boxes:
x1, y1, x2, y2 = box_info["box"]
cropped_region = original_image[y1:y2, x1:x2]
regions.append((cropped_region, (x1, y1, x2, y2)))
return regions
# ----------------------
# Step 4: 文字识别
# ----------------------
def convert_to_base64(region):
"""
将裁剪区域转换为 Base64 格式。
输入: region (NumPy 数组)
输出: Base64 编码字符串
"""
# 确保输入是 C-contiguous 的 NumPy 数组
region = np.ascontiguousarray(region)
# 转换为 PIL 图像
pil_image = Image.fromarray(region)
# 将图像保存到内存中的字节流
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG") # 保存为 PNG 格式
buffer.seek(0)
# 转为 Base64 编码
img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return img_base64
def recognize_text(region):
"""使用 OCR 识别文字。"""
base64_image = convert_to_base64(region)
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Directly output the numbers in the graph."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
}
],
max_tokens=300,
)
print(response.choices)
return response.choices[0].message.content
# ----------------------
# Gradio 接口
# ----------------------
def gradio_wrapper(image):
# Step 1: 图像预处理
gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image = preprocess_image(image)
# Step 2: YOLO 检测
yolo_annotated_image, boxes = detect_with_yolo(binary_image, model_path='best.pt')
detected_regions = crop_regions_from_boxes(image, boxes)
recognized_texts = [recognize_text(region[0]) for region in detected_regions]
gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB)
enhanced_image = cv2.cvtColor(contrast_enhanced, cv2.COLOR_GRAY2RGB)
denoised_image = cv2.cvtColor(blurred_image, cv2.COLOR_GRAY2RGB)
edge_image = cv2.cvtColor(sobel_normalized, cv2.COLOR_GRAY2RGB)
binary_image = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2RGB)
detected_region_image = detected_regions[0][0] if detected_regions else np.zeros((100, 100, 3), dtype=np.uint8)
final_result = recognized_texts[0]
steps_info = [
{"title": "灰度图", "image": gray_image, "description": "这是图片的灰度处理结果。"},
{"title": "对比度增强图", "image": enhanced_image, "description": "这是增强对比度后的图片。"},
{"title": "去噪图", "image": denoised_image, "description": "这是经过去噪处理的图片。"},
{"title": "边缘增强图", "image": edge_image, "description": "这是图片边缘增强后的效果。"},
{"title": "二值图", "image": binary_image, "description": "这是二值化处理后的图片。"},
{"title": "YOLO 检测标注图", "image": yolo_annotated_image, "description": "这是通过 YOLO 检测标注后的结果。"},
{"title": "区域", "image": detected_region_image, "description": "这是识别到的带字螺栓孔。"},
]
return steps_info, final_result
# Gradio 接口函数:处理上传的图片并返回步骤图和描述
def process_and_display(image):
steps_info, result = gradio_wrapper(image)
steps = [
(
step["image"],
f"{i+1}.{step['title']}:\n{step['description']}"
)
for i, step in enumerate(steps_info)
]
return steps, result
# 创建 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align:center'><strong>中信戴卡铸字识别系统</strong></h1>")
gr.Markdown("<p style='text-align:center'>上传图片后,系统将按照步骤检测和识别图片中的铸字。</p>")
with gr.Row():
# 左侧:上传图片和显示结果
with gr.Column(scale=1):
upload = gr.Image(type="numpy", label="上传图片")
final_result = gr.Label(label="最终文字识别结果")
# 右侧:检测步骤展示
with gr.Column(scale=2):
gr.Markdown("### 检测步骤展示")
gallery = gr.Gallery(label="")
step_desc = gr.Markdown()
# 上传图片触发处理
upload.change(
fn=process_and_display,
inputs=upload,
outputs=[gallery, final_result],
)
if __name__ == "__main__":
demo.launch()