Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
from ultralytics import YOLO | |
from fastapi import FastAPI | |
from PIL import Image | |
import torch | |
import spaces | |
import numpy as np | |
import cv2 | |
from pathlib import Path | |
import tempfile | |
from tqdm import tqdm | |
# 从环境变量获取密码 | |
APP_USERNAME = "admin" # 用户名保持固定 | |
APP_PASSWORD = os.getenv("APP_PASSWORD", "default_password") # 从环境变量获取密码 | |
app = FastAPI() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"使用设备: {device}") | |
model = YOLO('kunin-mice-pose.v0.1.5n.engine') | |
print("模型加载完成") | |
# 定义认证状态 | |
class AuthState: | |
def __init__(self): | |
self.is_logged_in = False | |
auth_state = AuthState() | |
def login(username, password): | |
"""登录验证""" | |
if username == APP_USERNAME and password == APP_PASSWORD: | |
auth_state.is_logged_in = True | |
return gr.update(visible=False), gr.update(visible=True), "登录成功" | |
return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误" | |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8): | |
""" | |
处理视频并进行小鼠检测 | |
""" | |
print("开始处理视频...") | |
if not auth_state.is_logged_in: | |
return None, "请先登录" | |
print("创建临时输出文件...") | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
output_path = tmp_file.name | |
print("读取视频信息...") | |
cap = cv2.VideoCapture(video_path) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
cap.release() | |
print(f"视频信息: {width}x{height} @ {fps}fps, 总帧数: {total_frames}") | |
print("初始化视频写入器...") | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_writer = cv2.VideoWriter( | |
output_path, | |
fourcc, | |
fps, | |
(width, height) | |
) | |
base_size = min(width, height) | |
line_thickness = max(1, int(base_size * 0.002)) | |
print("开始YOLO推理...") | |
results = model.predict( | |
source=video_path, | |
device=device, | |
conf=conf_threshold, | |
save=False, | |
show=False, | |
stream=True, | |
line_width=line_thickness, | |
boxes=True, | |
show_labels=True, | |
show_conf=True, | |
vid_stride=1, | |
max_det=max_det, | |
retina_masks=True, | |
verbose=False | |
) | |
frame_count = 0 | |
detection_info = [] | |
all_positions = [] | |
heatmap = np.zeros((height, width), dtype=np.float32) | |
print("处理检测结果...") | |
progress_bar = tqdm(total=total_frames, desc="处理帧") | |
for r in results: | |
frame = r.plot() | |
if hasattr(r, 'keypoints') and r.keypoints is not None: | |
kpts = r.keypoints.data | |
if isinstance(kpts, torch.Tensor): | |
kpts = kpts.cpu().numpy() | |
if kpts.shape == (1, 8, 3): | |
x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1]) | |
all_positions.append([x, y]) | |
if 0 <= x < width and 0 <= y < height: | |
sigma = 10 | |
kernel_size = 31 | |
temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
temp_heatmap[y, x] = 1 | |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma) | |
heatmap += temp_heatmap | |
frame_info = { | |
"frame": frame_count + 1, | |
"count": len(r.boxes), | |
"detections": [] | |
} | |
for box in r.boxes: | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
cls_name = r.names[cls] | |
frame_info["detections"].append({ | |
"class": cls_name, | |
"confidence": f"{conf:.2%}" | |
}) | |
detection_info.append(frame_info) | |
video_writer.write(frame) | |
frame_count += 1 | |
progress_bar.update(1) | |
if process_seconds and frame_count >= total_frames: | |
break | |
progress_bar.close() | |
print("视频处理完成") | |
video_writer.release() | |
print("生成分析报告...") | |
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']] | |
hist, bins = np.histogram(confidences, bins=5) | |
confidence_report = "\n".join([ | |
f"置信度 {bins[i]:.2f}-{bins[i+1]:.2f}: {hist[i]:3d}个检测 ({hist[i]/len(confidences)*100:.1f}%)" | |
for i in range(len(hist)) | |
]) | |
report = f"""视频分析报告: | |
参数设置: | |
- 置信度阈值: {conf_threshold:.2f} | |
- 最大检测数量: {max_det} | |
- 处理时长: {process_seconds}秒 | |
分析结果: | |
- 处理帧数: {frame_count} | |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f} | |
- 最大检测数: {max([info['count'] for info in detection_info])} | |
- 最小检测数: {min([info['count'] for info in detection_info])} | |
置信度分布: | |
{confidence_report} | |
""" | |
def filter_trajectories(positions, width, height, max_jump_distance=100): | |
if len(positions) < 3: | |
return positions | |
filtered_positions = [] | |
last_valid_pos = None | |
for i, pos in enumerate(positions): | |
x, y = pos | |
if not (0 <= x < width and 0 <= y < height): | |
continue | |
if last_valid_pos is None: | |
filtered_positions.append(pos) | |
last_valid_pos = pos | |
continue | |
distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2) | |
if distance > max_jump_distance: | |
if len(filtered_positions) > 0: | |
next_valid_pos = None | |
for next_pos in positions[i:]: | |
nx, ny = next_pos | |
if (0 <= nx < width and 0 <= ny < height): | |
next_distance = np.sqrt((nx - last_valid_pos[0])**2 + (ny - last_valid_pos[1])**2) | |
if next_distance <= max_jump_distance: | |
next_valid_pos = next_pos | |
break | |
if next_valid_pos is not None: | |
steps = max(2, int(distance / max_jump_distance)) | |
for j in range(1, steps): | |
alpha = j / steps | |
interp_x = int(last_valid_pos[0] * (1 - alpha) + next_valid_pos[0] * alpha) | |
interp_y = int(last_valid_pos[1] * (1 - alpha) + next_valid_pos[1] * alpha) | |
filtered_positions.append([interp_x, interp_y]) | |
filtered_positions.append(next_valid_pos) | |
last_valid_pos = next_valid_pos | |
else: | |
filtered_positions.append(pos) | |
last_valid_pos = pos | |
window_size = 5 | |
smoothed_positions = [] | |
if len(filtered_positions) >= window_size: | |
smoothed_positions.extend(filtered_positions[:window_size//2]) | |
for i in range(window_size//2, len(filtered_positions) - window_size//2): | |
window = filtered_positions[i-window_size//2:i+window_size//2+1] | |
smoothed_x = int(np.mean([p[0] for p in window])) | |
smoothed_y = int(np.mean([p[1] for p in window])) | |
smoothed_positions.append([smoothed_x, smoothed_y]) | |
smoothed_positions.extend(filtered_positions[-window_size//2:]) | |
else: | |
smoothed_positions = filtered_positions | |
return smoothed_positions | |
print("生成轨迹图...") | |
trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 | |
points = np.array(all_positions, dtype=np.int32) | |
if len(points) > 1: | |
filtered_points = filter_trajectories(points.tolist(), width, height) | |
points = np.array(filtered_points, dtype=np.int32) | |
for i in range(len(points) - 1): | |
ratio = i / (len(points) - 1) | |
color = ( | |
int((1 - ratio) * 255), | |
50, | |
int(ratio * 255) | |
) | |
cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2) | |
cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1) | |
cv2.circle(trajectory_img, tuple(points[-1]), 8, (0, 0, 255), -1) | |
arrow_interval = max(len(points) // 20, 1) | |
for i in range(0, len(points) - arrow_interval, arrow_interval): | |
pt1 = tuple(points[i]) | |
pt2 = tuple(points[i + arrow_interval]) | |
angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0]) | |
cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2) | |
print("生成热力图...") | |
if np.max(heatmap) > 0: | |
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX) | |
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET) | |
alpha = 0.7 | |
heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0) | |
print("保存结果图像...") | |
trajectory_path = output_path.replace('.mp4', '_trajectory.png') | |
heatmap_path = output_path.replace('.mp4', '_heatmap.png') | |
cv2.imwrite(trajectory_path, trajectory_img) | |
cv2.imwrite(heatmap_path, heatmap_colored) | |
print("处理完成!") | |
return output_path, trajectory_path, heatmap_path, report | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🐭 小鼠行为分析 (Mice Behavior Analysis)") | |
with gr.Group() as login_interface: | |
username = gr.Textbox(label="用户名") | |
password = gr.Textbox(label="密码", type="password") | |
login_button = gr.Button("登录") | |
login_msg = gr.Textbox(label="消息", interactive=False) | |
with gr.Group(visible=False) as main_interface: | |
gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="输入视频") | |
process_seconds = gr.Number( | |
label="处理时长(秒,0表示处理整个视频)", | |
value=20 | |
) | |
conf_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.2, | |
step=0.05, | |
label="置信度阈值", | |
info="越高越严格,建议范围0.2-0.5" | |
) | |
max_det = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=1, | |
step=1, | |
label="最大检测数量", | |
info="每帧最多检测的目标数量" | |
) | |
process_btn = gr.Button("开始处理") | |
with gr.Column(): | |
video_output = gr.Video(label="检测结果") | |
with gr.Row(): | |
trajectory_output = gr.Image(label="运动轨迹") | |
heatmap_output = gr.Image(label="热力图") | |
report_output = gr.Textbox(label="分析报告") | |
gr.Markdown(""" | |
### 使用说明 | |
1. 上传视频文件 | |
2. 设置处理参数: | |
- 处理时长:需要分析的视频时长(秒) | |
- 置信度阈值:检测的置信度要求(越高越严格) | |
- 最大检测数量:每帧最多检测的目标数量 | |
3. 等待处理完成 | |
4. 查看检测结果视频和分析报告 | |
### 注意事项 | |
- 支持常见视频格式(mp4, avi 等) | |
- 建议视频分辨率不超过 1920x1080 | |
- 处理时间与视频长度和分辨率相关 | |
- 置信度建议范围:0.2-0.5 | |
- 最大检测数量建议根据实际场景设置 | |
""") | |
login_button.click( | |
fn=login, | |
inputs=[username, password], | |
outputs=[login_interface, main_interface, login_msg] | |
) | |
process_btn.click( | |
fn=process_video, | |
inputs=[video_input, process_seconds, conf_threshold, max_det], | |
outputs=[video_output, trajectory_output, heatmap_output, report_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |