ggirishg commited on
Commit
9aa8a22
·
verified ·
1 Parent(s): 218ef70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +367 -76
app.py CHANGED
@@ -4,65 +4,235 @@ import subprocess
4
  subprocess.run(['chmod', '+x', 'setup.sh'])
5
  subprocess.run(['bash', 'setup.sh'], check=True)
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import streamlit as st
8
  import cv2
9
  import mediapipe as mp
10
  import numpy as np
11
  import tempfile
12
  import os
 
13
 
14
  # Initialize MediaPipe Pose
15
  mp_pose = mp.solutions.pose
16
- pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, enable_segmentation=True, min_detection_confidence=0.5, min_tracking_confidence=0.5)
17
  mp_drawing = mp.solutions.drawing_utils
18
 
19
- def calculate_angle_between_vectors(v1, v2):
20
- unit_vector_1 = v1 / np.linalg.norm(v1)
21
- unit_vector_2 = v2 / np.linalg.norm(v2)
22
- dot_product = np.dot(unit_vector_1, unit_vector_2)
23
- angle = np.arccos(dot_product)
24
- return np.degrees(angle)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def process_video(video_path):
 
 
 
 
 
 
 
 
27
  cap = cv2.VideoCapture(video_path)
28
  output_dir = tempfile.mkdtemp()
29
 
30
- current_phase = "Not Setup phase"
31
- prev_wrist_left_y = None
32
- prev_wrist_right_y = None
33
- top_backswing_detected = False
34
- mid_downswing_detected = False
35
- ball_impact_detected = False
36
- top_backswing_frame = -2
37
- mid_downswing_frame = -2
38
- ball_impact_frame = -2
39
-
40
- BALL_IMPACT_DURATION = 2 # Duration in frames to display Ball Impact phase
41
 
42
- MIN_MOVEMENT_THRESHOLD = 0.01
43
- HIP_NEAR_THRESHOLD = 0.05
44
- MID_SWING_THRESHOLD = 0.05
45
 
 
46
  saved_phases = set()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  while cap.isOpened():
49
  ret, frame = cap.read()
50
  if not ret:
51
  break
52
-
53
- frame_no = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
54
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55
  result = pose.process(image_rgb)
56
  h, w, c = frame.shape
57
 
58
  if result.pose_landmarks:
59
- mp_drawing.draw_landmarks(
60
- frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS,
61
- mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
62
- mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
63
- )
64
-
65
  landmarks = result.pose_landmarks.landmark
 
66
  wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
67
  wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
68
  hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
@@ -72,67 +242,188 @@ def process_video(video_path):
72
 
73
  hip_y_avg = (hip_left_y + hip_right_y) / 2
74
  shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2
 
 
 
 
 
 
 
 
75
  mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2
76
 
77
- # Ensure the current phase persists for a few more milliseconds if it's Ball Impact
78
- if ball_impact_detected and frame_no <= ball_impact_frame + BALL_IMPACT_DURATION:
79
- current_phase = "Ball impact phase"
80
- elif (abs(wrist_left_y - hip_y_avg) < HIP_NEAR_THRESHOLD and abs(wrist_right_y - hip_y_avg) < HIP_NEAR_THRESHOLD):
81
- if prev_wrist_left_y is not None and prev_wrist_right_y is not None:
82
- if (abs(wrist_left_y - prev_wrist_left_y) < MIN_MOVEMENT_THRESHOLD and abs(wrist_right_y - prev_wrist_right_y) < MIN_MOVEMENT_THRESHOLD):
83
- if mid_downswing_detected and frame_no > mid_downswing_frame:
84
- current_phase = "Ball impact phase"
85
- ball_impact_detected = True
86
- ball_impact_frame = frame_no
87
- else:
88
- current_phase = "Setup phase"
89
- top_backswing_detected = False
90
- mid_downswing_detected = False
91
- else:
92
- current_phase = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  else:
94
- if mid_downswing_detected and frame_no > mid_downswing_frame:
95
- current_phase = "Ball impact phase"
96
- ball_impact_detected = True
97
- ball_impact_frame = frame_no
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  else:
99
- current_phase = "Setup phase"
100
- top_backswing_detected = False
101
- mid_downswing_detected = False
102
- elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and not top_backswing_detected and not ball_impact_detected):
103
- current_phase = "Mid backswing phase"
104
- elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and not mid_downswing_detected and not ball_impact_detected):
105
- current_phase = "Top backswing phase"
106
- top_backswing_detected = True
107
- top_backswing_frame = frame_no
108
- elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and top_backswing_detected and frame_no > top_backswing_frame):
109
- current_phase = "Mid downswing phase"
110
- mid_downswing_detected = True
111
- mid_downswing_frame = frame_no
112
- elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and ball_impact_detected and frame_no > ball_impact_frame):
113
- current_phase = "Follow through phase"
 
 
114
  else:
