ggirishg commited on
Commit
d92004b
·
verified ·
1 Parent(s): f10bc31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -15
app.py CHANGED
@@ -1,20 +1,145 @@
1
- import gradio
2
  import cv2
 
 
 
 
3
 
 
 
 
 
4
 
5
- def inference(img):
6
- blur = cv2.blur(img,(5,5))
7
- return blur
 
 
 
8
 
9
- # For information on Interfaces, head to https://gradio.app/docs/
10
- # For user guides, head to https://gradio.app/guides/
11
- # For Spaces usage, head to https://huggingface.co/docs/hub/spaces
12
- iface = gradio.Interface(
13
- fn=inference,
14
- inputs='image',
15
- outputs='image',
16
- title='Hello World',
17
- description='The simplest interface!',
18
- examples=["llama.jpg"])
19
 
20
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import cv2
3
+ import mediapipe as mp
4
+ import numpy as np
5
+ import tempfile
6
+ import os
7
 
8
+ # Initialize MediaPipe Pose
9
+ mp_pose = mp.solutions.pose
10
+ pose = mp_pose.Pose(static_image_mode=False, model_complexity=1, enable_segmentation=True, min_detection_confidence=0.5, min_tracking_confidence=0.5)
11
+ mp_drawing = mp.solutions.drawing_utils
12
 
13
+ def calculate_angle_between_vectors(v1, v2):
14
+ unit_vector_1 = v1 / np.linalg.norm(v1)
15
+ unit_vector_2 = v2 / np.linalg.norm(v2)
16
+ dot_product = np.dot(unit_vector_1, unit_vector_2)
17
+ angle = np.arccos(dot_product)
18
+ return np.degrees(angle)
19
 
20
+ def process_video(video_path):
21
+ cap = cv2.VideoCapture(video_path)
22
+ output_dir = tempfile.mkdtemp()
 
 
 
 
 
 
 
23
 
24
+ current_phase = "Not Setup phase"
25
+ prev_wrist_left_y = None
26
+ prev_wrist_right_y = None
27
+ top_backswing_detected = False
28
+ mid_downswing_detected = False
29
+ ball_impact_detected = False
30
+ top_backswing_frame = -2
31
+ mid_downswing_frame = -2
32
+ ball_impact_frame = -2
33
+
34
+ BALL_IMPACT_DURATION = 2 # Duration in frames to display Ball Impact phase
35
+
36
+ MIN_MOVEMENT_THRESHOLD = 0.01
37
+ HIP_NEAR_THRESHOLD = 0.05
38
+ MID_SWING_THRESHOLD = 0.05
39
+
40
+ saved_phases = set()
41
+
42
+ while cap.isOpened():
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+
47
+ frame_no = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
48
+ image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
+ result = pose.process(image_rgb)
50
+ h, w, c = frame.shape
51
+
52
+ if result.pose_landmarks:
53
+ mp_drawing.draw_landmarks(
54
+ frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS,
55
+ mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2),
56
+ mp_drawing.DrawingSpec(color=(255, 0, 255), thickness=2, circle_radius=2)
57
+ )
58
+
59
+ landmarks = result.pose_landmarks.landmark
60
+ wrist_left_y = landmarks[mp_pose.PoseLandmark.LEFT_WRIST].y
61
+ wrist_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST].y
62
+ hip_left_y = landmarks[mp_pose.PoseLandmark.LEFT_HIP].y
63
+ hip_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_HIP].y
64
+ shoulder_left_y = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER].y
65
+ shoulder_right_y = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER].y
66
+
67
+ hip_y_avg = (hip_left_y + hip_right_y) / 2
68
+ shoulder_y_avg = (shoulder_left_y + shoulder_right_y) / 2
69
+ mid_swing_y = (shoulder_y_avg + hip_y_avg) / 2
70
+
71
+ # Ensure the current phase persists for a few more milliseconds if it's Ball Impact
72
+ if ball_impact_detected and frame_no <= ball_impact_frame + BALL_IMPACT_DURATION:
73
+ current_phase = "Ball impact phase"
74
+ elif (abs(wrist_left_y - hip_y_avg) < HIP_NEAR_THRESHOLD and abs(wrist_right_y - hip_y_avg) < HIP_NEAR_THRESHOLD):
75
+ if prev_wrist_left_y is not None and prev_wrist_right_y is not None:
76
+ if (abs(wrist_left_y - prev_wrist_left_y) < MIN_MOVEMENT_THRESHOLD and abs(wrist_right_y - prev_wrist_right_y) < MIN_MOVEMENT_THRESHOLD):
77
+ if mid_downswing_detected and frame_no > mid_downswing_frame:
78
+ current_phase = "Ball impact phase"
79
+ ball_impact_detected = True
80
+ ball_impact_frame = frame_no
81
+ else:
82
+ current_phase = "Setup phase"
83
+ top_backswing_detected = False
84
+ mid_downswing_detected = False
85
+ else:
86
+ current_phase = ""
87
+ else:
88
+ if mid_downswing_detected and frame_no > mid_downswing_frame:
89
+ current_phase = "Ball impact phase"
90
+ ball_impact_detected = True
91
+ ball_impact_frame = frame_no
92
+ else:
93
+ current_phase = "Setup phase"
94
+ top_backswing_detected = False
95
+ mid_downswing_detected = False
96
+ elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and not top_backswing_detected and not ball_impact_detected):
97
+ current_phase = "Mid backswing phase"
98
+ elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and not mid_downswing_detected and not ball_impact_detected):
99
+ current_phase = "Top backswing phase"
100
+ top_backswing_detected = True
101
+ top_backswing_frame = frame_no
102
+ elif (abs(wrist_left_y - mid_swing_y) < MID_SWING_THRESHOLD and abs(wrist_right_y - mid_swing_y) < MID_SWING_THRESHOLD and top_backswing_detected and frame_no > top_backswing_frame):
103
+ current_phase = "Mid downswing phase"
104
+ mid_downswing_detected = True
105
+ mid_downswing_frame = frame_no
106
+ elif (wrist_left_y < shoulder_left_y and wrist_right_y < shoulder_right_y and ball_impact_detected and frame_no > ball_impact_frame):
107
+ current_phase = "Follow through phase"
108
+ else:
109
+ current_phase = ""
110
+
111
+ prev_wrist_left_y = wrist_left_y
112
+ prev_wrist_right_y = wrist_right_y
113
+
114
+ cv2.putText(frame, f"Phase: {current_phase}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
115
+
116
+ # Save the frame for each detected phase
117
+ if current_phase and current_phase not in saved_phases:
118
+ phase_filename = os.path.join(output_dir, f"{current_phase.replace(' ', '_')}.png")
119
+ cv2.imwrite(phase_filename, frame)
120
+ saved_phases.add(current_phase)
121
+
122
+ cap.release()
123
+ cv2.destroyAllWindows()
124
+ pose.close()
125
+
126
+ return output_dir
127
+
128
+ st.title("Golf Swing Phase Detection")
129
+ st.write("Upload a video to detect different phases of a golf swing.")
130
+
131
+ video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])
132
+
133
+ if video_file is not None:
134
+ tfile = tempfile.NamedTemporaryFile(delete=False)
135
+ tfile.write(video_file.read())
136
+ tfile_path = tfile.name
137
+
138
+ st.write("Processing video...")
139
+ output_dir = process_video(tfile_path)
140
+
141
+ st.write("Detected phases saved to:", output_dir)
142
+ st.write("Example frames from detected phases:")
143
+
144
+ for phase_image in os.listdir(output_dir):
145
+ st.image(os.path.join(output_dir, phase_image), caption=phase_image)