File size: 6,519 Bytes
bc686e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315f1eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import streamlit as st
import cv2
import mediapipe as mp
import numpy as np
import tempfile
import os

# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, enable_segmentation=True, min_detection_confidence=0.5, min_tracking_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils

def calculate_angle_between_vectors(v1, v2):
    unit_vector_1 = v1 / np.linalg.norm(v1)
    unit_vector_2 = v2 / np.linalg.norm(v2)
    dot_product = np.dot(unit_vector_1, unit_vector_2)
    angle = np.arccos(dot_product)
    return np.degrees(angle)

def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    output_dir = tempfile.mkdtemp()

    current_phase = "Not Setup phase"
    prev_wrist_left_y = None
    prev_wrist_right_y = None
    top_backswing_detected = False
    mid_downswing_detected = False
    ball_impact_detected = False
    top_backswing_frame = -2
    mid_downswing_frame = -2
    ball_impact_frame = -2

    BALL_IMPACT_DURATION = 2  # Duration in frames to display Ball Impact phase

    MIN_MOVEMENT_THRESHOLD = 0.01
    HIP_NEAR_THRESHOLD = 0.05
    MID_SWING_THRESHOLD = 0.05

    saved_phases = set()

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_no = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
        image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        result = pose.process(image_rgb)
        h, w, c = frame.shape

        if result.pose_landmarks:
            mp_drawing.draw_landmarks(
                frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS,
                mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
                mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
            )

            landmarks = result.pose_landmarks.landmark
            wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
            wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
            hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
            hip_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y
            shoulder_left_y = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y
            shoulder_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y

            hip_y_avg = (hip_left_y + hip_right_y) / 2
            shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2
            mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2

            # Ensure the current phase persists for a few more milliseconds if it's Ball Impact
            if ball_impact_detected and frame_no <= ball_impact_frame + BALL_IMPACT_DURATION:
                current_phase = "Ball impact phase"
            elif (abs(wrist_left_y - hip_y_avg) < HIP_NEAR_THRESHOLD and abs(wrist_right_y - hip_y_avg) < HIP_NEAR_THRESHOLD):
                if prev_wrist_left_y is not None and prev_wrist_right_y is not None:
                    if (abs(wrist_left_y - prev_wrist_left_y) < MIN_MOVEMENT_THRESHOLD and abs(wrist_right_y - prev_wrist_right_y) < MIN_MOVEMENT_THRESHOLD):
                        if mid_downswing_detected and frame_no > mid_downswing_frame:
                            current_phase = "Ball impact phase"
                            ball_impact_detected = True
                            ball_impact_frame = frame_no
                        else:
                            current_phase = "Setup phase"
                        top_backswing_detected = False
                        mid_downswing_detected = False
                    else:
                        current_phase = ""
                else:
                    if mid_downswing_detected and frame_no > mid_downswing_frame:
                        current_phase = "Ball impact phase"
                        ball_impact_detected = True
                        ball_impact_frame = frame_no
                    else:
                        current_phase = "Setup phase"
                    top_backswing_detected = False
                    mid_downswing_detected = False
            elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and not top_backswing_detected and not ball_impact_detected):
                current_phase = "Mid backswing phase"
            elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and not mid_downswing_detected and not ball_impact_detected):
                current_phase = "Top backswing phase"
                top_backswing_detected = True
                top_backswing_frame = frame_no
            elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and top_backswing_detected and frame_no > top_backswing_frame):
                current_phase = "Mid downswing phase"
                mid_downswing_detected = True
                mid_downswing_frame = frame_no
            elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and ball_impact_detected and frame_no > ball_impact_frame):
                current_phase = "Follow through phase"
            else:
                current_phase = ""

            prev_wrist_left_y = wrist_left_y
            prev_wrist_right_y = wrist_right_y

            cv2.putText(frame, f"Phase: {current_phase}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)

            # Save the frame for each detected phase
            if current_phase and current_phase not in saved_phases:
                phase_filename = os.path.join(output_dir, f"{current_phase.replace(' ', '_')}.png")
                cv2.imwrite(phase_filename, frame)
                saved_phases.add(current_phase)

    cap.release()
    cv2.destroyAllWindows()
    pose.close()

    return output_dir

st.title("Golf Swing Phase Detection")
st.write("Upload a video to detect different phases of a golf swing.")

video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])

if video_file is not None:
    tfile = tempfile.NamedTemporaryFile(delete=False)
    tfile.write(video_file.read())
    tfile_path = tfile.name

    st.write("Processing video...")
    output_dir = process_video(tfile_path)

    st.write("Detected phases saved to:", output_dir)
    st.write("Example frames from detected phases:")
    
    for phase_image in os.listdir(output_dir):
        st.image(os.path.join(output_dir, phase_image), caption=phase_image)