115
- current_phase = ""
 
 
 
 
 
 
 
 
116
 
117
- prev_wrist_left_y = wrist_left_y
118
- prev_wrist_right_y = wrist_right_y
 
119
 
120
- cv2.putText(frame, f"Phase: {current_phase}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
121
 
122
- # Save the frame for each detected phase
123
- if current_phase and current_phase not in saved_phases:
124
- phase_filename = os.path.join(output_dir, f"{current_phase.replace(' ', '_')}.png")
 
 
125
  cv2.imwrite(phase_filename, frame)
126
- saved_phases.add(current_phase)
 
 
 
 
127
 
128
  cap.release()
129
  cv2.destroyAllWindows()
130
  pose.close()
131
-
132
  return output_dir
133
 
134
- st.title("Golf Swing Phase Detection")
135
- st.write("Upload a video to detect different phases of a golf swing.")
136
 
137
  video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])
138
 
@@ -146,6 +437,6 @@ if video_file is not None:
146
 
147
  st.write("Detected phases saved to:", output_dir)
148
  st.write("Example frames from detected phases:")
149
-
150
  for phase_image in os.listdir(output_dir):
151
  st.image(os.path.join(output_dir, phase_image), caption=phase_image)
 
4
  subprocess.run(['chmod', '+x', 'setup.sh'])
5
  subprocess.run(['bash', 'setup.sh'], check=True)
6
 
