Spaces:
Running
Running
File size: 5,702 Bytes
386f26f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import gradio as gr
import os
from ultralytics import YOLO
import numpy as np
import json
from PIL import Image, ImageDraw
# Define keypoints we need for rigging
KEYPOINTS = {
0: {"name": "chin (nose)"},
7: {"name": "left_elbow"},
8: {"name": "right_elbow"},
9: {"name": "left_wrist"},
10: {"name": "right_wrist"},
13: {"name": "left_knee"},
14: {"name": "right_knee"}
}
# Initialize model
model = None
def load_model():
"""Load the YOLO pose estimation model"""
global model
if model is None:
model_path = 'yolov8s-pose.pt'
if os.path.exists(model_path):
try:
model = YOLO(model_path)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
model = None
else:
print(f"Model file not found: {model_path}")
return model
def process_image(input_image):
"""
Process an image for pose estimation and return keypoint coordinates
Args:
input_image: Input image (PIL Image or numpy array)
Returns:
Tuple of (visualization image, JSON results string)
"""
# Load model if not already loaded
if load_model() is None:
return None, json.dumps({"error": "Model not available"})
try:
# Convert to PIL if needed
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image)
# Run inference
results = model.predict(input_image, verbose=False)
# Process keypoint data
keypoint_data = {}
if not results or len(results) == 0:
return input_image, json.dumps({"error": "No pose detection results found"})
result = results[0]
if not hasattr(result, "keypoints") or result.keypoints is None:
return input_image, json.dumps({"error": "No keypoints detected in the image"})
try:
keypoints = result.keypoints.data.cpu().numpy()
except AttributeError:
return input_image, json.dumps({"error": "Error accessing keypoints data"})
if len(keypoints) == 0:
return input_image, json.dumps({"error": "No people detected in the image"})
# Get first person's keypoints
kp = keypoints[0]
# Extract keypoints
for idx, keypoint_info in KEYPOINTS.items():
if idx < len(kp) and kp[idx][2] > 0.5: # Confidence threshold
x, y, conf = kp[idx]
keypoint_data[keypoint_info["name"]] = {
"x": int(x),
"y": int(y),
"confidence": float(conf)
}
# Add groin point (midpoint between points 11 and 12)
if len(kp) > 12 and kp[11][2] > 0.5 and kp[12][2] > 0.5:
groin_x = int((kp[11][0] + kp[12][0]) / 2)
groin_y = int((kp[11][1] + kp[12][1]) / 2)
groin_conf = (float(kp[11][2]) + float(kp[12][2])) / 2
keypoint_data["groin"] = {
"x": groin_x,
"y": groin_y,
"confidence": groin_conf
}
# Create visualization image
vis_image = Image.fromarray(input_image.copy())
draw = ImageDraw.Draw(vis_image)
# Draw keypoints
for point_name, point_data in keypoint_data.items():
x, y = point_data["x"], point_data["y"]
# Draw a circle at each keypoint
radius = 5
draw.ellipse(
[(x - radius, y - radius), (x + radius, y + radius)],
fill="red"
)
# Add text label
draw.text((x + 10, y), point_name, fill="black")
return np.array(vis_image), json.dumps({"keypoints": keypoint_data}, indent=2)
except Exception as e:
return input_image, json.dumps({"error": f"Error processing image: {str(e)}"})
# Create Gradio interface
def create_gradio_app():
with gr.Blocks() as demo:
gr.Markdown("# YOLO Pose Estimation API")
gr.Markdown("Upload an image to detect pose keypoints")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Input Image")
submit_btn = gr.Button("Process Image")
with gr.Column():
output_image = gr.Image(label="Visualization")
output_json = gr.JSON(label="Keypoint Data")
submit_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[output_image, output_json]
)
# Add API documentation
gr.Markdown("""
## API Usage
This Gradio app also provides a REST API endpoint at `/api/predict`.
Example usage:
```python
import requests
# Send a POST request to the API endpoint
response = requests.post(
"YOUR_HUGGINGFACE_SPACE_URL/api/predict",
files={"input_image": open("image.jpg", "rb")}
)
# Process results
if response.status_code == 200:
results = response.json()
keypoints = results.get("keypoints", {})
print(keypoints)
else:
print(f"Error: {response.text}")
```
""")
return demo
demo = create_gradio_app()
# Launch app
if __name__ == "__main__":
demo.launch()
else:
# For Hugging Face Spaces
demo.launch(share=False) |