224684P / app.py
CharmainChua's picture
model update
fa375ca
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import cv2
import torch
import numpy as np
import tempfile
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load YOLO model from Hugging Face Hub
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
path = os.path.join(download_dir, "best_int8_openvino_model 3")
model = YOLO(path, task="detect")
# Enable GPU if available
if torch.cuda.is_available():
model = model.to('cuda')
return model
# Load the model globally
detection_model = load_model("CharmainChua/windowsandcurtains")
# Predict for image
def predict_image(pil_img):
result = detection_model.predict(pil_img, conf=0.5, iou=0.6)
img_bgr = result[0].plot()
out_pil_img = Image.fromarray(img_bgr[..., ::-1])
return out_pil_img
# Predict for video
def predict_video(video_path, batch_size=16):
if video_path is None:
return None
# Limit batch size to reduce memory usage
batch_size = min(batch_size, 24)
# Create a temporary file for output
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.avi')
output_file = temp_output.name
logger.info(f"Creating output file: {output_file}")
# Load video file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Error opening video file: {video_path}")
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Ensure dimensions are even numbers
width = width if width % 2 == 0 else width - 1
height = height if height % 2 == 0 else height - 1
logger.info(f"Video properties: {width}x{height} @ {fps}fps, {total_frames} frames")
# Use MJPG codec for better compatibility
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
if not out.isOpened():
raise ValueError("Failed to create output video file")
try:
frames = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
frame_count += 1
# Process batch
if len(frames) == batch_size or frame_count == total_frames:
process_and_write_batch(frames, out, width, height, detection_model)
frames = [] # Clear the batch
if frame_count % 100 == 0:
logger.info(f"Processed {frame_count}/{total_frames} frames")
except Exception as e:
logger.error(f"Error during video processing: {str(e)}")
raise
finally:
# Clean up
if cap is not None:
cap.release()
if out is not None:
out.release()
cv2.destroyAllWindows()
if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
logger.info(f"Successfully created output video: {output_file}")
return output_file
else:
logger.error("Failed to create output video file")
return None
def process_and_write_batch(frames, out, width, height, model):
logger.info(f"Processing batch of {len(frames)} frames")
for frame in frames:
try:
# Convert frame to PIL Image
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Perform detection with GPU if available
with torch.no_grad(): # Disable gradient computation for faster inference
result = model.predict(pil_img, conf=0.5, iou=0.6)
# Plot results on the frame
img_bgr = result[0].plot()
# Resize frame to match original dimensions
img_resized = cv2.resize(img_bgr, (width, height))
# Write the frame to the output video
out.write(img_resized)
except Exception as e:
logger.error(f"Error processing/writing frame: {e}")
raise
# Gradio Interfaces
image_interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Object Detection",
description="Upload an image to detect objects. The output image will have bounding boxes drawn around detected objects.",
)
video_interface = gr.Interface(
fn=predict_video,
inputs=[
gr.Video(label="Upload a Video"),
gr.Number(value=24, label="Batch Size (default: 24, max: 24)"),
],
outputs=gr.Video(label="Processed Video"),
title="Video Object Detection with Batch Processing",
description="Upload a video to detect objects using batch processing. The output video will have bounding boxes drawn around detected objects.",
)
# Combine interfaces into one app
app = gr.TabbedInterface(
interface_list=[image_interface, video_interface],
tab_names=["Image Detection", "Video Detection"],
)
if __name__ == "__main__":
app.launch(share=True)