import torch def get_tomato_slice(idx): if idx == 0: result = [0, 1, 2, 3, 463, 464, 465] else: result = [ 4 + (idx - 1) * 3, 4 + (idx - 1) * 3 + 1, 4 + (idx - 1) * 3 + 2, 157 + (idx - 1) * 6, 157 + (idx - 1) * 6 + 1, 157 + (idx - 1) * 6 + 2, 157 + (idx - 1) * 6 + 3, 157 + (idx - 1) * 6 + 4, 157 + (idx - 1) * 6 + 5, 463 + idx * 3, 463 + idx * 3 + 1, 463 + idx * 3 + 2, ] return result def get_part_slice(idx_list, func): result = [] for idx in idx_list: result.extend(func(idx)) return result def expand_mask_to_all(mask, body_scale, hand_scale, face_scale): func = get_tomato_slice root_slice = get_part_slice([0], func) head_slice = get_part_slice([12, 15], func) stem_slice = get_part_slice([3, 6, 9], func) larm_slice = get_part_slice([14, 17, 19, 21], func) rarm_slice = get_part_slice([13, 16, 18, 20], func) lleg_slice = get_part_slice([2, 5, 8, 11], func) rleg_slice = get_part_slice([1, 4, 7, 10], func) lhnd_slice = get_part_slice(range(22, 37), func) rhnd_slice = get_part_slice(range(37, 52), func) face_slice = range(619, 669) B, T = mask.shape[0], mask.shape[1] mask = mask.view(B, T, -1) all_mask = torch.zeros(B, T, 669).type_as(mask) all_mask[:, :, root_slice] = mask[:, :, 0].unsqueeze(-1).repeat(1, 1, len(root_slice)) all_mask[:, :, head_slice] = mask[:, :, 1].unsqueeze(-1).repeat(1, 1, len(head_slice)) all_mask[:, :, stem_slice] = mask[:, :, 2].unsqueeze(-1).repeat(1, 1, len(stem_slice)) all_mask[:, :, larm_slice] = mask[:, :, 3].unsqueeze(-1).repeat(1, 1, len(larm_slice)) all_mask[:, :, rarm_slice] = mask[:, :, 4].unsqueeze(-1).repeat(1, 1, len(rarm_slice)) all_mask[:, :, lleg_slice] = mask[:, :, 5].unsqueeze(-1).repeat(1, 1, len(lleg_slice)) all_mask[:, :, rleg_slice] = mask[:, :, 6].unsqueeze(-1).repeat(1, 1, len(rleg_slice)) all_mask[:, :, lhnd_slice] = mask[:, :, 7].unsqueeze(-1).repeat(1, 1, len(lhnd_slice)) all_mask[:, :, rhnd_slice] = mask[:, :, 8].unsqueeze(-1).repeat(1, 1, len(rhnd_slice)) all_mask[:, :, face_slice] = mask[:, :, 9].unsqueeze(-1).repeat(1, 1, len(face_slice)) all_mask[:, :, root_slice] *= body_scale all_mask[:, :, head_slice] *= body_scale all_mask[:, :, stem_slice] *= body_scale all_mask[:, :, larm_slice] *= body_scale all_mask[:, :, rarm_slice] *= body_scale all_mask[:, :, lleg_slice] *= body_scale all_mask[:, :, rleg_slice] *= body_scale all_mask[:, :, lhnd_slice] *= hand_scale all_mask[:, :, rhnd_slice] *= hand_scale all_mask[:, :, face_slice] *= face_scale return all_mask