jianzhichun's picture
update
2dabc44
raw
history blame
8.19 kB
import gradio as gr
import cv2
import numpy as np
import torch
from openai import OpenAI
from PIL import Image
import base64
import io
client = OpenAI(api_key=os.getenv("openaikey"))
# import easyocr
# ----------------------
# Step 1: 图像预处理
# ----------------------
def preprocess_image(image):
"""对输入图像进行对比度增强、去噪、边缘检测等预处理。"""
gray_image = convert_to_grayscale(image)
contrast_enhanced = enhance_contrast(gray_image)
blurred_image = apply_gaussian_blur(contrast_enhanced)
sobel_normalized = detect_edges(blurred_image)
binary_image = threshold_image(sobel_normalized)
return gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image
def convert_to_grayscale(image):
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
def enhance_contrast(gray_image):
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return clahe.apply(gray_image)
def apply_gaussian_blur(image):
return cv2.GaussianBlur(image, (5, 5), 0)
def detect_edges(image):
sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
sobel = cv2.magnitude(sobel_x, sobel_y)
return cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
def threshold_image(image):
_, binary_image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
return binary_image
# ----------------------
# Step 2: YOLO 检测并提取区域
# ----------------------
def detect_with_yolo(image, model_path='best.pt', conf_threshold=0.5):
"""使用 YOLO 模型对图像进行目标检测。"""
model = load_yolo_model(model_path, conf_threshold)
return perform_yolo_detection(model, image)
def load_yolo_model(model_path, conf_threshold):
model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, force_reload=True)
model.conf = conf_threshold
return model
def perform_yolo_detection(model, image):
results = model(image)
detections = results.pandas().xyxy[0]
annotated_image, detected_regions = draw_yolo_detections(image, detections, model.names)
return annotated_image, detected_regions
def draw_yolo_detections(image, detections, class_names):
"""在图像上绘制 YOLO 检测结果,并返回检测区域。"""
annotated_image = image.copy()
if len(annotated_image.shape) == 2 or annotated_image.shape[2] == 1:
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2BGR)
boxes = []
for _, row in detections.iterrows():
x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax'])
conf, cls = row['confidence'], int(row['class'])
label = f"{class_names[cls]} {conf:.2f}"
# 提取检测区域
region = annotated_image[y1:y2, x1:x2]
boxes.append({"box": (x1, y1, x2, y2), "class": class_names[cls], "confidence": conf})
# 绘制检测框和标签
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(annotated_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
return annotated_image, boxes
def crop_regions_from_boxes(original_image, boxes):
"""从原始图像中裁剪出 YOLO 检测的区域。"""
regions = []
for box_info in boxes:
x1, y1, x2, y2 = box_info["box"]
cropped_region = original_image[y1:y2, x1:x2]
regions.append((cropped_region, (x1, y1, x2, y2)))
return regions
# ----------------------
# Step 4: 文字识别
# ----------------------
def convert_to_base64(region):
"""
将裁剪区域转换为 Base64 格式。
输入: region (NumPy 数组)
输出: Base64 编码字符串
"""
# 确保输入是 C-contiguous 的 NumPy 数组
region = np.ascontiguousarray(region)
# 转换为 PIL 图像
pil_image = Image.fromarray(region)
# 将图像保存到内存中的字节流
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG") # 保存为 PNG 格式
buffer.seek(0)
# 转为 Base64 编码
img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return img_base64
def recognize_text(region):
"""使用 OCR 识别文字。"""
base64_image = convert_to_base64(region)
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Directly output the numbers in the graph."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
}
],
max_tokens=300,
)
print(response.choices)
return response.choices[0].message.content
# ----------------------
# Gradio 接口
# ----------------------
def gradio_wrapper(image):
# Step 1: 图像预处理
gray_image, contrast_enhanced, blurred_image, sobel_normalized, binary_image = preprocess_image(image)
# Step 2: YOLO 检测
yolo_annotated_image, boxes = detect_with_yolo(binary_image, model_path='best.pt')
detected_regions = crop_regions_from_boxes(image, boxes)
recognized_texts = [recognize_text(region[0]) for region in detected_regions]
gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB)
enhanced_image = cv2.cvtColor(contrast_enhanced, cv2.COLOR_GRAY2RGB)
denoised_image = cv2.cvtColor(blurred_image, cv2.COLOR_GRAY2RGB)
edge_image = cv2.cvtColor(sobel_normalized, cv2.COLOR_GRAY2RGB)
binary_image = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2RGB)
detected_region_image = detected_regions[0][0] if detected_regions else np.zeros((100, 100, 3), dtype=np.uint8)
final_result = recognized_texts[0]
steps_info = [
{"title": "灰度图", "image": gray_image, "description": "这是图片的灰度处理结果。"},
{"title": "对比度增强图", "image": enhanced_image, "description": "这是增强对比度后的图片。"},
{"title": "去噪图", "image": denoised_image, "description": "这是经过去噪处理的图片。"},
{"title": "边缘增强图", "image": edge_image, "description": "这是图片边缘增强后的效果。"},
{"title": "二值图", "image": binary_image, "description": "这是二值化处理后的图片。"},
{"title": "YOLO 检测标注图", "image": yolo_annotated_image, "description": "这是通过 YOLO 检测标注后的结果。"},
{"title": "区域", "image": detected_region_image, "description": "这是识别到的带字螺栓孔。"},
]
return steps_info, final_result
# Gradio 接口函数:处理上传的图片并返回步骤图和描述
def process_and_display(image):
steps_info, result = gradio_wrapper(image)
steps = [
(
step["image"],
f"{i+1}.{step['title']}:\n{step['description']}"
)
for i, step in enumerate(steps_info)
]
return steps, result
# 创建 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align:center'><strong>中信戴卡铸字识别系统</strong></h1>")
gr.Markdown("<p style='text-align:center'>上传图片后,系统将按照步骤检测和识别图片中的铸字。</p>")
with gr.Row():
# 左侧:上传图片和显示结果
with gr.Column(scale=1):
upload = gr.Image(type="numpy", label="上传图片")
final_result = gr.Label(label="最终文字识别结果")
# 右侧:检测步骤展示
with gr.Column(scale=2):
gr.Markdown("### 检测步骤展示")
gallery = gr.Gallery(label="")
step_desc = gr.Markdown()
# 上传图片触发处理
upload.change(
fn=process_and_display,
inputs=upload,
outputs=[gallery, final_result],
)
if __name__ == "__main__":
demo.launch()