import argparse import os from pathlib import Path import imageio import numpy as np import torch import torch.nn as nn from PIL import Image from sklearn.preprocessing import LabelEncoder from cmib.data.lafan1_dataset import LAFAN1Dataset from cmib.data.utils import write_json from cmib.lafan1.utils import quat_ik from cmib.model.network import TransformerModel from cmib.model.preprocess import (lerp_input_repr, replace_constant, slerp_input_repr, vectorize_representation) from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names, sk_parents) from cmib.vis.pose import plot_pose_with_stop def test(opt, device): save_dir = Path(os.path.join('runs', 'train', opt.exp_name)) wdir = save_dir / 'weights' weights = os.listdir(wdir) weights_paths = [wdir / weight for weight in weights] latest_weight = max(weights_paths , key = os.path.getctime) ckpt = torch.load(latest_weight, map_location=device) print(f"Loaded weight: {latest_weight}") # Load Skeleton skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device) skeleton_mocap.remove_joints(sk_joints_to_remove) # Load LAFAN Dataset Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True) lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device) total_data = lafan_dataset.data['global_pos'].shape[0] # Replace with noise to In-betweening Frames from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48 horizon = ckpt['horizon'] print(f"HORIZON: {horizon}") test_idx = [] for i in range(total_data): test_idx.append(i) # Compare Input data, Prediction, GT save_path = os.path.join(opt.save_path, 'sampler') for i in range(len(test_idx)): Path(save_path).mkdir(parents=True, exist_ok=True) start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx] target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx] gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx] gt_img_path = os.path.join(save_path) plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt') print(f"ID {test_idx[i]}: completed.") def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--project', default='runs/train', help='project/name') parser.add_argument('--exp_name', default='slerp_40', help='experiment name') parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path') parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton') parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data') parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model') parser.add_argument('--motion_type', type=str, default='jumps', help='motion type') opt = parser.parse_args() return opt if __name__ == "__main__": opt = parse_opt() device = torch.device("cpu") test(opt, device)