|
import json |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class SGDataset(Dataset): |
|
def __init__( |
|
self, |
|
path_data_definition, |
|
path_processed_data, |
|
window, |
|
style_encoding_type, |
|
example_window_length, |
|
): |
|
"""PyTorch Dataset Instance |
|
|
|
Args: |
|
path_data_definition : Path to data_definition file |
|
path_processed_data : Path to processed_data npz file |
|
window : Length of the input-output slice |
|
style_encoding_type : "label" or "example" |
|
example_window_length : Length of example window |
|
""" |
|
|
|
with open(path_data_definition, "r") as f: |
|
details = json.load(f) |
|
self.details = details |
|
self.njoints = len(details["bone_names"]) |
|
self.nlabels = len(details["label_names"]) |
|
self.label_names = details["label_names"] |
|
self.bone_names = details["bone_names"] |
|
self.parents = torch.LongTensor(details["parents"]) |
|
self.dt = details["dt"] |
|
self.window = window |
|
self.style_encoding_type = style_encoding_type |
|
self.example_window_length = example_window_length |
|
|
|
|
|
processed_data = np.load(path_processed_data) |
|
|
|
self.ranges_train = processed_data["ranges_train"] |
|
self.ranges_valid = processed_data["ranges_valid"] |
|
self.ranges_train_labels = processed_data["ranges_train_labels"] |
|
self.ranges_valid_labels = processed_data["ranges_valid_labels"] |
|
|
|
self.X_audio_features = torch.as_tensor( |
|
processed_data["X_audio_features"], dtype=torch.float32 |
|
) |
|
self.Y_root_pos = torch.as_tensor(processed_data["Y_root_pos"], dtype=torch.float32) |
|
self.Y_root_rot = torch.as_tensor(processed_data["Y_root_rot"], dtype=torch.float32) |
|
self.Y_root_vel = torch.as_tensor(processed_data["Y_root_vel"], dtype=torch.float32) |
|
self.Y_root_vrt = torch.as_tensor(processed_data["Y_root_vrt"], dtype=torch.float32) |
|
self.Y_lpos = torch.as_tensor(processed_data["Y_lpos"], dtype=torch.float32) |
|
self.Y_ltxy = torch.as_tensor(processed_data["Y_ltxy"], dtype=torch.float32) |
|
self.Y_lvel = torch.as_tensor(processed_data["Y_lvel"], dtype=torch.float32) |
|
self.Y_lvrt = torch.as_tensor(processed_data["Y_lvrt"], dtype=torch.float32) |
|
self.Y_gaze_pos = torch.as_tensor(processed_data["Y_gaze_pos"], dtype=torch.float32) |
|
|
|
self.audio_input_mean = torch.as_tensor( |
|
processed_data["audio_input_mean"], dtype=torch.float32 |
|
) |
|
self.audio_input_std = torch.as_tensor( |
|
processed_data["audio_input_std"], dtype=torch.float32 |
|
) |
|
self.anim_input_mean = torch.as_tensor( |
|
processed_data["anim_input_mean"], dtype=torch.float32 |
|
) |
|
self.anim_input_std = torch.as_tensor(processed_data["anim_input_std"], dtype=torch.float32) |
|
self.anim_output_mean = torch.as_tensor( |
|
processed_data["anim_output_mean"], dtype=torch.float32 |
|
) |
|
self.anim_output_std = torch.as_tensor( |
|
processed_data["anim_output_std"], dtype=torch.float32 |
|
) |
|
|
|
|
|
R = [] |
|
L = [] |
|
S = [] |
|
for sample_number, ((range_start, range_end), range_label) in enumerate( |
|
zip(self.ranges_train, self.ranges_train_labels) |
|
): |
|
|
|
one_hot_label = np.zeros(self.nlabels, dtype=np.float32) |
|
one_hot_label[range_label] = 1.0 |
|
|
|
for ri in range(range_start, range_end - window): |
|
R.append(np.arange(ri, ri + window)) |
|
L.append(one_hot_label) |
|
S.append(sample_number) |
|
|
|
self.R = torch.as_tensor(np.array(R), dtype=torch.long) |
|
self.L = torch.as_tensor(np.array(L), dtype=torch.float32) |
|
self.S = torch.as_tensor(S, dtype=torch.short) |
|
|
|
|
|
@property |
|
def example_window_length(self): |
|
return self._example_window_length |
|
|
|
@example_window_length.setter |
|
def example_window_length(self, a): |
|
self._example_window_length = a |
|
|
|
def __len__(self): |
|
return len(self.R) |
|
|
|
def __getitem__(self, index): |
|
|
|
Rwindow = self.R[index] |
|
Rwindow = Rwindow.contiguous() |
|
|
|
|
|
Rlabel = self.L[index] |
|
|
|
|
|
RInd = self.S[index] |
|
sample_range = self.ranges_train[RInd] |
|
|
|
|
|
W_audio_features = self.X_audio_features[Rwindow] |
|
|
|
|
|
W_root_pos = self.Y_root_pos[Rwindow] |
|
W_root_rot = self.Y_root_rot[Rwindow] |
|
W_root_vel = self.Y_root_vel[Rwindow] |
|
W_root_vrt = self.Y_root_vrt[Rwindow] |
|
W_lpos = self.Y_lpos[Rwindow] |
|
W_ltxy = self.Y_ltxy[Rwindow] |
|
W_lvel = self.Y_lvel[Rwindow] |
|
W_lvrt = self.Y_lvrt[Rwindow] |
|
W_gaze_pos = self.Y_gaze_pos[Rwindow] |
|
|
|
if self.style_encoding_type == "label": |
|
style = Rlabel |
|
elif self.style_encoding_type == "example": |
|
style = self.get_example(Rwindow, sample_range, self.example_window_length) |
|
|
|
return ( |
|
W_audio_features, |
|
W_root_pos, |
|
W_root_rot, |
|
W_root_vel, |
|
W_root_vrt, |
|
W_lpos, |
|
W_ltxy, |
|
W_lvel, |
|
W_lvrt, |
|
W_gaze_pos, |
|
style, |
|
) |
|
|
|
def get_shapes(self): |
|
num_audio_features = self.X_audio_features.shape[1] |
|
pose_input_size = len(self.anim_input_std) |
|
pose_output_size = len(self.anim_output_std) |
|
dimensions = dict( |
|
num_audio_features=num_audio_features, |
|
pose_input_size=pose_input_size, |
|
pose_output_size=pose_output_size, |
|
) |
|
return dimensions |
|
|
|
def get_means_stds(self, device): |
|
return ( |
|
self.audio_input_mean.to(device), |
|
self.audio_input_std.to(device), |
|
self.anim_input_mean.to(device), |
|
self.anim_input_std.to(device), |
|
self.anim_output_mean.to(device), |
|
self.anim_output_std.to(device), |
|
) |
|
|
|
def get_example( |
|
self, Rwindow, sample_range, example_window_length, |
|
): |
|
|
|
ext_window = (example_window_length - self.window) // 2 |
|
ws = min(ext_window, Rwindow[0] - sample_range[0]) |
|
we = min(ext_window, sample_range[1] - Rwindow[-1]) |
|
s_ext = ws + ext_window - we |
|
w_ext = we + ext_window - ws |
|
start = max(Rwindow[0] - s_ext, sample_range[0]) |
|
end = min(Rwindow[-1] + w_ext, sample_range[1]) + 1 |
|
end = min(end, len(self.Y_root_vel)) |
|
S_root_vel = self.Y_root_vel[start:end].reshape(end - start, -1) |
|
S_root_vrt = self.Y_root_vrt[start:end].reshape(end - start, -1) |
|
S_lpos = self.Y_lpos[start:end].reshape(end - start, -1) |
|
S_ltxy = self.Y_ltxy[start:end].reshape(end - start, -1) |
|
S_lvel = self.Y_lvel[start:end].reshape(end - start, -1) |
|
S_lvrt = self.Y_lvrt[start:end].reshape(end - start, -1) |
|
example_feature_vec = torch.cat( |
|
[S_root_vel, S_root_vrt, S_lpos, S_ltxy, S_lvel, S_lvrt, torch.zeros_like(S_root_vel), ], |
|
dim=1, |
|
) |
|
curr_len = len(example_feature_vec) |
|
if curr_len < example_window_length: |
|
example_feature_vec = torch.cat( |
|
[example_feature_vec, example_feature_vec[-example_window_length + curr_len:]], |
|
dim=0, |
|
) |
|
return example_feature_vec |
|
|
|
def get_sample(self, dataset, length=None, range_index=None): |
|
if dataset == "train": |
|
if range_index is None: |
|
range_index = np.random.randint(len(self.ranges_train)) |
|
(s, e), label = self.ranges_train[range_index], self.ranges_train_labels[range_index] |
|
elif dataset == "valid": |
|
if range_index is None: |
|
range_index = np.random.randint(len(self.ranges_valid)) |
|
(s, e), label = self.ranges_valid[range_index], self.ranges_valid_labels[range_index] |
|
|
|
if length is not None: |
|
e = min(s + length * 60, e) |
|
|
|
return ( |
|
self.X_audio_features[s:e][np.newaxis], |
|
self.Y_root_pos[s:e][np.newaxis], |
|
self.Y_root_rot[s:e][np.newaxis], |
|
self.Y_root_vel[s:e][np.newaxis], |
|
self.Y_root_vrt[s:e][np.newaxis], |
|
self.Y_lpos[s:e][np.newaxis], |
|
self.Y_ltxy[s:e][np.newaxis], |
|
self.Y_lvel[s:e][np.newaxis], |
|
self.Y_lvrt[s:e][np.newaxis], |
|
self.Y_gaze_pos[s:e][np.newaxis], |
|
label, |
|
[s, e], |
|
range_index, |
|
) |
|
|
|
def get_stats(self): |
|
from rich.console import Console |
|
from rich.table import Table |
|
|
|
console = Console(record=True) |
|
|
|
df = pd.DataFrame() |
|
df["Dataset"] = ["Train", "Validation", "Total"] |
|
pd.set_option("display.max_rows", None, "display.max_columns", None) |
|
table = Table(title="Data Info", show_lines=True, row_styles=["magenta"]) |
|
table.add_column("Dataset") |
|
data_len = 0 |
|
for i in range(self.nlabels): |
|
ind_mask = self.ranges_train_labels == i |
|
ranges = self.ranges_train[ind_mask] |
|
num_train_frames = ( |
|
np.sum(ranges[:, 1] - ranges[:, 0]) / 2 |
|
) |
|
ind_mask = self.ranges_valid_labels == i |
|
ranges = self.ranges_valid[ind_mask] |
|
num_valid_frames = np.sum(ranges[:, 1] - ranges[:, 0]) / 2 |
|
total = num_train_frames + num_valid_frames |
|
df[self.label_names[i]] = [ |
|
f"{num_train_frames} frames - {num_train_frames / 60:.1f} secs", |
|
f"{num_valid_frames} frames - {num_valid_frames / 60:.1f} secs", |
|
f"{total} frames - {total / 60:.1f} secs", |
|
] |
|
table.add_column(self.label_names[i]) |
|
data_len += total |
|
|
|
for i in range(3): |
|
table.add_row(*list(df.iloc[i])) |
|
console.print(table) |
|
dimensions = self.get_shapes() |
|
console.print(f"Total length of dataset is {data_len} frames - {data_len / 60:.1f} seconds") |
|
console.print("Num features: ", dimensions) |
|
|