7
+ # import streamlit as st
8
+ # import cv2
9
+ # import mediapipe as mp
10
+ # import numpy as np
11
+ # import tempfile
12
+ # import os
13
+
14
+ # # Initialize MediaPipe Pose
15
+ # mp_pose = mp.solutions.pose
16
+ # pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, enable_segmentation=True, min_detection_confidence=0.5, min_tracking_confidence=0.5)
17
+ # mp_drawing = mp.solutions.drawing_utils
18
+
19
+ # def calculate_angle_between_vectors(v1, v2):
20
+ # unit_vector_1 = v1 / np.linalg.norm(v1)
21
+ # unit_vector_2 = v2 / np.linalg.norm(v2)
22
+ # dot_product = np.dot(unit_vector_1, unit_vector_2)
23
+ # angle = np.arccos(dot_product)
24
+ # return np.degrees(angle)
25
+
26
+ # def process_video(video_path):
27
+ # cap = cv2.VideoCapture(video_path)
28
+ # output_dir = tempfile.mkdtemp()
29
+
30
+ # current_phase = "Not Setup phase"
31
+ # prev_wrist_left_y = None
32
+ # prev_wrist_right_y = None
33
+ # top_backswing_detected = False
34
+ # mid_downswing_detected = False
35
+ # ball_impact_detected = False
36
+ # top_backswing_frame = -2
37
+ # mid_downswing_frame = -2
38
+ # ball_impact_frame = -2
39
+
40
+ # BALL_IMPACT_DURATION = 2 # Duration in frames to display Ball Impact phase
41
+
42
+ # MIN_MOVEMENT_THRESHOLD = 0.01
43
+ # HIP_NEAR_THRESHOLD = 0.05
44
+ # MID_SWING_THRESHOLD = 0.05
45
+
46
+ # saved_phases = set()
47
+
48
+ # while cap.isOpened():
49
+ # ret, frame = cap.read()
50
+ # if not ret:
51
+ # break
52
+
53
+ # frame_no = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
54
+ # image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55
+ # result = pose.process(image_rgb)
56
+ # h, w, c = frame.shape
57
+
58
+ # if result.pose_landmarks:
59
+ # mp_drawing.draw_landmarks(
60
+ # frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS,
61
+ # mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
62
+ # mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
63
+ # )
64
+
65
+ # landmarks = result.pose_landmarks.landmark
66
+ # wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
67
+ # wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
68
+ # hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
69
+ # hip_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y
70
+ # shoulder_left_y = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y
71
+ # shoulder_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y
72
+
73
+ # hip_y_avg = (hip_left_y + hip_right_y) / 2
74
+ # shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2
75
+ # mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2
76
+
77
+ # # Ensure the current phase persists for a few more milliseconds if it's Ball Impact
78
+ # if ball_impact_detected and frame_no <= ball_impact_frame + BALL_IMPACT_DURATION:
79
+ # current_phase = "Ball impact phase"
80
+ # elif (abs(wrist_left_y - hip_y_avg) < HIP_NEAR_THRESHOLD and abs(wrist_right_y - hip_y_avg) < HIP_NEAR_THRESHOLD):
81
+ # if prev_wrist_left_y is not None and prev_wrist_right_y is not None:
82
+ # if (abs(wrist_left_y - prev_wrist_left_y) < MIN_MOVEMENT_THRESHOLD and abs(wrist_right_y - prev_wrist_right_y) < MIN_MOVEMENT_THRESHOLD):
83
+ # if mid_downswing_detected and frame_no > mid_downswing_frame:
84
+ # current_phase = "Ball impact phase"
85
+ # ball_impact_detected = True
86
+ # ball_impact_frame = frame_no
87
+ # else:
88
+ # current_phase = "Setup phase"
89
+ # top_backswing_detected = False
90
+ # mid_downswing_detected = False
91
+ # else:
92
+ # current_phase = ""
93
+ # else:
94
+ # if mid_downswing_detected and frame_no > mid_downswing_frame:
95
+ # current_phase = "Ball impact phase"
96
+ # ball_impact_detected = True
97
+ # ball_impact_frame = frame_no
98
+ # else:
99
+ # current_phase = "Setup phase"
100
+ # top_backswing_detected = False
101
+ # mid_downswing_detected = False
102
+ # elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and not top_backswing_detected and not ball_impact_detected):
103
+ # current_phase = "Mid backswing phase"
104
+ # elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and not mid_downswing_detected and not ball_impact_detected):
105
+ # current_phase = "Top backswing phase"
106
+ # top_backswing_detected = True
107
+ # top_backswing_frame = frame_no
108
+ # elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and top_backswing_detected and frame_no > top_backswing_frame):
109
+ # current_phase = "Mid downswing phase"
110
+ # mid_downswing_detected = True
111
+ # mid_downswing_frame = frame_no
112
+ # elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and ball_impact_detected and frame_no > ball_impact_frame):
113
+ # current_phase = "Follow through phase"
114
+ # else:
115
+ # current_phase = ""
116
+
117
+ # prev_wrist_left_y = wrist_left_y
118
+ # prev_wrist_right_y = wrist_right_y
119
+
120
+ # cv2.putText(frame, f"Phase: {current_phase}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
121
+
122
+ # # Save the frame for each detected phase
123
+ # if current_phase and current_phase not in saved_phases:
124
+ # phase_filename = os.path.join(output_dir, f"{current_phase.replace(' ', '_')}.png")
125
+ # cv2.imwrite(phase_filename, frame)
126
+ # saved_phases.add(current_phase)
127
+
128
+ # cap.release()
129
+ # cv2.destroyAllWindows()
130
+ # pose.close()
131
+
132
+ # return output_dir
133
+
134
+ # st.title("Golf Swing Phase Detection")
135
+ # st.write("Upload a video to detect different phases of a golf swing.")
136
+
137
+ # video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])
138
+
139
+ # if video_file is not None:
140
+ # tfile = tempfile.NamedTemporaryFile(delete=False)
141
+ # tfile.write(video_file.read())
142
+ # tfile_path = tfile.name
143
+
144
+ # st.write("Processing video...")
145
+ # output_dir = process_video(tfile_path)
146
+
147
+ # st.write("Detected phases saved to:", output_dir)
148
+ # st.write("Example frames from detected phases:")
149
+
150
+ # for phase_image in os.listdir(output_dir):
151
+ # st.image(os.path.join(output_dir, phase_image), caption=phase_image)
152
  import streamlit as st
