224748K / app.py
ExStella
torch for gpu usage
1de5900
raw
history blame
3.39 kB
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import cv2
import tempfile
import torch
# Function to load YOLO model from a Hugging Face repository
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
print(f"Model downloaded to: {download_dir}")
model_path = os.path.join(download_dir, "best.pt") # Ensure best.pt is present in the repository
detection_model = YOLO(model_path).to('cuda' if torch.cuda.is_available() else 'cpu') # Move model to appropriate device
return detection_model
# Load models for traffic cones and license plates
traffic_cones_model = load_model("ExStella/Traffic-cones")
license_plate_model = load_model("ExStella/License-plate")
# Function to process an image with the selected model
def process_image(img, model_type):
model = traffic_cones_model if model_type == "Traffic Cones" else license_plate_model
result = model.predict(img, conf=0.5, iou=0.6)
img_bgr = result[0].plot()
output_image = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL
return output_image
# Function to process a video with the selected model
def process_video(video_path, model_type):
model = traffic_cones_model if model_type == "Traffic Cones" else license_plate_model
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Define video writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Perform object detection
results = model.predict(frame, conf=0.5, iou=0.6)
frame = results[0].plot() # Annotated frame
# Write the processed frame
out.write(frame)
cap.release()
out.release()
return temp_output
# Gradio prediction function
def predict(input_file, model_type):
if input_file.name.endswith(('.jpg', '.jpeg', '.png')):
# Image input
img = Image.open(input_file.name)
processed_image = process_image(img, model_type)
return processed_image, None # Return image and None for video
elif input_file.name.endswith(('.mp4', '.avi', '.mov')):
# Video input
processed_video = process_video(input_file.name, model_type)
return None, processed_video # Return None for image and video path
else:
raise ValueError("Unsupported file format. Please upload an image or video.")
# Gradio interface
gr.Interface(
fn=predict,
inputs=[
gr.File(label="Upload an image or video (JPG, PNG, MP4, AVI, etc.)"),
gr.Radio(["Traffic Cones", "License Plate"], label="Choose Detection Type"),
],
outputs=[
gr.Image(type="pil", label="Processed Image"), # Output for images
gr.Video(label="Processed Video"), # Output for videos
],
title="Object Detection for Traffic Cones and License Plates",
description="Upload an image or video to perform object detection. Select between Traffic Cones or License Plates detection.",
).launch(share=True)