File size: 2,191 Bytes
78de175
 
a0cc8be
07a4fd9
78de175
be24798
6cfcd89
 
 
 
 
 
 
 
78de175
 
6cfcd89
29c5e29
6cfcd89
 
 
 
 
 
 
 
 
29c5e29
 
 
6cfcd89
29c5e29
 
6cfcd89
29c5e29
 
dc958df
29c5e29
dc958df
6cfcd89
78de175
6cfcd89
78de175
 
 
6cfcd89
78de175
6cfcd89
78de175
6cfcd89
78de175
07a4fd9
 
f203369
 
78de175
6cfcd89
a0cc8be
 
 
 
f203369
 
 
 
a0cc8be
78de175
f203369
78de175
6cfcd89
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import streamlit as st
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch
import numpy as np
from torchvision import transforms
from ultralytics import YOLO

# Load the YOLO model
@st.cache
def load_model():
    # Replace 'model.pt' with the path to your YOLO model file
    model = YOLO('best.pt')
    return model

# Define YOLO processing function
def process_image(image):
    # Preprocess the image
    preprocess = transforms.Compose([
        transforms.Resize((416, 416)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = preprocess(image)
    input_batch = input_tensor.unsqueeze(0)
    
    # # Perform inference
    # with torch.no_grad():
    #     output = model(input_batch)
    
    # # Post-process the output (e.g., draw bounding boxes)
    # # Replace this with your post-processing code
    
    # # Convert tensor to PIL Image
    # output_image = transforms.ToPILImage()(output[0].cpu().squeeze())

    return input_batch

# Main Streamlit code
def main():
    st.title('YOLO Image Detection')
    
    # Upload image file
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    
    if uploaded_file is not None:
        # Load YOLO model
        model = load_model()
        
        # Process uploaded image
        # image_l = Image.open(uploaded_file)
        image = plt.imread(uploaded_file)
        # Process uploaded image
        # image = Image.open(uploaded_file)
        st.image(image, caption='Original Image', use_column_width=True)
        
        # output_l = process_image(image)
        # output = np.asarray(output_l)
        # output_image = Image.fromarray(output)
        results = model(image)
        for result in results.pred[0]:
            box = result.xyxy[0].tolist()  # Extract bounding box coordinates as a list [x_min, y_min, x_max, y_max]
            x_min, y_min, x_max, y_max = box
            cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
        st.image(image, caption='Processed Image', use_column_width=True)


if __name__ == '__main__':
    main()