153
  import cv2
154
  import mediapipe as mp
155
  import numpy as np
156
  import tempfile
157
  import os
158
+ from collections import deque
159
 
160
  # Initialize MediaPipe Pose
161
  mp_pose = mp.solutions.pose
 
162
  mp_drawing = mp.solutions.drawing_utils
163
 
164
+ # Define states for the state machine
165
+ SETUP = "Setup phase"
166
+ MID_BACKSWING = "Mid backswing phase"
167
+ TOP_BACKSWING = "Top backswing phase"
168
+ MID_DOWNSWING = "Mid downswing phase"
169
+ BALL_IMPACT = "Ball impact phase"
170
+ FOLLOW_THROUGH = "Follow through phase"
171
+ UNKNOWN = "Unknown"
172
+
173
+ # Parameters for logic
174
+ NUM_FRAMES_STABLE = 5 # Number of frames to confirm a state transition
175
+ VEL_THRESHOLD = 0.003 # Velocity threshold to confirm direction (tune as needed)
176
+ MID_POINT_RATIO = 0.5 # Ratio for mid-swing line (between shoulders and hips)
177
+ BALL_IMPACT_DURATION = 5 # Frames to keep Ball Impact state stable
178
+
179
+ def smooth_positions(positions, window=5):
180
+ """Simple smoothing by averaging the last `window` positions."""
181
+ if len(positions) < window:
182
+ return positions[-1]
183
+ arr = np.array(positions[-window:])
184
+ return np.mean(arr, axis=0)
185
 
186
  def process_video(video_path):
187
+ pose = mp_pose.Pose(
188
+ static_image_mode=False,
189
+ model_complexity=1,
190
+ enable_segmentation=True,
191
+ min_detection_confidence=0.5,
192
+ min_tracking_confidence=0.5,
193
+ )
194
+
195
  cap = cv2.VideoCapture(video_path)
196
  output_dir = tempfile.mkdtemp()
197
 
198
+ # State machine variables
199
+ current_state = UNKNOWN
200
+ last_confirmed_state = UNKNOWN
201
+ state_confirmation_count = 0
 
 
 
 
 
 
 
202
 
203
+ # To store positions and smoothing
204
+ wrist_left_positions = deque(maxlen=30)
205
+ wrist_right_positions = deque(maxlen=30)
206
 
207
+ # For saving phases once
208
  saved_phases = set()
209
 
210
+ # Reference positions (will be recorded from initial frames)
211
+ initial_hip_y = None
212
+ initial_shoulder_y = None
213
+ detected_initial_setup = False
214
+
215
+ # Variables to track top backswing peak
216
+ # We'll store the max height reached during backswing
217
+ max_wrist_height = None
218
+ top_backswing_reached = False
219
+
220
+ # For Ball impact stable frames
221
+ ball_impact_frame_no = -1
222
+ frame_count = 0
223
+
224
  while cap.isOpened():
225
  ret, frame = cap.read()
226
  if not ret:
227
  break
228
+ frame_count += 1
 
229
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
230
  result = pose.process(image_rgb)
231
  h, w, c = frame.shape
232
 
233
  if result.pose_landmarks:
 
 
 
 
 
 
234
  landmarks = result.pose_landmarks.landmark
235
+ # Extract relevant landmarks
236
  wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
237
  wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
238
  hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
 
242
 
243
  hip_y_avg = (hip_left_y + hip_right_y) / 2
244
  shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2
245
+
246
+ # Record initial reference once at the start if not done
247
+ if initial_hip_y is None:
248
+ initial_hip_y = hip_y_avg
249
+ if initial_shoulder_y is None:
250
+ initial_shoulder_y = shoulder_y_avg
251
+
252
+ # Mid swing line (between shoulder and hip)
253
  mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2
254
 
