File size: 19,161 Bytes
6a87530
 
 
 
 
 
9aa8a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d92004b
967037e
d92004b
 
 
 
9aa8a22
bc686e7
d92004b
 
 
bc686e7
9aa8a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc686e7
d92004b
9aa8a22
 
 
 
 
 
 
 
d92004b
 
bc686e7
9aa8a22
 
 
 
d92004b
9aa8a22
 
 
d92004b
9aa8a22
d92004b
 
9aa8a22
 
 
 
 
 
 
 
 
 
 
 
 
 
d92004b
 
 
 
9aa8a22
d92004b
 
 
 
 
 
9aa8a22
d92004b
 
 
 
 
 
 
 
 
9aa8a22
 
 
 
 
 
 
 
d92004b
 
9aa8a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d92004b
9aa8a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d92004b
9aa8a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d92004b
9aa8a22
 
 
 
 
 
 
 
 
d92004b
9aa8a22
 
 
d92004b
9aa8a22
 
 
 
 
 
 
 
d92004b
9aa8a22
 
 
 
 
d92004b
9aa8a22
 
 
 
 
d92004b
 
 
 
 
 
9aa8a22
 
d92004b
 
 
 
 
 
 
 
 
 
 
 
 
9aa8a22
d92004b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
import subprocess

# Ensure setup.sh is executable and then run it using bash
subprocess.run(['chmod', '+x', 'setup.sh'])
subprocess.run(['bash', 'setup.sh'], check=True)

# import streamlit as st
# import cv2
# import mediapipe as mp
# import numpy as np
# import tempfile
# import os

# # Initialize MediaPipe Pose
# mp_pose = mp.solutions.pose
# pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, enable_segmentation=True, min_detection_confidence=0.5, min_tracking_confidence=0.5)
# mp_drawing = mp.solutions.drawing_utils

# def calculate_angle_between_vectors(v1, v2):
#     unit_vector_1 = v1 / np.linalg.norm(v1)
#     unit_vector_2 = v2 / np.linalg.norm(v2)
#     dot_product = np.dot(unit_vector_1, unit_vector_2)
#     angle = np.arccos(dot_product)
#     return np.degrees(angle)

# def process_video(video_path):
#     cap = cv2.VideoCapture(video_path)
#     output_dir = tempfile.mkdtemp()

#     current_phase = "Not Setup phase"
#     prev_wrist_left_y = None
#     prev_wrist_right_y = None
#     top_backswing_detected = False
#     mid_downswing_detected = False
#     ball_impact_detected = False
#     top_backswing_frame = -2
#     mid_downswing_frame = -2
#     ball_impact_frame = -2

#     BALL_IMPACT_DURATION = 2  # Duration in frames to display Ball Impact phase

#     MIN_MOVEMENT_THRESHOLD = 0.01
#     HIP_NEAR_THRESHOLD = 0.05
#     MID_SWING_THRESHOLD = 0.05

#     saved_phases = set()

#     while cap.isOpened():
#         ret, frame = cap.read()
#         if not ret:
#             break

#         frame_no = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
#         image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#         result = pose.process(image_rgb)
#         h, w, c = frame.shape

#         if result.pose_landmarks:
#             mp_drawing.draw_landmarks(
#                 frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS,
#                 mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
#                 mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
#             )

#             landmarks = result.pose_landmarks.landmark
#             wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
#             wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
#             hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
#             hip_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y
#             shoulder_left_y = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y
#             shoulder_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y

#             hip_y_avg = (hip_left_y + hip_right_y) / 2
#             shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2
#             mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2

