# app.py import gradio as gr import torch from torchvision import transforms, models import cv2 import numpy as np from PIL import Image from ultralytics import YOLO def load_models(): # Initialize YOLO yolo_model = YOLO('HockeyAI.pt') # Initialize SqueezeNet squeezenet_model = models.squeezenet1_1(weights=None) squeezenet_model.classifier[1] = torch.nn.Conv2d(512, 8, kernel_size=1) squeezenet_model.num_classes = 8 squeezenet_model.load_state_dict(torch.load('best_model_squezenet.pth', map_location=torch.device('cpu'))) squeezenet_model.eval() return yolo_model, squeezenet_model def process_image(input_image): if input_image is None: return None # Convert to numpy array if needed if isinstance(input_image, str): image = cv2.imread(input_image) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) else: image = input_image.copy() # Initialize models yolo_model, squeezenet_model = load_models() # Class labels for direction class_labels = [ "Bottom", "Bottom_Left", "Bottom_Right", "Left", "Right", "Top", "Top_Left", "Top_Right" ] # Transform for SqueezeNet transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Run YOLO detection results = yolo_model(image) # Process each detection for box in results[0].boxes: xyxy = box.xyxy[0].cpu().numpy() conf = float(box.conf[0].cpu().numpy()) cls = int(box.cls[0].cpu().numpy()) # Process only if it's a player (class 4) and confidence is above threshold if cls == 4 and conf > 0.5: x1, y1, x2, y2 = map(int, xyxy) # Crop and process for direction classification if x2 > x1 and y2 > y1: cropped_array = image[y1:y2, x1:x2] if cropped_array.size > 0: cropped_image = Image.fromarray(cropped_array) # Predict direction image_tensor = transform(cropped_image).unsqueeze(0) with torch.no_grad(): output = squeezenet_model(image_tensor) direction_class = torch.argmax(output, dim=1).item() # Draw annotations cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(image, f"{conf:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # Draw direction arrow center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 arrow_length = 80 # Increased from 50 to 80 direction = class_labels[direction_class] # Calculate arrow endpoint end_x, end_y = center_x, center_y if "Top" in direction: end_y = center_y - arrow_length elif "Bottom" in direction: end_y = center_y + arrow_length if "Left" in direction: end_x = center_x - arrow_length elif "Right" in direction: end_x = center_x + arrow_length cv2.arrowedLine(image, (center_x, center_y), (end_x, end_y), (255, 0, 0), 4, tipLength=0.4) return image # Create Gradio interface def gradio_interface(): with gr.Blocks() as iface: gr.Markdown("# Player Direction Detection") gr.Markdown("Upload an image to detect players and their movement directions") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="numpy") with gr.Column(): output_image = gr.Image(label="Output Image") # Handle image processing input_image.change( fn=process_image, inputs=[input_image], outputs=[output_image] ) # Add example images if you have them gr.Examples( examples=["example-1.jpg", "example-2.jpg"], inputs=input_image, outputs=output_image, fn=process_image, cache_examples=True ) return iface if __name__ == "__main__": iface = gradio_interface() iface.launch()