import gradio as gr import cv2 import numpy as np import torch from openai import OpenAI from PIL import Image import base64 import io client = OpenAI(api_key="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") # import easyocr # ---------------------- # Step 1: 图像预处理 # ---------------------- def preprocess_image(image): """对输入图像进行对比度增强、去噪、边缘检测等预处理。""" gray_image = convert_to_grayscale(image) contrast_enhanced = enhance_contrast(gray_image) blurred_image = apply_gaussian_blur(contrast_enhanced) sobel_normalized = detect_edges(blurred_image) binary_image = threshold_image(sobel_normalized) return gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image def convert_to_grayscale(image): return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) def enhance_contrast(gray_image): clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) return clahe.apply(gray_image) def apply_gaussian_blur(image): return cv2.GaussianBlur(image, (5, 5), 0) def detect_edges(image): sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3) sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3) sobel = cv2.magnitude(sobel_x, sobel_y) return cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) def threshold_image(image): _, binary_image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY) return binary_image # ---------------------- # Step 2: YOLO 检测并提取区域 # ---------------------- def detect_with_yolo(image, model_path='best.pt', conf_threshold=0.5): """使用 YOLO 模型对图像进行目标检测。""" model = load_yolo_model(model_path, conf_threshold) return perform_yolo_detection(model, image) def load_yolo_model(model_path, conf_threshold): model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, force_reload=True) model.conf = conf_threshold return model def perform_yolo_detection(model, image): results = model(image) detections = results.pandas().xyxy[0] annotated_image, detected_regions = draw_yolo_detections(image, detections, model.names) return annotated_image, detected_regions def draw_yolo_detections(image, detections, class_names): """在图像上绘制 YOLO 检测结果,并返回检测区域。""" annotated_image = image.copy() if len(annotated_image.shape) == 2 or annotated_image.shape[2] == 1: annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2BGR) boxes = [] for _, row in detections.iterrows(): x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax']) conf, cls = row['confidence'], int(row['class']) label = f"{class_names[cls]} {conf:.2f}" # 提取检测区域 region = annotated_image[y1:y2, x1:x2] boxes.append({"box": (x1, y1, x2, y2), "class": class_names[cls], "confidence": conf}) # 绘制检测框和标签 cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(annotated_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) return annotated_image, boxes def crop_regions_from_boxes(original_image, boxes): """从原始图像中裁剪出 YOLO 检测的区域。""" regions = [] for box_info in boxes: x1, y1, x2, y2 = box_info["box"] cropped_region = original_image[y1:y2, x1:x2] regions.append((cropped_region, (x1, y1, x2, y2))) return regions # ---------------------- # Step 4: 文字识别 # ---------------------- def convert_to_base64(region): """ 将裁剪区域转换为 Base64 格式。 输入: region (NumPy 数组) 输出: Base64 编码字符串 """ # 确保输入是 C-contiguous 的 NumPy 数组 region = np.ascontiguousarray(region) # 转换为 PIL 图像 pil_image = Image.fromarray(region) # 将图像保存到内存中的字节流 buffer = io.BytesIO() pil_image.save(buffer, format="PNG") # 保存为 PNG 格式 buffer.seek(0) # 转为 Base64 编码 img_base64 = base64.b64encode(buffer.read()).decode("utf-8") return img_base64 def recognize_text(region): """使用 OCR 识别文字。""" base64_image = convert_to_base64(region) response = client.chat.completions.create( model="gpt-4o-mini", messages=[ { "role": "user", "content": [ { "type": "text", "text": "Directly output the numbers in the graph." }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}", }, }, ], } ], max_tokens=300, ) print(response.choices) return response.choices[0].message.content # ---------------------- # Gradio 接口 # ---------------------- def gradio_wrapper(image): # Step 1: 图像预处理 gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image = preprocess_image(image) # Step 2: YOLO 检测 yolo_annotated_image, boxes = detect_with_yolo(binary_image, model_path='best.pt') detected_regions = crop_regions_from_boxes(image, boxes) recognized_texts = [recognize_text(region[0]) for region in detected_regions] gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB) enhanced_image = cv2.cvtColor(contrast_enhanced, cv2.COLOR_GRAY2RGB) denoised_image = cv2.cvtColor(blurred_image, cv2.COLOR_GRAY2RGB) edge_image = cv2.cvtColor(sobel_normalized, cv2.COLOR_GRAY2RGB) binary_image = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2RGB) detected_region_image = detected_regions[0][0] if detected_regions else np.zeros((100, 100, 3), dtype=np.uint8) final_result = recognized_texts[0] steps_info = [ {"title": "灰度图", "image": gray_image, "description": "这是图片的灰度处理结果。"}, {"title": "对比度增强图", "image": enhanced_image, "description": "这是增强对比度后的图片。"}, {"title": "去噪图", "image": denoised_image, "description": "这是经过去噪处理的图片。"}, {"title": "边缘增强图", "image": edge_image, "description": "这是图片边缘增强后的效果。"}, {"title": "二值图", "image": binary_image, "description": "这是二值化处理后的图片。"}, {"title": "YOLO 检测标注图", "image": yolo_annotated_image, "description": "这是通过 YOLO 检测标注后的结果。"}, {"title": "区域", "image": detected_region_image, "description": "这是识别到的带字螺栓孔。"}, ] return steps_info, final_result # Gradio 接口函数:处理上传的图片并返回步骤图和描述 def process_and_display(image): steps_info, result = gradio_wrapper(image) steps = [ ( step["image"], f"{i+1}.{step['title']}:\n{step['description']}" ) for i, step in enumerate(steps_info) ] return steps, result # 创建 Gradio 界面 with gr.Blocks() as demo: gr.Markdown("
上传图片后,系统将按照步骤检测和识别图片中的铸字。
") with gr.Row(): # 左侧:上传图片和显示结果 with gr.Column(scale=1): upload = gr.Image(type="numpy", label="上传图片") final_result = gr.Label(label="最终文字识别结果") # 右侧:检测步骤展示 with gr.Column(scale=2): gr.Markdown("### 检测步骤展示") gallery = gr.Gallery(label="") step_desc = gr.Markdown() # 上传图片触发处理 upload.change( fn=process_and_display, inputs=upload, outputs=[gallery, final_result], ) if __name__ == "__main__": demo.launch()