#             # Ensure the current phase persists for a few more milliseconds if it's Ball Impact
#             if ball_impact_detected and frame_no <= ball_impact_frame + BALL_IMPACT_DURATION:
#                 current_phase = "Ball impact phase"
#             elif (abs(wrist_left_y - hip_y_avg) < HIP_NEAR_THRESHOLD and abs(wrist_right_y - hip_y_avg) < HIP_NEAR_THRESHOLD):
#                 if prev_wrist_left_y is not None and prev_wrist_right_y is not None:
#                     if (abs(wrist_left_y - prev_wrist_left_y) < MIN_MOVEMENT_THRESHOLD and abs(wrist_right_y - prev_wrist_right_y) < MIN_MOVEMENT_THRESHOLD):
#                         if mid_downswing_detected and frame_no > mid_downswing_frame:
#                             current_phase = "Ball impact phase"
#                             ball_impact_detected = True
#                             ball_impact_frame = frame_no
#                         else:
#                             current_phase = "Setup phase"
#                         top_backswing_detected = False
#                         mid_downswing_detected = False
#                     else:
#                         current_phase = ""
#                 else:
#                     if mid_downswing_detected and frame_no > mid_downswing_frame:
#                         current_phase = "Ball impact phase"
#                         ball_impact_detected = True
#                         ball_impact_frame = frame_no
#                     else:
#                         current_phase = "Setup phase"
#                     top_backswing_detected = False
#                     mid_downswing_detected = False
#             elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and not top_backswing_detected and not ball_impact_detected):
#                 current_phase = "Mid backswing phase"
#             elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and not mid_downswing_detected and not ball_impact_detected):
#                 current_phase = "Top backswing phase"
#                 top_backswing_detected = True
#                 top_backswing_frame = frame_no
#             elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and top_backswing_detected and frame_no > top_backswing_frame):
#                 current_phase = "Mid downswing phase"
#                 mid_downswing_detected = True
#                 mid_downswing_frame = frame_no
#             elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and ball_impact_detected and frame_no > ball_impact_frame):
#                 current_phase = "Follow through phase"
#             else:
#                 current_phase = ""

#             prev_wrist_left_y = wrist_left_y
#             prev_wrist_right_y = wrist_right_y

#             cv2.putText(frame, f"Phase: {current_phase}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)

#             # Save the frame for each detected phase
#             if current_phase and current_phase not in saved_phases:
#                 phase_filename = os.path.join(output_dir, f"{current_phase.replace(' ', '_')}.png")
#                 cv2.imwrite(phase_filename, frame)
#                 saved_phases.add(current_phase)

#     cap.release()
#     cv2.destroyAllWindows()
#     pose.close()

#     return output_dir

# st.title("Golf Swing Phase Detection")
# st.write("Upload a video to detect different phases of a golf swing.")

# video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])

# if video_file is not None:
#     tfile = tempfile.NamedTemporaryFile(delete=False)
#     tfile.write(video_file.read())
#     tfile_path = tfile.name

#     st.write("Processing video...")
#     output_dir = process_video(tfile_path)

#     st.write("Detected phases saved to:", output_dir)
#     st.write("Example frames from detected phases:")
    
#     for phase_image in os.listdir(output_dir):
#         st.image(os.path.join(output_dir, phase_image), caption=phase_image)
import streamlit as st
import cv2
import mediapipe as mp
import numpy as np
import tempfile
import os
from collections import deque

# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

# Define states for the state machine
SETUP = "Setup phase"
MID_BACKSWING = "Mid backswing phase"
TOP_BACKSWING = "Top backswing phase"
MID_DOWNSWING = "Mid downswing phase"
BALL_IMPACT = "Ball impact phase"
FOLLOW_THROUGH = "Follow through phase"
UNKNOWN = "Unknown"

# Parameters for logic
NUM_FRAMES_STABLE = 5    # Number of frames to confirm a state transition
VEL_THRESHOLD = 0.003     # Velocity threshold to confirm direction (tune as needed)
MID_POINT_RATIO = 0.5     # Ratio for mid-swing line (between shoulders and hips)
BALL_IMPACT_DURATION = 5  # Frames to keep Ball Impact state stable

def smooth_positions(positions, window=5):
    """Simple smoothing by averaging the last `window` positions."""
    if len(positions) < window:
        return positions[-1]
    arr = np.array(positions[-window:])
    return np.mean(arr, axis=0)

