Spaces:
Sleeping
Sleeping
File size: 6,519 Bytes
bc686e7 315f1eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import streamlit as st
import cv2
import mediapipe as mp
import numpy as np
import tempfile
import os
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, enable_segmentation=True, min_detection_confidence=0.5, min_tracking_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils
def calculate_angle_between_vectors(v1, v2):
unit_vector_1 = v1 / np.linalg.norm(v1)
unit_vector_2 = v2 / np.linalg.norm(v2)
dot_product = np.dot(unit_vector_1, unit_vector_2)
angle = np.arccos(dot_product)
return np.degrees(angle)
def process_video(video_path):
cap = cv2.VideoCapture(video_path)
output_dir = tempfile.mkdtemp()
current_phase = "Not Setup phase"
prev_wrist_left_y = None
prev_wrist_right_y = None
top_backswing_detected = False
mid_downswing_detected = False
ball_impact_detected = False
top_backswing_frame = -2
mid_downswing_frame = -2
ball_impact_frame = -2
BALL_IMPACT_DURATION = 2 # Duration in frames to display Ball Impact phase
MIN_MOVEMENT_THRESHOLD = 0.01
HIP_NEAR_THRESHOLD = 0.05
MID_SWING_THRESHOLD = 0.05
saved_phases = set()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_no = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
result = pose.process(image_rgb)
h, w, c = frame.shape
if result.pose_landmarks:
mp_drawing.draw_landmarks(
frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS,
mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
)
landmarks = result.pose_landmarks.landmark
wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
hip_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y
shoulder_left_y = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y
shoulder_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y
hip_y_avg = (hip_left_y + hip_right_y) / 2
shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2
mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2
# Ensure the current phase persists for a few more milliseconds if it's Ball Impact
if ball_impact_detected and frame_no <= ball_impact_frame + BALL_IMPACT_DURATION:
current_phase = "Ball impact phase"
elif (abs(wrist_left_y - hip_y_avg) < HIP_NEAR_THRESHOLD and abs(wrist_right_y - hip_y_avg) < HIP_NEAR_THRESHOLD):
if prev_wrist_left_y is not None and prev_wrist_right_y is not None:
if (abs(wrist_left_y - prev_wrist_left_y) < MIN_MOVEMENT_THRESHOLD and abs(wrist_right_y - prev_wrist_right_y) < MIN_MOVEMENT_THRESHOLD):
if mid_downswing_detected and frame_no > mid_downswing_frame:
current_phase = "Ball impact phase"
ball_impact_detected = True
ball_impact_frame = frame_no
else:
current_phase = "Setup phase"
top_backswing_detected = False
mid_downswing_detected = False
else:
current_phase = ""
else:
if mid_downswing_detected and frame_no > mid_downswing_frame:
current_phase = "Ball impact phase"
ball_impact_detected = True
ball_impact_frame = frame_no
else:
current_phase = "Setup phase"
top_backswing_detected = False
mid_downswing_detected = False
elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and not top_backswing_detected and not ball_impact_detected):
current_phase = "Mid backswing phase"
elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and not mid_downswing_detected and not ball_impact_detected):
current_phase = "Top backswing phase"
top_backswing_detected = True
top_backswing_frame = frame_no
elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and top_backswing_detected and frame_no > top_backswing_frame):
current_phase = "Mid downswing phase"
mid_downswing_detected = True
mid_downswing_frame = frame_no
elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and ball_impact_detected and frame_no > ball_impact_frame):
current_phase = "Follow through phase"
else:
current_phase = ""
prev_wrist_left_y = wrist_left_y
prev_wrist_right_y = wrist_right_y
cv2.putText(frame, f"Phase: {current_phase}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# Save the frame for each detected phase
if current_phase and current_phase not in saved_phases:
phase_filename = os.path.join(output_dir, f"{current_phase.replace(' ', '_')}.png")
cv2.imwrite(phase_filename, frame)
saved_phases.add(current_phase)
cap.release()
cv2.destroyAllWindows()
pose.close()
return output_dir
st.title("Golf Swing Phase Detection")
st.write("Upload a video to detect different phases of a golf swing.")
video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])
if video_file is not None:
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(video_file.read())
tfile_path = tfile.name
st.write("Processing video...")
output_dir = process_video(tfile_path)
st.write("Detected phases saved to:", output_dir)
st.write("Example frames from detected phases:")
for phase_image in os.listdir(output_dir):
st.image(os.path.join(output_dir, phase_image), caption=phase_image) |