Spaces:
Sleeping
Sleeping
from ultralytics import YOLO | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
import cv2 | |
import tempfile | |
import torch | |
# Function to load YOLO model from a Hugging Face repository | |
def load_model(repo_id): | |
download_dir = snapshot_download(repo_id) | |
print(f"Model downloaded to: {download_dir}") | |
model_path = os.path.join(download_dir, "best.pt") # Ensure best.pt is present in the repository | |
detection_model = YOLO(model_path).to('cuda' if torch.cuda.is_available() else 'cpu') # Move model to appropriate device | |
return detection_model | |
# Load models for traffic cones and license plates | |
traffic_cones_model = load_model("ExStella/Traffic-cones") | |
license_plate_model = load_model("ExStella/License-plate") | |
# Function to process an image with the selected model | |
def process_image(img, model_type): | |
model = traffic_cones_model if model_type == "Traffic Cones" else license_plate_model | |
result = model.predict(img, conf=0.5, iou=0.6) | |
img_bgr = result[0].plot() | |
output_image = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL | |
return output_image | |
# Function to process a video with the selected model | |
def process_video(video_path, model_type): | |
model = traffic_cones_model if model_type == "Traffic Cones" else license_plate_model | |
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
# Define video writer | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Perform object detection | |
results = model.predict(frame, conf=0.5, iou=0.6) | |
frame = results[0].plot() # Annotated frame | |
# Write the processed frame | |
out.write(frame) | |
cap.release() | |
out.release() | |
return temp_output | |
# Gradio prediction function | |
def predict(input_file, model_type): | |
if input_file.name.endswith(('.jpg', '.jpeg', '.png')): | |
# Image input | |
img = Image.open(input_file.name) | |
processed_image = process_image(img, model_type) | |
return processed_image, None # Return image and None for video | |
elif input_file.name.endswith(('.mp4', '.avi', '.mov')): | |
# Video input | |
processed_video = process_video(input_file.name, model_type) | |
return None, processed_video # Return None for image and video path | |
else: | |
raise ValueError("Unsupported file format. Please upload an image or video.") | |
# Gradio interface | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.File(label="Upload an image or video (JPG, PNG, MP4, AVI, etc.)"), | |
gr.Radio(["Traffic Cones", "License Plate"], label="Choose Detection Type"), | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Processed Image"), # Output for images | |
gr.Video(label="Processed Video"), # Output for videos | |
], | |
title="Object Detection for Traffic Cones and License Plates", | |
description="Upload an image or video to perform object detection. Select between Traffic Cones or License Plates detection.", | |
).launch(share=True) | |