255
+ # Append current positions
256
+ wrist_left_positions.append(wrist_left_y)
257
+ wrist_right_positions.append(wrist_right_y)
258
+
259
+ # Smooth positions
260
+ smoothed_left_y = smooth_positions(list(wrist_left_positions))
261
+ smoothed_right_y = smooth_positions(list(wrist_right_positions))
262
+
263
+ # Average wrist height
264
+ avg_wrist_y = (smoothed_left_y + smoothed_right_y) / 2.0
265
+
266
+ # Compute velocity as difference from last frame (if possible)
267
+ if len(wrist_left_positions) > 1:
268
+ vel_wrist_y = avg_wrist_y - ((wrist_left_positions[-2] + wrist_right_positions[-2]) / 2.0)
269
+ else:
270
+ vel_wrist_y = 0.0
271
+
272
+ # Define conditions for each phase based on relative positions and movement:
273
+ # We'll define logical checks:
274
+ # 1. Setup: wrists near hip level and minimal movement
275
+ # 2. Mid backswing: wrists have started moving upward from hip level toward shoulder
276
+ # 3. Top backswing: wrists reach a peak (highest point) and start descending
277
+ # 4. Mid downswing: wrists cross mid line going downward
278
+ # 5. Ball impact: wrists around hip level again with downward movement stabilized
279
+ # 6. Follow through: wrists go above shoulders again after impact
280
+
281
+ # Detect initial Setup:
282
+ # Setup if wrists near hips and minimal vertical movement for a few frames
283
+ near_hip = abs(avg_wrist_y - initial_hip_y) < 0.05
284
+ low_velocity = abs(vel_wrist_y) < VEL_THRESHOLD
285
+
286
+ # Mid Backswing check:
287
+ # Movement upward from hip towards shoulder
288
+ # Condition: wrist higher than hip but not yet at top, positive upward velocity
289
+ going_up = (vel_wrist_y < -VEL_THRESHOLD) # remember y is normalized [0..1], top is smaller
290
+ mid_backswing_cond = (avg_wrist_y < mid_swing_y) and (avg_wrist_y < initial_hip_y) and going_up
291
+
292
+ # Top Backswing:
293
+ # Detecting a peak: we track max height during backswing.
294
+ # If currently going_up, update max_wrist_height.
295
+ # Once we detect a change from going_up to going_down, we mark top backswing.
296
+ if max_wrist_height is None or avg_wrist_y < max_wrist_height:
297
+ max_wrist_height = avg_wrist_y
298
+ going_down = (vel_wrist_y > VEL_THRESHOLD)
299
+ # Top backswing if we previously were going up and now start going down
300
+ # and wrists are near or above shoulder level (or at least higher than mid swing).
301
+ top_backswing_cond = top_backswing_reached is False and going_down and (max_wrist_height < mid_swing_y)
302
+
303
+ # Mid Downswing:
304
+ # After top backswing, as we go down again and cross mid swing line downward
305
+ mid_downswing_cond = top_backswing_reached and (avg_wrist_y > mid_swing_y) and going_down
306
+
307
+ # Ball Impact:
308
+ # When wrists return to near hip level while still going down or stabilizing
309
+ # We'll consider ball impact when avg_wrist_y ~ hip level and we've come down from top backswing
310
+ ball_impact_cond = top_backswing_reached and (abs(avg_wrist_y - initial_hip_y) < 0.05) and going_down
311
+
312
+ # Follow Through:
313
+ # After impact, if wrists go up again above shoulder level
314
+ follow_through_cond = (ball_impact_frame_no > 0 and frame_count > ball_impact_frame_no + BALL_IMPACT_DURATION
315
+ and avg_wrist_y < mid_swing_y and going_up)
316
+
317
+ # State machine transitions:
318
+ desired_state = UNKNOWN
319
+
320
+ # Prioritize states in a logical order
321
+ if current_state == UNKNOWN:
322
+ # Try to find a stable setup as a start
323
+ if near_hip and low_velocity:
324
+ desired_state = SETUP
325
+ else:
326
+ desired_state = UNKNOWN
327
+
328
+ elif current_state == SETUP:
329
+ # From setup, if we start going up and cross mid line:
330
+ if mid_backswing_cond:
331
+ desired_state = MID_BACKSWING
332
+ else:
333
+ desired_state = SETUP
334
+
335
+ elif current_state == MID_BACKSWING:
336
+ # If we detect a top backswing condition (peak reached):
337
+ if top_backswing_cond:
338
+ desired_state = TOP_BACKSWING
339
+ top_backswing_reached = True
340
+ else:
341
+ desired_state = MID_BACKSWING
342
+
343
+ elif current_state == TOP_BACKSWING:
344
+ # After top backswing, going down past mid line means mid downswing
345
+ if mid_downswing_cond:
346
+ desired_state = MID_DOWNSWING
347
  else:
