Spaces:
Running
Running
import gradio as gr | |
import cv2 | |
import numpy as np | |
import torch | |
from openai import OpenAI | |
from PIL import Image | |
import base64 | |
import io | |
client = OpenAI(api_key=os.getenv("openaikey")) | |
# import easyocr | |
# ---------------------- | |
# Step 1: 图像预处理 | |
# ---------------------- | |
def preprocess_image(image): | |
"""对输入图像进行对比度增强、去噪、边缘检测等预处理。""" | |
gray_image = convert_to_grayscale(image) | |
contrast_enhanced = enhance_contrast(gray_image) | |
blurred_image = apply_gaussian_blur(contrast_enhanced) | |
sobel_normalized = detect_edges(blurred_image) | |
binary_image = threshold_image(sobel_normalized) | |
return gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image | |
def convert_to_grayscale(image): | |
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
def enhance_contrast(gray_image): | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
return clahe.apply(gray_image) | |
def apply_gaussian_blur(image): | |
return cv2.GaussianBlur(image, (5, 5), 0) | |
def detect_edges(image): | |
sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3) | |
sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3) | |
sobel = cv2.magnitude(sobel_x, sobel_y) | |
return cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) | |
def threshold_image(image): | |
_, binary_image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY) | |
return binary_image | |
# ---------------------- | |
# Step 2: YOLO 检测并提取区域 | |
# ---------------------- | |
def detect_with_yolo(image, model_path='best.pt', conf_threshold=0.5): | |
"""使用 YOLO 模型对图像进行目标检测。""" | |
model = load_yolo_model(model_path, conf_threshold) | |
return perform_yolo_detection(model, image) | |
def load_yolo_model(model_path, conf_threshold): | |
model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, force_reload=True) | |
model.conf = conf_threshold | |
return model | |
def perform_yolo_detection(model, image): | |
results = model(image) | |
detections = results.pandas().xyxy[0] | |
annotated_image, detected_regions = draw_yolo_detections(image, detections, model.names) | |
return annotated_image, detected_regions | |
def draw_yolo_detections(image, detections, class_names): | |
"""在图像上绘制 YOLO 检测结果,并返回检测区域。""" | |
annotated_image = image.copy() | |
if len(annotated_image.shape) == 2 or annotated_image.shape[2] == 1: | |
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2BGR) | |
boxes = [] | |
for _, row in detections.iterrows(): | |
x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax']) | |
conf, cls = row['confidence'], int(row['class']) | |
label = f"{class_names[cls]} {conf:.2f}" | |
# 提取检测区域 | |
region = annotated_image[y1:y2, x1:x2] | |
boxes.append({"box": (x1, y1, x2, y2), "class": class_names[cls], "confidence": conf}) | |
# 绘制检测框和标签 | |
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(annotated_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
return annotated_image, boxes | |
def crop_regions_from_boxes(original_image, boxes): | |
"""从原始图像中裁剪出 YOLO 检测的区域。""" | |
regions = [] | |
for box_info in boxes: | |
x1, y1, x2, y2 = box_info["box"] | |
cropped_region = original_image[y1:y2, x1:x2] | |
regions.append((cropped_region, (x1, y1, x2, y2))) | |
return regions | |
# ---------------------- | |
# Step 4: 文字识别 | |
# ---------------------- | |
def convert_to_base64(region): | |
""" | |
将裁剪区域转换为 Base64 格式。 | |
输入: region (NumPy 数组) | |
输出: Base64 编码字符串 | |
""" | |
# 确保输入是 C-contiguous 的 NumPy 数组 | |
region = np.ascontiguousarray(region) | |
# 转换为 PIL 图像 | |
pil_image = Image.fromarray(region) | |
# 将图像保存到内存中的字节流 | |
buffer = io.BytesIO() | |
pil_image.save(buffer, format="PNG") # 保存为 PNG 格式 | |
buffer.seek(0) | |
# 转为 Base64 编码 | |
img_base64 = base64.b64encode(buffer.read()).decode("utf-8") | |
return img_base64 | |
def recognize_text(region): | |
"""使用 OCR 识别文字。""" | |
base64_image = convert_to_base64(region) | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "Directly output the numbers in the graph." | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_image}", | |
}, | |
}, | |
], | |
} | |
], | |
max_tokens=300, | |
) | |
print(response.choices) | |
return response.choices[0].message.content | |
# ---------------------- | |
# Gradio 接口 | |
# ---------------------- | |
def gradio_wrapper(image): | |
# Step 1: 图像预处理 | |
gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image = preprocess_image(image) | |
# Step 2: YOLO 检测 | |
yolo_annotated_image, boxes = detect_with_yolo(binary_image, model_path='best.pt') | |
detected_regions = crop_regions_from_boxes(image, boxes) | |
recognized_texts = [recognize_text(region[0]) for region in detected_regions] | |
gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB) | |
enhanced_image = cv2.cvtColor(contrast_enhanced, cv2.COLOR_GRAY2RGB) | |
denoised_image = cv2.cvtColor(blurred_image, cv2.COLOR_GRAY2RGB) | |
edge_image = cv2.cvtColor(sobel_normalized, cv2.COLOR_GRAY2RGB) | |
binary_image = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2RGB) | |
detected_region_image = detected_regions[0][0] if detected_regions else np.zeros((100, 100, 3), dtype=np.uint8) | |
final_result = recognized_texts[0] | |
steps_info = [ | |
{"title": "灰度图", "image": gray_image, "description": "这是图片的灰度处理结果。"}, | |
{"title": "对比度增强图", "image": enhanced_image, "description": "这是增强对比度后的图片。"}, | |
{"title": "去噪图", "image": denoised_image, "description": "这是经过去噪处理的图片。"}, | |
{"title": "边缘增强图", "image": edge_image, "description": "这是图片边缘增强后的效果。"}, | |
{"title": "二值图", "image": binary_image, "description": "这是二值化处理后的图片。"}, | |
{"title": "YOLO 检测标注图", "image": yolo_annotated_image, "description": "这是通过 YOLO 检测标注后的结果。"}, | |
{"title": "区域", "image": detected_region_image, "description": "这是识别到的带字螺栓孔。"}, | |
] | |
return steps_info, final_result | |
# Gradio 接口函数:处理上传的图片并返回步骤图和描述 | |
def process_and_display(image): | |
steps_info, result = gradio_wrapper(image) | |
steps = [ | |
( | |
step["image"], | |
f"{i+1}.{step['title']}:\n{step['description']}" | |
) | |
for i, step in enumerate(steps_info) | |
] | |
return steps, result | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1 style='text-align:center'><strong>中信戴卡铸字识别系统</strong></h1>") | |
gr.Markdown("<p style='text-align:center'>上传图片后,系统将按照步骤检测和识别图片中的铸字。</p>") | |
with gr.Row(): | |
# 左侧:上传图片和显示结果 | |
with gr.Column(scale=1): | |
upload = gr.Image(type="numpy", label="上传图片") | |
final_result = gr.Label(label="最终文字识别结果") | |
# 右侧:检测步骤展示 | |
with gr.Column(scale=2): | |
gr.Markdown("### 检测步骤展示") | |
gallery = gr.Gallery(label="") | |
step_desc = gr.Markdown() | |
# 上传图片触发处理 | |
upload.change( | |
fn=process_and_display, | |
inputs=upload, | |
outputs=[gallery, final_result], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |