pose-estimation / app.py
shihabsarar29's picture
Create app.py
386f26f verified
raw
history blame
5.7 kB
import gradio as gr
import os
from ultralytics import YOLO
import numpy as np
import json
from PIL import Image, ImageDraw
# Define keypoints we need for rigging
KEYPOINTS = {
0: {"name": "chin (nose)"},
7: {"name": "left_elbow"},
8: {"name": "right_elbow"},
9: {"name": "left_wrist"},
10: {"name": "right_wrist"},
13: {"name": "left_knee"},
14: {"name": "right_knee"}
}
# Initialize model
model = None
def load_model():
"""Load the YOLO pose estimation model"""
global model
if model is None:
model_path = 'yolov8s-pose.pt'
if os.path.exists(model_path):
try:
model = YOLO(model_path)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
model = None
else:
print(f"Model file not found: {model_path}")
return model
def process_image(input_image):
"""
Process an image for pose estimation and return keypoint coordinates
Args:
input_image: Input image (PIL Image or numpy array)
Returns:
Tuple of (visualization image, JSON results string)
"""
# Load model if not already loaded
if load_model() is None:
return None, json.dumps({"error": "Model not available"})
try:
# Convert to PIL if needed
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image)
# Run inference
results = model.predict(input_image, verbose=False)
# Process keypoint data
keypoint_data = {}
if not results or len(results) == 0:
return input_image, json.dumps({"error": "No pose detection results found"})
result = results[0]
if not hasattr(result, "keypoints") or result.keypoints is None:
return input_image, json.dumps({"error": "No keypoints detected in the image"})
try:
keypoints = result.keypoints.data.cpu().numpy()
except AttributeError:
return input_image, json.dumps({"error": "Error accessing keypoints data"})
if len(keypoints) == 0:
return input_image, json.dumps({"error": "No people detected in the image"})
# Get first person's keypoints
kp = keypoints[0]
# Extract keypoints
for idx, keypoint_info in KEYPOINTS.items():
if idx < len(kp) and kp[idx][2] > 0.5: # Confidence threshold
x, y, conf = kp[idx]
keypoint_data[keypoint_info["name"]] = {
"x": int(x),
"y": int(y),
"confidence": float(conf)
}
# Add groin point (midpoint between points 11 and 12)
if len(kp) > 12 and kp[11][2] > 0.5 and kp[12][2] > 0.5:
groin_x = int((kp[11][0] + kp[12][0]) / 2)
groin_y = int((kp[11][1] + kp[12][1]) / 2)
groin_conf = (float(kp[11][2]) + float(kp[12][2])) / 2
keypoint_data["groin"] = {
"x": groin_x,
"y": groin_y,
"confidence": groin_conf
}
# Create visualization image
vis_image = Image.fromarray(input_image.copy())
draw = ImageDraw.Draw(vis_image)
# Draw keypoints
for point_name, point_data in keypoint_data.items():
x, y = point_data["x"], point_data["y"]
# Draw a circle at each keypoint
radius = 5
draw.ellipse(
[(x - radius, y - radius), (x + radius, y + radius)],
fill="red"
)
# Add text label
draw.text((x + 10, y), point_name, fill="black")
return np.array(vis_image), json.dumps({"keypoints": keypoint_data}, indent=2)
except Exception as e:
return input_image, json.dumps({"error": f"Error processing image: {str(e)}"})
# Create Gradio interface
def create_gradio_app():
with gr.Blocks() as demo:
gr.Markdown("# YOLO Pose Estimation API")
gr.Markdown("Upload an image to detect pose keypoints")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Input Image")
submit_btn = gr.Button("Process Image")
with gr.Column():
output_image = gr.Image(label="Visualization")
output_json = gr.JSON(label="Keypoint Data")
submit_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[output_image, output_json]
)
# Add API documentation
gr.Markdown("""
## API Usage
This Gradio app also provides a REST API endpoint at `/api/predict`.
Example usage:
```python
import requests
# Send a POST request to the API endpoint
response = requests.post(
"YOUR_HUGGINGFACE_SPACE_URL/api/predict",
files={"input_image": open("image.jpg", "rb")}
)
# Process results
if response.status_code == 200:
results = response.json()
keypoints = results.get("keypoints", {})
print(keypoints)
else:
print(f"Error: {response.text}")
```
""")
return demo
demo = create_gradio_app()
# Launch app
if __name__ == "__main__":
demo.launch()
else:
# For Hugging Face Spaces
demo.launch(share=False)