Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| from ultralytics import YOLO | |
| # Load the YOLO model | |
| def load_model(): | |
| # Replace 'model.pt' with the path to your YOLO model file | |
| model = YOLO('best.pt') | |
| return model | |
| # Define YOLO processing function | |
| def process_image(image): | |
| # Preprocess the image | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((416, 416)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| input_tensor = preprocess(image) | |
| input_batch = input_tensor.unsqueeze(0) | |
| # # Perform inference | |
| # with torch.no_grad(): | |
| # output = model(input_batch) | |
| # # Post-process the output (e.g., draw bounding boxes) | |
| # # Replace this with your post-processing code | |
| # # Convert tensor to PIL Image | |
| # output_image = transforms.ToPILImage()(output[0].cpu().squeeze()) | |
| return input_batch | |
| # Main Streamlit code | |
| def main(): | |
| st.title('YOLO Image Detection') | |
| # Upload image file | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Load YOLO model | |
| model = load_model() | |
| # Process uploaded image | |
| # image_l = Image.open(uploaded_file) | |
| image = plt.imread(uploaded_file) | |
| # Process uploaded image | |
| # image = Image.open(uploaded_file) | |
| st.image(image, caption='Original Image', use_column_width=True) | |
| # output_l = process_image(image) | |
| # output = np.asarray(output_l) | |
| # output_image = Image.fromarray(output) | |
| results = model(image) | |
| for result in results.pred[0]: | |
| box = result.xyxy[0].tolist() # Extract bounding box coordinates as a list [x_min, y_min, x_max, y_max] | |
| x_min, y_min, x_max, y_max = box | |
| cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) | |
| st.image(image, caption='Processed Image', use_column_width=True) | |
| if __name__ == '__main__': | |
| main() |