Rohankumar31 commited on
Commit
6cfcd89
·
verified ·
1 Parent(s): dc958df

Update sample.py

Browse files
Files changed (1) hide show
  1. sample.py +37 -40
sample.py CHANGED
@@ -1,60 +1,57 @@
1
  import streamlit as st
2
  from PIL import Image
3
- from io import BytesIO
4
  import torch
5
- from pathlib import Path
6
- from PIL import Image
7
- from torchvision.transforms import functional as F
8
- from utils.plots import plot_one_box
9
-
10
- # Import YOLOv5
11
- from models.experimental import attempt_load
12
- from utils.general import non_max_suppression, scale_coords
13
- from utils.torch_utils import select_device
14
-
15
- # Function to load YOLOv5 model
16
- @st.cache(allow_output_mutation=True)
17
- def load_model(weights='best.pt', device=''):
18
- device = select_device(device)
19
- model = attempt_load(weights, map_location=device)
20
  return model
21
 
22
- # Function to process image through YOLOv5
23
  def process_image(image, model):
24
- # Convert PIL image to Torch tensor
25
- img_tensor = F.to_tensor(image)
26
- img_tensor = img_tensor.unsqueeze(0)
27
-
28
- # Inference
29
- results = model(img_tensor)
30
- results = non_max_suppression(results, conf_thres=0.4, iou_thres=0.5)
31
-
32
- # Draw bounding boxes on the image
33
- for i, det in enumerate(results):
34
- if det is not None and len(det):
35
- det[:, :4] = scale_coords(img_tensor.shape[2:], det[:, :4], image.size)
36
- for *xyxy, conf, cls in reversed(det):
37
- plot_one_box(xyxy, image, label=f'{model.names[int(cls)]} {conf:.2f}', color='red', line_thickness=3)
 
 
 
 
38
 
39
- return image
40
 
41
- # Main function
42
  def main():
43
- st.title('YOLOv5 Image Detection')
44
 
45
  # Upload image file
46
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
47
-
48
  if uploaded_file is not None:
49
- # Load YOLOv5 model
50
  model = load_model()
51
-
52
  # Process uploaded image
53
  image = Image.open(uploaded_file)
54
  st.image(image, caption='Original Image', use_column_width=True)
55
-
56
- output_image = process_image(image, model)
 
57
  st.image(output_image, caption='Processed Image', use_column_width=True)
58
 
59
  if __name__ == '__main__':
60
- main()
 
1
  import streamlit as st
2
  from PIL import Image
 
3
  import torch
4
+ from torchvision import transforms
5
+ from ultralytics import YOLO
6
+
7
+ # Load the YOLO model
8
+ @st.cache
9
+ def load_model():
10
+ # Replace 'model.pt' with the path to your YOLO model file
11
+ model = YOLO('best.pt')
 
 
 
 
 
 
 
12
  return model
13
 
14
+ # Define YOLO processing function
15
  def process_image(image, model):
16
+ # Preprocess the image
17
+ preprocess = transforms.Compose([
18
+ transforms.Resize((416, 416)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
+ ])
22
+ input_tensor = preprocess(image)
23
+ input_batch = input_tensor.unsqueeze(0)
24
+
25
+ # Perform inference
26
+ with torch.no_grad():
27
+ output = model(input_batch)
28
+
29
+ # Post-process the output (e.g., draw bounding boxes)
30
+ # Replace this with your post-processing code
31
+
32
+ # Convert tensor to PIL Image
33
+ output_image = transforms.ToPILImage()(output[0].cpu().squeeze())
34
 
35
+ return output_image
36
 
37
+ # Main Streamlit code
38
  def main():
39
+ st.title('YOLO Image Detection')
40
 
41
  # Upload image file
42
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
43
+
44
  if uploaded_file is not None:
45
+ # Load YOLO model
46
  model = load_model()
47
+
48
  # Process uploaded image
49
  image = Image.open(uploaded_file)
50
  st.image(image, caption='Original Image', use_column_width=True)
51
+
52
+ output = model.predict(image)
53
+ output_image = Image.fromarray(output)
54
  st.image(output_image, caption='Processed Image', use_column_width=True)
55
 
56
  if __name__ == '__main__':
57
+ main()