Spaces:
Runtime error
Runtime error
# Visualisation code for SMPL-X model. This code is useful if you already have predictions. | |
import os | |
import sys | |
import os.path as osp | |
import numpy as np | |
import smplx | |
from smplx.joint_names import JOINT_NAMES | |
import torch | |
try: | |
CUR_DIR = osp.dirname(os.path.abspath(__file__)) | |
except NameError: | |
CUR_DIR = os.getcwd() | |
sys.path.insert(0, osp.join(CUR_DIR, '..', 'main')) | |
sys.path.insert(0, osp.join(CUR_DIR , '..', 'common')) | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D | |
JOINT_NAMES_DICT = {name: i for i, name in enumerate(JOINT_NAMES)} | |
# Load the SMPL-X model | |
model_path = 'common/utils/human_model_files' # Update with the path to your SMPL-X models | |
model = smplx.create(model_path, model_type='smplx', gender='neutral', ext='npz') | |
# Load the parameters from the .npz file | |
data = np.load('/home/sahand/Downloads/smplx/00047_9.npz') | |
betas = torch.tensor(data['betas'], dtype=torch.float32) | |
body_pose = torch.tensor(data['body_pose'], dtype=torch.float32) | |
global_orient = torch.tensor(data['global_orient'], dtype=torch.float32) | |
transl = torch.tensor(data['transl'], dtype=torch.float32) | |
expression = torch.tensor(data['expression'], dtype=torch.float32) | |
# Add missing dimensions to the tensors | |
if betas.ndim == 1: | |
betas = betas.unsqueeze(0) | |
if body_pose.ndim == 2: | |
body_pose = body_pose.unsqueeze(0) | |
if global_orient.ndim == 1: | |
global_orient = global_orient.unsqueeze(0) | |
if transl.ndim == 1: | |
transl = transl.unsqueeze(0) | |
if expression.ndim == 1: | |
expression = expression.unsqueeze(0) | |
# Reshape body_pose to include the batch dimension | |
body_pose = body_pose.view(1, -1, 3) | |
# Forward pass through the model | |
output = model(betas=betas, body_pose=body_pose, global_orient=global_orient, transl=transl, expression=expression) | |
# Extract joint positions | |
joints = output.joints.detach().cpu().numpy().squeeze() | |
print(joints.shape) | |
# Ankle joints (left and right) | |
left_knee = joints[4] # Index for left ankle in SMPL-X | |
right_knee = joints[5] # Index for right ankle in SMPL-X | |
left_ankle = joints[7] # Index for left ankle in SMPL-X | |
right_ankle = joints[8] # Index for right ankle in SMPL-X | |
bone_connections = [ | |
(JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["spine1"]), (JOINT_NAMES_DICT["spine1"], JOINT_NAMES_DICT["spine2"]), (JOINT_NAMES_DICT["spine2"], JOINT_NAMES_DICT["spine3"]), # Spine | |
(JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["left_hip"]), (JOINT_NAMES_DICT["left_hip"], JOINT_NAMES_DICT["left_knee"]), (JOINT_NAMES_DICT["left_knee"], JOINT_NAMES_DICT["left_ankle"]), # Left leg | |
(JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["right_hip"]), (JOINT_NAMES_DICT["right_hip"], JOINT_NAMES_DICT["right_knee"]), (JOINT_NAMES_DICT["right_knee"], JOINT_NAMES_DICT["right_ankle"]), # Right leg | |
(JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_heel"]), | |
(JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_heel"]), | |
(JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_foot"]), | |
(JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_big_toe"]), (JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_small_toe"]), | |
(JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_foot"]), | |
(JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_big_toe"]), (JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_small_toe"]), | |
# Add more bones if necessary | |
] | |
# Visualize the 3D skeleton | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
# Plot all joints | |
ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], c='blue', marker='o') | |
# Highlight ankle joints | |
ax.scatter([left_knee[0]], [left_knee[1]], [left_knee[2]], c='red', marker='x', s=100, label='Left Knee') | |
ax.scatter([right_knee[0]], [right_knee[1]], [right_knee[2]], c='green', marker='x', s=100, label='Right Knee') | |
ax.scatter([left_ankle[0]], [left_ankle[1]], [left_ankle[2]], c='red', marker='o', s=100, label='Left Ankle') | |
ax.scatter([right_ankle[0]], [right_ankle[1]], [right_ankle[2]], c='green', marker='o', s=100, label='Right Ankle') | |
# Draw bones | |
for bone in bone_connections: | |
start, end = bone | |
ax.plot([joints[start, 0], joints[end, 0]], | |
[joints[start, 1], joints[end, 1]], | |
[joints[start, 2], joints[end, 2]], 'k-') | |
# Set labels | |
ax.set_xlabel('X') | |
ax.set_ylabel('Y') | |
ax.set_zlabel('Z') | |
ax.legend() | |
plt.show() |