shihabsarar29 commited on
Commit
386f26f
·
verified ·
1 Parent(s): f53d570

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from ultralytics import YOLO
4
+ import numpy as np
5
+ import json
6
+ from PIL import Image, ImageDraw
7
+
8
+ # Define keypoints we need for rigging
9
+ KEYPOINTS = {
10
+ 0: {"name": "chin (nose)"},
11
+ 7: {"name": "left_elbow"},
12
+ 8: {"name": "right_elbow"},
13
+ 9: {"name": "left_wrist"},
14
+ 10: {"name": "right_wrist"},
15
+ 13: {"name": "left_knee"},
16
+ 14: {"name": "right_knee"}
17
+ }
18
+
19
+ # Initialize model
20
+ model = None
21
+
22
+ def load_model():
23
+ """Load the YOLO pose estimation model"""
24
+ global model
25
+ if model is None:
26
+ model_path = 'yolov8s-pose.pt'
27
+ if os.path.exists(model_path):
28
+ try:
29
+ model = YOLO(model_path)
30
+ print("Model loaded successfully")
31
+ except Exception as e:
32
+ print(f"Error loading model: {e}")
33
+ model = None
34
+ else:
35
+ print(f"Model file not found: {model_path}")
36
+ return model
37
+
38
+ def process_image(input_image):
39
+ """
40
+ Process an image for pose estimation and return keypoint coordinates
41
+
42
+ Args:
43
+ input_image: Input image (PIL Image or numpy array)
44
+
45
+ Returns:
46
+ Tuple of (visualization image, JSON results string)
47
+ """
48
+ # Load model if not already loaded
49
+ if load_model() is None:
50
+ return None, json.dumps({"error": "Model not available"})
51
+
52
+ try:
53
+ # Convert to PIL if needed
54
+ if not isinstance(input_image, np.ndarray):
55
+ input_image = np.array(input_image)
56
+
57
+ # Run inference
58
+ results = model.predict(input_image, verbose=False)
59
+
60
+ # Process keypoint data
61
+ keypoint_data = {}
62
+
63
+ if not results or len(results) == 0:
64
+ return input_image, json.dumps({"error": "No pose detection results found"})
65
+
66
+ result = results[0]
67
+
68
+ if not hasattr(result, "keypoints") or result.keypoints is None:
69
+ return input_image, json.dumps({"error": "No keypoints detected in the image"})
70
+
71
+ try:
72
+ keypoints = result.keypoints.data.cpu().numpy()
73
+ except AttributeError:
74
+ return input_image, json.dumps({"error": "Error accessing keypoints data"})
75
+
76
+ if len(keypoints) == 0:
77
+ return input_image, json.dumps({"error": "No people detected in the image"})
78
+
79
+ # Get first person's keypoints
80
+ kp = keypoints[0]
81
+
82
+ # Extract keypoints
83
+ for idx, keypoint_info in KEYPOINTS.items():
84
+ if idx < len(kp) and kp[idx][2] > 0.5: # Confidence threshold
85
+ x, y, conf = kp[idx]
86
+ keypoint_data[keypoint_info["name"]] = {
87
+ "x": int(x),
88
+ "y": int(y),
89
+ "confidence": float(conf)
90
+ }
91
+
92
+ # Add groin point (midpoint between points 11 and 12)
93
+ if len(kp) > 12 and kp[11][2] > 0.5 and kp[12][2] > 0.5:
94
+ groin_x = int((kp[11][0] + kp[12][0]) / 2)
95
+ groin_y = int((kp[11][1] + kp[12][1]) / 2)
96
+ groin_conf = (float(kp[11][2]) + float(kp[12][2])) / 2
97
+ keypoint_data["groin"] = {
98
+ "x": groin_x,
99
+ "y": groin_y,
100
+ "confidence": groin_conf
101
+ }
102
+
103
+ # Create visualization image
104
+ vis_image = Image.fromarray(input_image.copy())
105
+ draw = ImageDraw.Draw(vis_image)
106
+
107
+ # Draw keypoints
108
+ for point_name, point_data in keypoint_data.items():
109
+ x, y = point_data["x"], point_data["y"]
110
+ # Draw a circle at each keypoint
111
+ radius = 5
112
+ draw.ellipse(
113
+ [(x - radius, y - radius), (x + radius, y + radius)],
114
+ fill="red"
115
+ )
116
+ # Add text label
117
+ draw.text((x + 10, y), point_name, fill="black")
118
+
119
+ return np.array(vis_image), json.dumps({"keypoints": keypoint_data}, indent=2)
120
+
121
+ except Exception as e:
122
+ return input_image, json.dumps({"error": f"Error processing image: {str(e)}"})
123
+
124
+ # Create Gradio interface
125
+ def create_gradio_app():
126
+ with gr.Blocks() as demo:
127
+ gr.Markdown("# YOLO Pose Estimation API")
128
+ gr.Markdown("Upload an image to detect pose keypoints")
129
+
130
+ with gr.Row():
131
+ with gr.Column():
132
+ input_image = gr.Image(type="numpy", label="Input Image")
133
+ submit_btn = gr.Button("Process Image")
134
+
135
+ with gr.Column():
136
+ output_image = gr.Image(label="Visualization")
137
+ output_json = gr.JSON(label="Keypoint Data")
138
+
139
+ submit_btn.click(
140
+ fn=process_image,
141
+ inputs=[input_image],
142
+ outputs=[output_image, output_json]
143
+ )
144
+
145
+ # Add API documentation
146
+ gr.Markdown("""
147
+ ## API Usage
148
+
149
+ This Gradio app also provides a REST API endpoint at `/api/predict`.
150
+
151
+ Example usage:
152
+ ```python
153
+ import requests
154
+
155
+ # Send a POST request to the API endpoint
156
+ response = requests.post(
157
+ "YOUR_HUGGINGFACE_SPACE_URL/api/predict",
158
+ files={"input_image": open("image.jpg", "rb")}
159
+ )
160
+
161
+ # Process results
162
+ if response.status_code == 200:
163
+ results = response.json()
164
+ keypoints = results.get("keypoints", {})
165
+ print(keypoints)
166
+ else:
167
+ print(f"Error: {response.text}")
168
+ ```
169
+ """)
170
+
171
+ return demo
172
+
173
+ demo = create_gradio_app()
174
+
175
+ # Launch app
176
+ if __name__ == "__main__":
177
+ demo.launch()
178
+ else:
179
+ # For Hugging Face Spaces
180
+ demo.launch(share=False)