File size: 5,702 Bytes
386f26f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import os
from ultralytics import YOLO
import numpy as np
import json
from PIL import Image, ImageDraw

# Define keypoints we need for rigging
KEYPOINTS = {
    0: {"name": "chin (nose)"},
    7: {"name": "left_elbow"},
    8: {"name": "right_elbow"},
    9: {"name": "left_wrist"},
    10: {"name": "right_wrist"},
    13: {"name": "left_knee"},
    14: {"name": "right_knee"}
}

# Initialize model
model = None

def load_model():
    """Load the YOLO pose estimation model"""
    global model
    if model is None:
        model_path = 'yolov8s-pose.pt'
        if os.path.exists(model_path):
            try:
                model = YOLO(model_path)
                print("Model loaded successfully")
            except Exception as e:
                print(f"Error loading model: {e}")
                model = None
        else:
            print(f"Model file not found: {model_path}")
    return model

def process_image(input_image):
    """
    Process an image for pose estimation and return keypoint coordinates
    
    Args:
        input_image: Input image (PIL Image or numpy array)
    
    Returns:
        Tuple of (visualization image, JSON results string)
    """
    # Load model if not already loaded
    if load_model() is None:
        return None, json.dumps({"error": "Model not available"})
    
    try:
        # Convert to PIL if needed
        if not isinstance(input_image, np.ndarray):
            input_image = np.array(input_image)
        
        # Run inference
        results = model.predict(input_image, verbose=False)
        
        # Process keypoint data
        keypoint_data = {}
        
        if not results or len(results) == 0:
            return input_image, json.dumps({"error": "No pose detection results found"})
            
        result = results[0]
        
        if not hasattr(result, "keypoints") or result.keypoints is None:
            return input_image, json.dumps({"error": "No keypoints detected in the image"})
            
        try:
            keypoints = result.keypoints.data.cpu().numpy()
        except AttributeError:
            return input_image, json.dumps({"error": "Error accessing keypoints data"})
            
        if len(keypoints) == 0:
            return input_image, json.dumps({"error": "No people detected in the image"})
            
        # Get first person's keypoints
        kp = keypoints[0]
        
        # Extract keypoints
        for idx, keypoint_info in KEYPOINTS.items():
            if idx < len(kp) and kp[idx][2] > 0.5:  # Confidence threshold
                x, y, conf = kp[idx]
                keypoint_data[keypoint_info["name"]] = {
                    "x": int(x),
                    "y": int(y),
                    "confidence": float(conf)
                }
        
        # Add groin point (midpoint between points 11 and 12)
        if len(kp) > 12 and kp[11][2] > 0.5 and kp[12][2] > 0.5:
            groin_x = int((kp[11][0] + kp[12][0]) / 2)
            groin_y = int((kp[11][1] + kp[12][1]) / 2)
            groin_conf = (float(kp[11][2]) + float(kp[12][2])) / 2
            keypoint_data["groin"] = {
                "x": groin_x,
                "y": groin_y,
                "confidence": groin_conf
            }
        
        # Create visualization image
        vis_image = Image.fromarray(input_image.copy())
        draw = ImageDraw.Draw(vis_image)
        
        # Draw keypoints
        for point_name, point_data in keypoint_data.items():
            x, y = point_data["x"], point_data["y"]
            # Draw a circle at each keypoint
            radius = 5
            draw.ellipse(
                [(x - radius, y - radius), (x + radius, y + radius)],
                fill="red"
            )
            # Add text label
            draw.text((x + 10, y), point_name, fill="black")
        
        return np.array(vis_image), json.dumps({"keypoints": keypoint_data}, indent=2)
    
    except Exception as e:
        return input_image, json.dumps({"error": f"Error processing image: {str(e)}"})

# Create Gradio interface
def create_gradio_app():
    with gr.Blocks() as demo:
        gr.Markdown("# YOLO Pose Estimation API")
        gr.Markdown("Upload an image to detect pose keypoints")
        
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="numpy", label="Input Image")
                submit_btn = gr.Button("Process Image")
            
            with gr.Column():
                output_image = gr.Image(label="Visualization")
                output_json = gr.JSON(label="Keypoint Data")
        
        submit_btn.click(
            fn=process_image,
            inputs=[input_image],
            outputs=[output_image, output_json]
        )
        
        # Add API documentation
        gr.Markdown("""
        ## API Usage
        
        This Gradio app also provides a REST API endpoint at `/api/predict`.
        
        Example usage:
        ```python
        import requests
        
        # Send a POST request to the API endpoint
        response = requests.post(
            "YOUR_HUGGINGFACE_SPACE_URL/api/predict",
            files={"input_image": open("image.jpg", "rb")}
        )
        
        # Process results
        if response.status_code == 200:
            results = response.json()
            keypoints = results.get("keypoints", {})
            print(keypoints)
        else:
            print(f"Error: {response.text}")
        ```
        """)
    
    return demo

demo = create_gradio_app()

# Launch app
if __name__ == "__main__":
    demo.launch()
else:
    # For Hugging Face Spaces
    demo.launch(share=False)