from ultralytics import YOLO from PIL import Image import gradio as gr from huggingface_hub import snapshot_download import os import cv2 import torch import numpy as np import tempfile import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load YOLO model from Hugging Face Hub def load_model(repo_id): download_dir = snapshot_download(repo_id) path = os.path.join(download_dir, "best_int8_openvino_model 3") model = YOLO(path, task="detect") # Enable GPU if available if torch.cuda.is_available(): model = model.to('cuda') return model # Load the model globally detection_model = load_model("CharmainChua/windowsandcurtains") # Predict for image def predict_image(pil_img): result = detection_model.predict(pil_img, conf=0.5, iou=0.6) img_bgr = result[0].plot() out_pil_img = Image.fromarray(img_bgr[..., ::-1]) return out_pil_img # Predict for video def predict_video(video_path, batch_size=16): if video_path is None: return None # Limit batch size to reduce memory usage batch_size = min(batch_size, 24) # Create a temporary file for output temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.avi') output_file = temp_output.name logger.info(f"Creating output file: {output_file}") # Load video file cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Error opening video file: {video_path}") # Get video properties width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Ensure dimensions are even numbers width = width if width % 2 == 0 else width - 1 height = height if height % 2 == 0 else height - 1 logger.info(f"Video properties: {width}x{height} @ {fps}fps, {total_frames} frames") # Use MJPG codec for better compatibility fourcc = cv2.VideoWriter_fourcc(*'MJPG') out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) if not out.isOpened(): raise ValueError("Failed to create output video file") try: frames = [] frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break frames.append(frame) frame_count += 1 # Process batch if len(frames) == batch_size or frame_count == total_frames: process_and_write_batch(frames, out, width, height, detection_model) frames = [] # Clear the batch if frame_count % 100 == 0: logger.info(f"Processed {frame_count}/{total_frames} frames") except Exception as e: logger.error(f"Error during video processing: {str(e)}") raise finally: # Clean up if cap is not None: cap.release() if out is not None: out.release() cv2.destroyAllWindows() if os.path.exists(output_file) and os.path.getsize(output_file) > 0: logger.info(f"Successfully created output video: {output_file}") return output_file else: logger.error("Failed to create output video file") return None def process_and_write_batch(frames, out, width, height, model): logger.info(f"Processing batch of {len(frames)} frames") for frame in frames: try: # Convert frame to PIL Image pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Perform detection with GPU if available with torch.no_grad(): # Disable gradient computation for faster inference result = model.predict(pil_img, conf=0.5, iou=0.6) # Plot results on the frame img_bgr = result[0].plot() # Resize frame to match original dimensions img_resized = cv2.resize(img_bgr, (width, height)) # Write the frame to the output video out.write(img_resized) except Exception as e: logger.error(f"Error processing/writing frame: {e}") raise # Gradio Interfaces image_interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Image Object Detection", description="Upload an image to detect objects. The output image will have bounding boxes drawn around detected objects.", ) video_interface = gr.Interface( fn=predict_video, inputs=[ gr.Video(label="Upload a Video"), gr.Number(value=24, label="Batch Size (default: 24, max: 24)"), ], outputs=gr.Video(label="Processed Video"), title="Video Object Detection with Batch Processing", description="Upload a video to detect objects using batch processing. The output video will have bounding boxes drawn around detected objects.", ) # Combine interfaces into one app app = gr.TabbedInterface( interface_list=[image_interface, video_interface], tab_names=["Image Detection", "Video Detection"], ) if __name__ == "__main__": app.launch(share=True)