Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import snapshot_download | |
from ultralytics import YOLO | |
import os | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import tempfile | |
#public model path location | |
#MODEL_REPO_ID = "mintheinwin/3907578Y" | |
#Organizations model path location | |
MODEL_REPO_ID = "ITI107-2024S2/3907578Y" | |
# Load model | |
def load_model(repo_id): | |
download_dir = snapshot_download(repo_id) | |
path = os.path.join(download_dir, "best_int8_openvino_model") | |
detection_model = YOLO(path, task="detect") | |
return detection_model | |
detection_model = load_model(MODEL_REPO_ID) | |
#Student ID | |
student_info = "Student Id: 3907578Y, Name: Min Thein Win" | |
#Prediction for images | |
def predict_image(pil_img): | |
result = detection_model.predict(pil_img, conf=0.5, iou=0.5) | |
img_bgr = result[0].plot() # Annotated image | |
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert to RGB PIL image | |
return out_pilimg | |
#Prediction for videos | |
def predict_video(video): | |
cap = cv2.VideoCapture(video) | |
frames = [] | |
temp_dir = tempfile.mkdtemp() | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Detection | |
result = detection_model.predict(frame, conf=0.5, iou=0.5) | |
annotated_frame = result[0].plot() | |
frames.append(annotated_frame) | |
cap.release() | |
# Save annotated video | |
height, width, _ = frames[0].shape | |
output_path = os.path.join(temp_dir, "annotated_video.mp4") | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), 20, (width, height)) | |
for frame in frames: | |
out.write(frame) | |
out.release() | |
return output_path | |
# Unified prediction function | |
def unified_predict(file): | |
if isinstance(file, Image.Image): | |
# If the input is a PIL Image, treat it as an image | |
return predict_image(file) | |
elif isinstance(file, str) and file.endswith(('.mp4', '.avi', '.mov')): | |
# If the input is a video file path, treat it as a video | |
return predict_video(file) | |
else: | |
raise ValueError("Unsupported file type. Please upload an image or a video.") | |
# UI Interface | |
with gr.Blocks() as interface: | |
gr.Markdown("# Wild Animal Detection (Tiger/Lion)") | |
gr.Markdown(student_info) | |
# Unified Section | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Upload an Image or Video:") | |
input_file = gr.File(label="Input File") | |
with gr.Column(): | |
gr.Markdown("### Output Results:") | |
output_display = gr.Output(label="Output") | |
clear_btn= gr.Button("CLEAR") | |
submit_btn = gr.Button("SUBMIT") | |
def process_file(file): | |
if file.name.endswith((".jpg", ".jpeg", ".png")): | |
pil_image = Image.open(file.name) | |
return predict_image(pil_image) | |
elif file.name.endswith((".mp4", ".avi", ".mov")): | |
return predict_video(file.name) | |
else: | |
return "Unsupported file type. Please upload an image or a video." | |
def clear_all(): | |
return None, "" | |
submit_btn.click(fn=process_file, inputs=input_file, outputs=output_display) | |
clear_btn.click(fn=clear_all, inputs=None, outputs=[input_file, output_display]) | |
# Launch app | |
interface.launch(share=True) | |