Spaces:
Runtime error
Runtime error
from ultralytics import YOLO | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
import cv2 | |
import torch | |
import numpy as np | |
import tempfile | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load YOLO model from Hugging Face Hub | |
def load_model(repo_id): | |
download_dir = snapshot_download(repo_id) | |
path = os.path.join(download_dir, "best_int8_openvino_model 3") | |
model = YOLO(path, task="detect") | |
# Enable GPU if available | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
return model | |
# Load the model globally | |
detection_model = load_model("CharmainChua/windowsandcurtains") | |
# Predict for image | |
def predict_image(pil_img): | |
result = detection_model.predict(pil_img, conf=0.5, iou=0.6) | |
img_bgr = result[0].plot() | |
out_pil_img = Image.fromarray(img_bgr[..., ::-1]) | |
return out_pil_img | |
# Predict for video | |
def predict_video(video_path, batch_size=16): | |
if video_path is None: | |
return None | |
# Limit batch size to reduce memory usage | |
batch_size = min(batch_size, 24) | |
# Create a temporary file for output | |
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.avi') | |
output_file = temp_output.name | |
logger.info(f"Creating output file: {output_file}") | |
# Load video file | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
raise ValueError(f"Error opening video file: {video_path}") | |
# Get video properties | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Ensure dimensions are even numbers | |
width = width if width % 2 == 0 else width - 1 | |
height = height if height % 2 == 0 else height - 1 | |
logger.info(f"Video properties: {width}x{height} @ {fps}fps, {total_frames} frames") | |
# Use MJPG codec for better compatibility | |
fourcc = cv2.VideoWriter_fourcc(*'MJPG') | |
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) | |
if not out.isOpened(): | |
raise ValueError("Failed to create output video file") | |
try: | |
frames = [] | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frames.append(frame) | |
frame_count += 1 | |
# Process batch | |
if len(frames) == batch_size or frame_count == total_frames: | |
process_and_write_batch(frames, out, width, height, detection_model) | |
frames = [] # Clear the batch | |
if frame_count % 100 == 0: | |
logger.info(f"Processed {frame_count}/{total_frames} frames") | |
except Exception as e: | |
logger.error(f"Error during video processing: {str(e)}") | |
raise | |
finally: | |
# Clean up | |
if cap is not None: | |
cap.release() | |
if out is not None: | |
out.release() | |
cv2.destroyAllWindows() | |
if os.path.exists(output_file) and os.path.getsize(output_file) > 0: | |
logger.info(f"Successfully created output video: {output_file}") | |
return output_file | |
else: | |
logger.error("Failed to create output video file") | |
return None | |
def process_and_write_batch(frames, out, width, height, model): | |
logger.info(f"Processing batch of {len(frames)} frames") | |
for frame in frames: | |
try: | |
# Convert frame to PIL Image | |
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
# Perform detection with GPU if available | |
with torch.no_grad(): # Disable gradient computation for faster inference | |
result = model.predict(pil_img, conf=0.5, iou=0.6) | |
# Plot results on the frame | |
img_bgr = result[0].plot() | |
# Resize frame to match original dimensions | |
img_resized = cv2.resize(img_bgr, (width, height)) | |
# Write the frame to the output video | |
out.write(img_resized) | |
except Exception as e: | |
logger.error(f"Error processing/writing frame: {e}") | |
raise | |
# Gradio Interfaces | |
image_interface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="pil"), | |
title="Image Object Detection", | |
description="Upload an image to detect objects. The output image will have bounding boxes drawn around detected objects.", | |
) | |
video_interface = gr.Interface( | |
fn=predict_video, | |
inputs=[ | |
gr.Video(label="Upload a Video"), | |
gr.Number(value=24, label="Batch Size (default: 24, max: 24)"), | |
], | |
outputs=gr.Video(label="Processed Video"), | |
title="Video Object Detection with Batch Processing", | |
description="Upload a video to detect objects using batch processing. The output video will have bounding boxes drawn around detected objects.", | |
) | |
# Combine interfaces into one app | |
app = gr.TabbedInterface( | |
interface_list=[image_interface, video_interface], | |
tab_names=["Image Detection", "Video Detection"], | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) |