3907578Y / app.py
mintheinwin's picture
update app
1283eb6 verified
raw
history blame
3.29 kB
import gradio as gr
from huggingface_hub import snapshot_download
from ultralytics import YOLO
import os
from PIL import Image
import cv2
import numpy as np
import tempfile
#public model path location
#MODEL_REPO_ID = "mintheinwin/3907578Y"
#Organizations model path location
MODEL_REPO_ID = "ITI107-2024S2/3907578Y"
# Load model
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
path = os.path.join(download_dir, "best_int8_openvino_model")
detection_model = YOLO(path, task="detect")
return detection_model
detection_model = load_model(MODEL_REPO_ID)
#Student ID
student_info = "Student Id: 3907578Y, Name: Min Thein Win"
#Prediction for images
def predict_image(pil_img):
result = detection_model.predict(pil_img, conf=0.5, iou=0.5)
img_bgr = result[0].plot() # Annotated image
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert to RGB PIL image
return out_pilimg
#Prediction for videos
def predict_video(video):
cap = cv2.VideoCapture(video)
frames = []
temp_dir = tempfile.mkdtemp()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Detection
result = detection_model.predict(frame, conf=0.5, iou=0.5)
annotated_frame = result[0].plot()
frames.append(annotated_frame)
cap.release()
# Save annotated video
height, width, _ = frames[0].shape
output_path = os.path.join(temp_dir, "annotated_video.mp4")
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), 20, (width, height))
for frame in frames:
out.write(frame)
out.release()
return output_path
# Unified prediction function
def unified_predict(file):
if isinstance(file, Image.Image):
# If the input is a PIL Image, treat it as an image
return predict_image(file)
elif isinstance(file, str) and file.endswith(('.mp4', '.avi', '.mov')):
# If the input is a video file path, treat it as a video
return predict_video(file)
else:
raise ValueError("Unsupported file type. Please upload an image or a video.")
# UI Interface
with gr.Blocks() as interface:
gr.Markdown("# Wild Animal Detection (Tiger/Lion)")
gr.Markdown(student_info)
# Unified Section
with gr.Row():
with gr.Column():
gr.Markdown("### Upload an Image or Video:")
input_file = gr.File(label="Input File")
with gr.Column():
gr.Markdown("### Output Results:")
output_display = gr.Output(label="Output")
clear_btn= gr.Button("CLEAR")
submit_btn = gr.Button("SUBMIT")
def process_file(file):
if file.name.endswith((".jpg", ".jpeg", ".png")):
pil_image = Image.open(file.name)
return predict_image(pil_image)
elif file.name.endswith((".mp4", ".avi", ".mov")):
return predict_video(file.name)
else:
return "Unsupported file type. Please upload an image or a video."
def clear_all():
return None, ""
submit_btn.click(fn=process_file, inputs=input_file, outputs=output_display)
clear_btn.click(fn=clear_all, inputs=None, outputs=[input_file, output_display])
# Launch app
interface.launch(share=True)