sahandv commited on
Commit
f9e76ec
·
1 Parent(s): 16051be

Added visualise.py for visualising the predictions

Browse files
Files changed (1) hide show
  1. visualise.py +98 -0
visualise.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Visualisation code for SMPL-X model. This code is useful if you already have predictions.
2
+
3
+ import os
4
+ import sys
5
+ import os.path as osp
6
+ import numpy as np
7
+ import smplx
8
+ from smplx.joint_names import JOINT_NAMES
9
+ import torch
10
+ try:
11
+ CUR_DIR = osp.dirname(os.path.abspath(__file__))
12
+ except NameError:
13
+ CUR_DIR = os.getcwd()
14
+ sys.path.insert(0, osp.join(CUR_DIR, '..', 'main'))
15
+ sys.path.insert(0, osp.join(CUR_DIR , '..', 'common'))
16
+ import matplotlib.pyplot as plt
17
+ from mpl_toolkits.mplot3d import Axes3D
18
+
19
+ JOINT_NAMES_DICT = {name: i for i, name in enumerate(JOINT_NAMES)}
20
+
21
+ # Load the SMPL-X model
22
+ model_path = 'common/utils/human_model_files' # Update with the path to your SMPL-X models
23
+ model = smplx.create(model_path, model_type='smplx', gender='neutral', ext='npz')
24
+
25
+ # Load the parameters from the .npz file
26
+ data = np.load('/home/sahand/Downloads/smplx/00047_9.npz')
27
+
28
+ betas = torch.tensor(data['betas'], dtype=torch.float32)
29
+ body_pose = torch.tensor(data['body_pose'], dtype=torch.float32)
30
+ global_orient = torch.tensor(data['global_orient'], dtype=torch.float32)
31
+ transl = torch.tensor(data['transl'], dtype=torch.float32)
32
+ expression = torch.tensor(data['expression'], dtype=torch.float32)
33
+
34
+ # Add missing dimensions to the tensors
35
+ if betas.ndim == 1:
36
+ betas = betas.unsqueeze(0)
37
+ if body_pose.ndim == 2:
38
+ body_pose = body_pose.unsqueeze(0)
39
+ if global_orient.ndim == 1:
40
+ global_orient = global_orient.unsqueeze(0)
41
+ if transl.ndim == 1:
42
+ transl = transl.unsqueeze(0)
43
+ if expression.ndim == 1:
44
+ expression = expression.unsqueeze(0)
45
+
46
+ # Reshape body_pose to include the batch dimension
47
+ body_pose = body_pose.view(1, -1, 3)
48
+
49
+ # Forward pass through the model
50
+ output = model(betas=betas, body_pose=body_pose, global_orient=global_orient, transl=transl, expression=expression)
51
+
52
+ # Extract joint positions
53
+ joints = output.joints.detach().cpu().numpy().squeeze()
54
+ print(joints.shape)
55
+ # Ankle joints (left and right)
56
+ left_knee = joints[4] # Index for left ankle in SMPL-X
57
+ right_knee = joints[5] # Index for right ankle in SMPL-X
58
+ left_ankle = joints[7] # Index for left ankle in SMPL-X
59
+ right_ankle = joints[8] # Index for right ankle in SMPL-X
60
+
61
+ bone_connections = [
62
+ (JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["spine1"]), (JOINT_NAMES_DICT["spine1"], JOINT_NAMES_DICT["spine2"]), (JOINT_NAMES_DICT["spine2"], JOINT_NAMES_DICT["spine3"]), # Spine
63
+ (JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["left_hip"]), (JOINT_NAMES_DICT["left_hip"], JOINT_NAMES_DICT["left_knee"]), (JOINT_NAMES_DICT["left_knee"], JOINT_NAMES_DICT["left_ankle"]), # Left leg
64
+ (JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["right_hip"]), (JOINT_NAMES_DICT["right_hip"], JOINT_NAMES_DICT["right_knee"]), (JOINT_NAMES_DICT["right_knee"], JOINT_NAMES_DICT["right_ankle"]), # Right leg
65
+ (JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_heel"]),
66
+ (JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_heel"]),
67
+ (JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_foot"]),
68
+ (JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_big_toe"]), (JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_small_toe"]),
69
+ (JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_foot"]),
70
+ (JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_big_toe"]), (JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_small_toe"]),
71
+ # Add more bones if necessary
72
+ ]
73
+
74
+ # Visualize the 3D skeleton
75
+ fig = plt.figure()
76
+ ax = fig.add_subplot(111, projection='3d')
77
+
78
+ # Plot all joints
79
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], c='blue', marker='o')
80
+ # Highlight ankle joints
81
+ ax.scatter([left_knee[0]], [left_knee[1]], [left_knee[2]], c='red', marker='x', s=100, label='Left Knee')
82
+ ax.scatter([right_knee[0]], [right_knee[1]], [right_knee[2]], c='green', marker='x', s=100, label='Right Knee')
83
+ ax.scatter([left_ankle[0]], [left_ankle[1]], [left_ankle[2]], c='red', marker='o', s=100, label='Left Ankle')
84
+ ax.scatter([right_ankle[0]], [right_ankle[1]], [right_ankle[2]], c='green', marker='o', s=100, label='Right Ankle')
85
+
86
+ # Draw bones
87
+ for bone in bone_connections:
88
+ start, end = bone
89
+ ax.plot([joints[start, 0], joints[end, 0]],
90
+ [joints[start, 1], joints[end, 1]],
91
+ [joints[start, 2], joints[end, 2]], 'k-')
92
+
93
+ # Set labels
94
+ ax.set_xlabel('X')
95
+ ax.set_ylabel('Y')
96
+ ax.set_zlabel('Z')
97
+ ax.legend()
98
+ plt.show()