tryon / app.py
nithinm19's picture
Update app.py
c25ecb0 verified
raw
history blame
3.22 kB
import gradio as gr
import cv2
import mediapipe as mp
import numpy as np
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import torch
# Initialize Mediapipe Pose Estimation
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=True, model_complexity=2)
mp_drawing = mp.solutions.drawing_utils
# Initialize Segformer Model for Segmentation
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
# Define body part mapping with unique colors
PART_COLORS = {
"head": (0, 255, 0),
"shoulders": (255, 0, 0),
"upper_body": (0, 0, 255),
"arms": (255, 255, 0),
"lower_body": (255, 0, 255)
}
PART_LABELS = {
"head": [0], # Face class in Segformer
"shoulders": [2], # Upper body classes (may include neck, shoulders)
"upper_body": [3, 4], # Torso classes
"arms": [5, 6], # Arms
"lower_body": [7, 8] # Legs
}
def segment_image(image):
# Preprocess the image for Segformer
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
segmentation = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
# Resize segmentation mask to match original image size
segmentation_resized = cv2.resize(segmentation, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
# Create a blank mask image with the same size as the original image
segmented_image = np.zeros_like(image)
# Color each part with unique colors
for part, color in PART_COLORS.items():
mask = np.isin(segmentation_resized, PART_LABELS[part])
segmented_image[mask] = color
return segmented_image
def estimate_pose(image):
# Convert image from BGR (OpenCV) to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Perform pose detection
results = pose.process(image_rgb)
if not results.pose_landmarks:
return image, segment_image(image) # Return original image and segmented image if no pose found
# Draw pose landmarks on the image
annotated_image = image.copy()
mp_drawing.draw_landmarks(
annotated_image,
results.pose_landmarks,
mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2),
connection_drawing_spec=mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=2),
)
return annotated_image, segment_image(image)
# Gradio Interface
interface = gr.Interface(
fn=estimate_pose,
inputs=gr.Image(type="numpy", label="Upload an Image"),
outputs=[
gr.Image(type="numpy", label="Pose Landmarks Image"),
gr.Image(type="numpy", label="Segmented Body Parts"),
],
title="Human Pose Estimation and Segmentation",
description="Upload an image to detect and visualize human pose landmarks and segment body parts (head, shoulders, upper body, arms, lower body) with different colors.",
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()