Dan Biagini commited on
Commit
fb6287f
·
1 Parent(s): ca8be5b

fix multi run image reads and provide a sample for gringos

Browse files
src/hockey_object_detection.py CHANGED
@@ -9,46 +9,77 @@ def get_model():
9
  repo_id = "danbiagini/hockey_breeds_v2"
10
  return hf_hub_download(repo_id=repo_id, filename="hockey_breeds-v2-101623.pt")
11
 
 
12
  def run_inference(img, model, thresh=0.5):
13
  model = YOLO(model_f)
14
  st.session_state.results = model(img)
15
  return draw_hockey_boxes(img, st.session_state.results, thresh)
16
 
 
17
  def draw_hockey_boxes(frame, results, thresh=0.5):
18
- colors = {0: (0, 255, 0), 1: (255, 0, 0), 2: (0, 0, 255), 3: (128, 0, 0), 4: (0, 128, 0), 5: (0, 0, 128), 6: (0, 64, 0), 7: (64, 0, 0), 8: (0, 0, 64)}
 
19
  font_scale = frame.shape[0] / 500
20
  objects = []
21
-
22
  for name in results:
23
  for box in name.boxes.data.tolist():
24
  x1, y1, x2, y2, score, class_id = box
25
  objects.append((name.names[int(class_id)], score))
26
 
27
  if score > thresh:
28
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), colors[(class_id % 9)], 3)
 
29
  cv2.putText(frame, f'{name.names[int(class_id)].upper()}: {score:.2f}', (int(x1), int(y1 - 10)),
30
  cv2.FONT_HERSHEY_SIMPLEX, font_scale, colors[(class_id % 9)], 3, cv2.LINE_AA)
31
  else:
32
- print(f'Found an object under confidence threshold {thresh} type: {name.names[class_id]}, score:{score}, x1, y2:{x1}, {y2}')
 
33
  return objects
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
 
36
  if 'results' not in st.session_state:
37
  st.session_state.results = []
38
 
 
 
 
 
 
 
 
 
39
  st.set_page_config(page_title='Hockey Breeds v2 - Objects', layout="wide",
40
  page_icon=":frame_with_picture:")
41
 
42
  st.title('Hockey Breeds v2 - Objects')
43
- intro = '''The first version of Hockey Breeds was fun and educational, but not useful for analyzing hockey videos. The second version is to a proof of concept
44
  with the ability to recognize individual "objects" within an image, which paves the way to ultimately tracking those objects through game play.'''
45
 
46
  st.markdown(intro)
47
- st.subheader('Object Detection')
 
48
 
49
  desc = '''Hockey Breed detector v2 uses a state of the art (circa 2023) computer vision approach.
50
 
51
- I used the same training images as the first version of the Hockey Breeds model, but change the ML algorithm to use YOLO object detection (YOLO v8).
52
  The output will be a set of hockey objects (defined by "bounding boxes") with labels for any hockey image uploaded.
53
 
54
  **Object List**:
@@ -63,38 +94,51 @@ The output will be a set of hockey objects (defined by "bounding boxes") with la
63
  st.markdown(desc)
64
 
65
  st.subheader("Sample")
66
- st.image('src/images/samples/v2/v2-sample1-090124.png', caption='Sample image with hockey objects detected')
 
67
 
68
  st.subheader("Validation Results")
69
 
70
  st.markdown('''Validation of the model\'s performance was done using 15 images not included in the training set. The model had many issues; it did poorly with detecting *pucks* and *sticks* vs backgrounds and even goalies and skaters. It did very well on detecting referees.''')
71
- st.image("src/images/artifacts/confusion_matrix_v2.png", caption="Confusion Matrix for Hockey Breeds v2", )
 
 
 
 
 
72
 
73
- st.subheader("Try It Out")
 
 
74
 
75
- img = st.file_uploader("Upload an image for object detection", type=["jpg", "jpeg", "png"])
 
 
76
 
 
77
  if img is not None:
78
- thresh =st.slider('Set the object confidence threshold', min_value=0.0, max_value=1.0, value=0.5, step=0.01)
 
 
 
79
  with st.status("Detecting hockey objects..."):
80
  st.write("Loading model...")
81
  model_f = get_model()
82
 
83
- st.write("Processing image...")
84
- frame = cv2.imdecode(np.frombuffer(img.read(), np.uint8), 1)
85
-
86
  st.write("Running inference on image...")
87
- objects = run_inference(frame, model_f, thresh)
 
88
  st.dataframe(objects, column_config={
89
- "0": "Object",
90
- "1": "Confidence"
91
  })
92
 
93
  # check if the results list is empty
94
  if len(st.session_state.results) == 0:
95
- st.image(img, caption='Uploaded Image')
 
96
  else:
97
- st.image(frame, caption='Uploaded Image')
 
 
98
 
99
- else:
100
- st.session_state.results = []
 
9
  repo_id = "danbiagini/hockey_breeds_v2"
10
  return hf_hub_download(repo_id=repo_id, filename="hockey_breeds-v2-101623.pt")
11
 
12
+
13
  def run_inference(img, model, thresh=0.5):
14
  model = YOLO(model_f)
15
  st.session_state.results = model(img)
16
  return draw_hockey_boxes(img, st.session_state.results, thresh)
17
 
18
+
19
  def draw_hockey_boxes(frame, results, thresh=0.5):