def process_video(video_path):
    pose = mp_pose.Pose(
        static_image_mode=False,
        model_complexity=1,
        enable_segmentation=True,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5,
    )

    cap = cv2.VideoCapture(video_path)
    output_dir = tempfile.mkdtemp()

    # State machine variables
    current_state = UNKNOWN
    last_confirmed_state = UNKNOWN
    state_confirmation_count = 0

    # To store positions and smoothing
    wrist_left_positions = deque(maxlen=30)
    wrist_right_positions = deque(maxlen=30)

    # For saving phases once
    saved_phases = set()

    # Reference positions (will be recorded from initial frames)
    initial_hip_y = None
    initial_shoulder_y = None
    detected_initial_setup = False

    # Variables to track top backswing peak
    # We'll store the max height reached during backswing
    max_wrist_height = None
    top_backswing_reached = False

    # For Ball impact stable frames
    ball_impact_frame_no = -1
    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        result = pose.process(image_rgb)
        h, w, c = frame.shape

        if result.pose_landmarks:
            landmarks = result.pose_landmarks.landmark
            # Extract relevant landmarks
            wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
            wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
            hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
            hip_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y
            shoulder_left_y = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y
            shoulder_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y

            hip_y_avg = (hip_left_y + hip_right_y) / 2
            shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2

            # Record initial reference once at the start if not done
            if initial_hip_y is None:
                initial_hip_y = hip_y_avg
            if initial_shoulder_y is None:
                initial_shoulder_y = shoulder_y_avg

            # Mid swing line (between shoulder and hip)
            mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2

            # Append current positions
            wrist_left_positions.append(wrist_left_y)
            wrist_right_positions.append(wrist_right_y)

            # Smooth positions
            smoothed_left_y = smooth_positions(list(wrist_left_positions))
            smoothed_right_y = smooth_positions(list(wrist_right_positions))

            # Average wrist height
            avg_wrist_y = (smoothed_left_y + smoothed_right_y) / 2.0

            # Compute velocity as difference from last frame (if possible)
            if len(wrist_left_positions) > 1:
                vel_wrist_y = avg_wrist_y - ((wrist_left_positions[-2] + wrist_right_positions[-2]) / 2.0)
            else:
                vel_wrist_y = 0.0

            # Define conditions for each phase based on relative positions and movement:
            # We'll define logical checks:
            # 1. Setup: wrists near hip level and minimal movement
            # 2. Mid backswing: wrists have started moving upward from hip level toward shoulder
            # 3. Top backswing: wrists reach a peak (highest point) and start descending
            # 4. Mid downswing: wrists cross mid line going downward
            # 5. Ball impact: wrists around hip level again with downward movement stabilized
            # 6. Follow through: wrists go above shoulders again after impact

            # Detect initial Setup:
            # Setup if wrists near hips and minimal vertical movement for a few frames
            near_hip = abs(avg_wrist_y - initial_hip_y) < 0.05
            low_velocity = abs(vel_wrist_y) < VEL_THRESHOLD

            # Mid Backswing check:
            # Movement upward from hip towards shoulder
            # Condition: wrist higher than hip but not yet at top, positive upward velocity
            going_up = (vel_wrist_y < -VEL_THRESHOLD)  # remember y is normalized [0..1], top is smaller
            mid_backswing_cond = (avg_wrist_y < mid_swing_y) and (avg_wrist_y < initial_hip_y) and going_up

            # Top Backswing:
            # Detecting a peak: we track max height during backswing.
            # If currently going_up, update max_wrist_height.
            # Once we detect a change from going_up to going_down, we mark top backswing.
            if max_wrist_height is None or avg_wrist_y < max_wrist_height:
                max_wrist_height = avg_wrist_y
            going_down = (vel_wrist_y > VEL_THRESHOLD)
            # Top backswing if we previously were going up and now start going down
            # and wrists are near or above shoulder level (or at least higher than mid swing).
            top_backswing_cond = top_backswing_reached is False and going_down and (max_wrist_height < mid_swing_y)

            # Mid Downswing:
            # After top backswing, as we go down again and cross mid swing line downward
            mid_downswing_cond = top_backswing_reached and (avg_wrist_y > mid_swing_y) and going_down

            # Ball Impact:
            # When wrists return to near hip level while still going down or stabilizing
            # We'll consider ball impact when avg_wrist_y ~ hip level and we've come down from top backswing
            ball_impact_cond = top_backswing_reached and (abs(avg_wrist_y - initial_hip_y) < 0.05) and going_down

            # Follow Through:
            # After impact, if wrists go up again above shoulder level
            follow_through_cond = (ball_impact_frame_no > 0 and frame_count > ball_impact_frame_no + BALL_IMPACT_DURATION
                                   and avg_wrist_y < mid_swing_y and going_up)

            # State machine transitions:
            desired_state = UNKNOWN

            # Prioritize states in a logical order
            if current_state == UNKNOWN:
                # Try to find a stable setup as a start
                if near_hip and low_velocity:
                    desired_state = SETUP
                else:
                    desired_state = UNKNOWN

            elif current_state == SETUP:
                # From setup, if we start going up and cross mid line:
                if mid_backswing_cond:
                    desired_state = MID_BACKSWING
                else:
                    desired_state = SETUP

            elif current_state == MID_BACKSWING:
                # If we detect a top backswing condition (peak reached):
                if top_backswing_cond:
                    desired_state = TOP_BACKSWING
                    top_backswing_reached = True
                else:
                    desired_state = MID_BACKSWING

            elif current_state == TOP_BACKSWING:
                # After top backswing, going down past mid line means mid downswing
                if mid_downswing_cond:
                    desired_state = MID_DOWNSWING
                else:
                    desired_state = TOP_BACKSWING

            elif current_state == MID_DOWNSWING:
                # Reaching ball impact condition
                if ball_impact_cond:
                    desired_state = BALL_IMPACT
                    ball_impact_frame_no = frame_count
                else:
                    desired_state = MID_DOWNSWING

            elif current_state == BALL_IMPACT:
                # After ball impact, potentially follow through if going upward again
                if follow_through_cond:
                    desired_state = FOLLOW_THROUGH
                else:
                    # Keep showing ball impact for a few frames
                    if frame_count <= ball_impact_frame_no + BALL_IMPACT_DURATION:
                        desired_state = BALL_IMPACT
                    else:
                        desired_state = BALL_IMPACT  # could default to unknown if no follow through detected

            elif current_state == FOLLOW_THROUGH:
                # Final phase, usually no more transitions expected
                desired_state = FOLLOW_THROUGH

            # If we are UNKNOWN and can't find a better match:
            if desired_state == UNKNOWN:
                # Try to match any phase heuristics if no known logic fits
                if near_hip and low_velocity:
                    desired_state = SETUP
                else:
                    desired_state = UNKNOWN

            # Confirm state transitions only if stable for several frames
            if desired_state == current_state:
                state_confirmation_count += 1
            else:
                # Different desired state
                if desired_state != UNKNOWN:
                    # Start counting from scratch for the new state
                    current_state = desired_state
                    state_confirmation_count = 1
                else:
                    # If unknown requested, just switch immediately
                    current_state = UNKNOWN
                    state_confirmation_count = 1

            # Once stable enough in a state, set last_confirmed_state
            if state_confirmation_count >= NUM_FRAMES_STABLE:
                last_confirmed_state = current_state

            # Draw Landmarks
            mp_drawing.draw_landmarks(
                frame, 
                result.pose_landmarks, 
                mp_pose.POSE_CONNECTIONS,
                mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
                mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
            )

            cv2.putText(frame, f"Phase: {last_confirmed_state}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, cv2.LINE_AA)

            # Save the frame for each detected phase (once)
            if last_confirmed_state not in saved_phases and last_confirmed_state != UNKNOWN:
                phase_filename = os.path.join(output_dir, f"{last_confirmed_state.replace(' ', '_')}.png")
                cv2.imwrite(phase_filename, frame)
                saved_phases.add(last_confirmed_state)

        cv2.imshow("Pose Estimation", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    pose.close()
    return output_dir

st.title("Golf Swing Phase Detection - Improved Logic")
st.write("Upload a video to detect different phases of a golf swing with improved accuracy.")

video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])

if video_file is not None:
    tfile = tempfile.NamedTemporaryFile(delete=False)
    tfile.write(video_file.read())
    tfile_path = tfile.name

    st.write("Processing video...")
    output_dir = process_video(tfile_path)

    st.write("Detected phases saved to:", output_dir)
    st.write("Example frames from detected phases:")

    for phase_image in os.listdir(output_dir):
        st.image(os.path.join(output_dir, phase_image), caption=phase_image)