348
+ desired_state = TOP_BACKSWING
349
+
350
+ elif current_state == MID_DOWNSWING:
351
+ # Reaching ball impact condition
352
+ if ball_impact_cond:
353
+ desired_state = BALL_IMPACT
354
+ ball_impact_frame_no = frame_count
355
+ else:
356
+ desired_state = MID_DOWNSWING
357
+
358
+ elif current_state == BALL_IMPACT:
359
+ # After ball impact, potentially follow through if going upward again
360
+ if follow_through_cond:
361
+ desired_state = FOLLOW_THROUGH
362
+ else:
363
+ # Keep showing ball impact for a few frames
364
+ if frame_count <= ball_impact_frame_no + BALL_IMPACT_DURATION:
365
+ desired_state = BALL_IMPACT
366
  else:
367
+ desired_state = BALL_IMPACT # could default to unknown if no follow through detected
368
+
369
+ elif current_state == FOLLOW_THROUGH:
370
+ # Final phase, usually no more transitions expected
371
+ desired_state = FOLLOW_THROUGH
372
+
373
+ # If we are UNKNOWN and can't find a better match:
374
+ if desired_state == UNKNOWN:
375
+ # Try to match any phase heuristics if no known logic fits
376
+ if near_hip and low_velocity:
377
+ desired_state = SETUP
378
+ else:
379
+ desired_state = UNKNOWN
380
+
381
+ # Confirm state transitions only if stable for several frames
382
+ if desired_state == current_state:
383
+ state_confirmation_count += 1
384
  else:
385
+ # Different desired state
386
+ if desired_state != UNKNOWN:
387
+ # Start counting from scratch for the new state
388
+ current_state = desired_state
389
+ state_confirmation_count = 1
390
+ else:
391
+ # If unknown requested, just switch immediately
392
+ current_state = UNKNOWN
393
+ state_confirmation_count = 1
394
 
395
+ # Once stable enough in a state, set last_confirmed_state
396
+ if state_confirmation_count >= NUM_FRAMES_STABLE:
397
+ last_confirmed_state = current_state
398
 
399
+ # Draw Landmarks
400
+ mp_drawing.draw_landmarks(
401
+ frame,
402
+ result.pose_landmarks,
403
+ mp_pose.POSE_CONNECTIONS,
404
+ mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
405
+ mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
406
+ )
407
 
408
+ cv2.putText(frame, f"Phase: {last_confirmed_state}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, cv2.LINE_AA)
409
+
410
+ # Save the frame for each detected phase (once)
411
+ if last_confirmed_state not in saved_phases and last_confirmed_state != UNKNOWN:
412
+ phase_filename = os.path.join(output_dir, f"{last_confirmed_state.replace(' ', '_')}.png")
413
  cv2.imwrite(phase_filename, frame)
414
+ saved_phases.add(last_confirmed_state)
415
+
416
+ cv2.imshow("Pose Estimation", frame)
417
+ if cv2.waitKey(1) & 0xFF == ord('q'):
418
+ break
419
 
420
  cap.release()
421
  cv2.destroyAllWindows()
422
  pose.close()
 
423
  return output_dir
424
 
425
+ st.title("Golf Swing Phase Detection - Improved Logic")
426
+ st.write("Upload a video to detect different phases of a golf swing with improved accuracy.")
427
 
428
  video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])
429
 
 
437
 
438
  st.write("Detected phases saved to:", output_dir)
439
  st.write("Example frames from detected phases:")
440
+
441
  for phase_image in os.listdir(output_dir):
442
  st.image(os.path.join(output_dir, phase_image), caption=phase_image)