Spaces:
Runtime error
Runtime error
import streamlit as st | |
from PIL import Image | |
import cv2 | |
import matplotlib.pyplot as plt | |
import torch | |
import numpy as np | |
from torchvision import transforms | |
from ultralytics import YOLO | |
# Load the YOLO model | |
def load_model(): | |
# Replace 'model.pt' with the path to your YOLO model file | |
model = YOLO('best.pt') | |
return model | |
# Define YOLO processing function | |
def process_image(image): | |
# Preprocess the image | |
preprocess = transforms.Compose([ | |
transforms.Resize((416, 416)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
input_tensor = preprocess(image) | |
input_batch = input_tensor.unsqueeze(0) | |
# # Perform inference | |
# with torch.no_grad(): | |
# output = model(input_batch) | |
# # Post-process the output (e.g., draw bounding boxes) | |
# # Replace this with your post-processing code | |
# # Convert tensor to PIL Image | |
# output_image = transforms.ToPILImage()(output[0].cpu().squeeze()) | |
return input_batch | |
# Main Streamlit code | |
def main(): | |
st.title('YOLO Image Detection') | |
# Upload image file | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Load YOLO model | |
model = load_model() | |
# Process uploaded image | |
# image_l = Image.open(uploaded_file) | |
image = plt.imread(uploaded_file) | |
# Process uploaded image | |
# image = Image.open(uploaded_file) | |
st.image(image, caption='Original Image', use_column_width=True) | |
# output_l = process_image(image) | |
# output = np.asarray(output_l) | |
# output_image = Image.fromarray(output) | |
results = model(image) | |
for result in results.pred[0]: | |
box = result.xyxy[0].tolist() # Extract bounding box coordinates as a list [x_min, y_min, x_max, y_max] | |
x_min, y_min, x_max, y_max = box | |
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) | |
st.image(image, caption='Processed Image', use_column_width=True) | |
if __name__ == '__main__': | |
main() |