20
+ colors = {0: (0, 255, 0), 1: (255, 0, 0), 2: (0, 0, 255), 3: (128, 0, 0), 4: (
21
+ 0, 128, 0), 5: (0, 0, 128), 6: (0, 64, 0), 7: (64, 0, 0), 8: (0, 0, 64)}
22
  font_scale = frame.shape[0] / 500
23
  objects = []
24
+
25
  for name in results:
26
  for box in name.boxes.data.tolist():
27
  x1, y1, x2, y2, score, class_id = box
28
  objects.append((name.names[int(class_id)], score))
29
 
30
  if score > thresh:
31
+ cv2.rectangle(frame, (int(x1), int(y1)),
32
+ (int(x2), int(y2)), colors[(class_id % 9)], 3)
33
  cv2.putText(frame, f'{name.names[int(class_id)].upper()}: {score:.2f}', (int(x1), int(y1 - 10)),
34
  cv2.FONT_HERSHEY_SIMPLEX, font_scale, colors[(class_id % 9)], 3, cv2.LINE_AA)
35
  else:
36
+ print(
37
+ f'Found an object under confidence threshold {thresh} type: {name.names[class_id]}, score:{score}, x1, y2:{x1}, {y2}')
38
  return objects
39
 
40
+ def reset_image():
41
+ st.session_state.img = None
42
+
43
+ def upload_img():
44
+ if st.session_state.upload_img is not None:
45
+ st.session_state.img = st.session_state.upload_img
46
+
47
+ def get_naked_image():
48
+ if st.session_state.img is not None:
49
+ img = st.session_state.img
50
+ img.seek(0)
51
+ return(cv2.imdecode(np.frombuffer(img.read(), np.uint8), 1))
52
+ return None
53
 
54
+ def use_sample_image():
55
+ st.session_state.img = open('src/images/samples/v2/net-chaos.jpg', 'rb')
56
+
57
+ # Init state
58
  if 'results' not in st.session_state:
59
  st.session_state.results = []
60
 
61
+ if 'thresh' not in st.session_state:
62
+ st.session_state.thresh = 0.5
63
+
64
+ if 'img' not in st.session_state:
65
+ st.session_state.img = None
66
+
67
+
68
+ # Top down page rendering
69
  st.set_page_config(page_title='Hockey Breeds v2 - Objects', layout="wide",
70
  page_icon=":frame_with_picture:")
71
 
72
  st.title('Hockey Breeds v2 - Objects')
73
+ intro = '''The first version of Hockey Breeds was fun and educational, but not useful for analyzing hockey videos. The second version is to a proof of concept
74
  with the ability to recognize individual "objects" within an image, which paves the way to ultimately tracking those objects through game play.'''
75
 
76
  st.markdown(intro)
77
+
78
+ st.subheader('Object Detection Technical Details')
79
 
80
  desc = '''Hockey Breed detector v2 uses a state of the art (circa 2023) computer vision approach.
81
 
82
+ I used the same training images as the first version of the Hockey Breeds model, but change the ML algorithm to use YOLO object detection (YOLO v8).
83
  The output will be a set of hockey objects (defined by "bounding boxes") with labels for any hockey image uploaded.
84
 
85
  **Object List**:
 
94
  st.markdown(desc)
95
 
96
  st.subheader("Sample")
97
+ st.image('src/images/samples/v2/v2-sample1-090124.png',
98
+ caption='Sample image with hockey objects detected')
99
 
100
  st.subheader("Validation Results")
101
 
102
  st.markdown('''Validation of the model\'s performance was done using 15 images not included in the training set. The model had many issues; it did poorly with detecting *pucks* and *sticks* vs backgrounds and even goalies and skaters. It did very well on detecting referees.''')
103
+ st.image("src/images/artifacts/confusion_matrix_v2.png",
104
+ caption="Confusion Matrix for Hockey Breeds v2", )
105
+
106
+ st.subheader("Try it out!")
107
+ st.write("Upload an image file to try detecting hockey objects in your own hockey image, or use a sample image below.")
108
+
109
 
110
+ if st.session_state.img is None:
111
+ st.file_uploader("Upload an image and Hockey Breeds v2 will find the hockey objects in the image",
112
+ type=["jpg", "jpeg", "png"], key='upload_img', on_change=upload_img)
113
 
114
+ with st.expander("Sample Images"):
115
+ st.image('src/images/samples/v2/net-chaos.jpg')
116
+ st.button("Use Sample", on_click=use_sample_image)
117
 
118
+ img = get_naked_image()
119
  if img is not None:
120
+
121
+ thresh = st.slider('Set the object confidence threshold', key='thresh',
122
+ min_value=0.0, max_value=1.0, value=0.5, step=0.05)
123
+
124
  with st.status("Detecting hockey objects..."):
125
  st.write("Loading model...")
126
  model_f = get_model()
127
 
 
 
 
128
  st.write("Running inference on image...")
129
+ objects = run_inference(img, model_f, thresh)
130
+
131
  st.dataframe(objects, column_config={
132
+ "0": "Object",
133
+ "1": "Confidence"
134
  })
135
 
136
  # check if the results list is empty
137
  if len(st.session_state.results) == 0:
138
+ st.write('**No hockey objects found in image!**')
139
+ st.image(img, caption='Uploaded Image had no hockey objects')
140
  else:
141
+ st.image(img, caption='Image with hockey object bounding boxes')
142
+
143
+ st.button("Reset Image", on_click=reset_image)
144
 
 
 
src/images/samples/v2/net-chaos.jpg ADDED