Yolo-weed-detection / sample.py
Rohankumar31's picture
Update sample.py
f203369 verified
import streamlit as st
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch
import numpy as np
from torchvision import transforms
from ultralytics import YOLO
# Load the YOLO model
@st.cache
def load_model():
# Replace 'model.pt' with the path to your YOLO model file
model = YOLO('best.pt')
return model
# Define YOLO processing function
def process_image(image):
# Preprocess the image
preprocess = transforms.Compose([
transforms.Resize((416, 416)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# # Perform inference
# with torch.no_grad():
# output = model(input_batch)
# # Post-process the output (e.g., draw bounding boxes)
# # Replace this with your post-processing code
# # Convert tensor to PIL Image
# output_image = transforms.ToPILImage()(output[0].cpu().squeeze())
return input_batch
# Main Streamlit code
def main():
st.title('YOLO Image Detection')
# Upload image file
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Load YOLO model
model = load_model()
# Process uploaded image
# image_l = Image.open(uploaded_file)
image = plt.imread(uploaded_file)
# Process uploaded image
# image = Image.open(uploaded_file)
st.image(image, caption='Original Image', use_column_width=True)
# output_l = process_image(image)
# output = np.asarray(output_l)
# output_image = Image.fromarray(output)
results = model(image)
for result in results.pred[0]:
box = result.xyxy[0].tolist() # Extract bounding box coordinates as a list [x_min, y_min, x_max, y_max]
x_min, y_min, x_max, y_max = box
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
st.image(image, caption='Processed Image', use_column_width=True)
if __name__ == '__main__':
main()