|
import torch |
|
import torch.nn as nn |
|
from .RPE.KPRPE.kprpe_shared import get_rpe_config |
|
from .RPE.KPRPE import relative_keypoints |
|
|
|
|
|
def make_kprpe_shared(rpe_config, depth, num_heads): |
|
|
|
assert rpe_config.rpe_on == 'k' |
|
num_buckets = get_rpe_config( |
|
ratio=rpe_config.ratio, |
|
method=rpe_config.method, |
|
mode=rpe_config.mode, |
|
shared_head=rpe_config.shared_head, |
|
skip=0, |
|
rpe_on=rpe_config.rpe_on, |
|
)['rpe_k']['num_buckets'] |
|
if rpe_config.ctx_type == 'rel_keypoint': |
|
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets) |
|
|
|
keypoint_linear.weight.data.zero_() |
|
keypoint_linear.bias.data.zero_() |
|
|
|
elif rpe_config.ctx_type == 'rel_keypoint_unshared': |
|
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * depth) |
|
|
|
keypoint_linear.weight.data.zero_() |
|
keypoint_linear.bias.data.zero_() |
|
|
|
elif rpe_config.ctx_type == 'rel_keypoint_unshared_v2': |
|
keypoint_linear = nn.Sequential( |
|
nn.Linear(2 * rpe_config.num_keypoints, 256), |
|
nn.ReLU(inplace=True), |
|
nn.LayerNorm(256), |
|
nn.Linear(256, num_buckets * depth), |
|
) |
|
|
|
keypoint_linear[-1].weight.data.zero_() |
|
keypoint_linear[-1].bias.data.zero_() |
|
|
|
elif rpe_config.ctx_type == 'rel_keypoint_splithead': |
|
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads) |
|
|
|
keypoint_linear.weight.data.zero_() |
|
keypoint_linear.bias.data.zero_() |
|
|
|
elif rpe_config.ctx_type == 'rel_keypoint_splithead_unshared': |
|
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads * depth) |
|
|
|
keypoint_linear.weight.data.zero_() |
|
keypoint_linear.bias.data.zero_() |
|
|
|
elif rpe_config.ctx_type == 'rel_keypoint_v2': |
|
keypoint_linear = nn.Sequential( |
|
nn.Linear(2 * rpe_config.num_keypoints, 256), |
|
nn.ReLU(inplace=True), |
|
nn.LayerNorm(256), |
|
nn.Linear(256, num_buckets), |
|
) |
|
|
|
keypoint_linear[-1].weight.data.zero_() |
|
keypoint_linear[-1].bias.data.zero_() |
|
elif rpe_config.ctx_type == 'keypoint': |
|
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets) |
|
|
|
keypoint_linear.weight.data.zero_() |
|
keypoint_linear.bias.data.zero_() |
|
else: |
|
raise ValueError(f'Not support ctx_type: {rpe_config.ctx_type}') |
|
|
|
return keypoint_linear, num_buckets |
|
|
|
|
|
|
|
def make_kprpe_input(keypoints, x, keypoint_linear, rpe_config, mask_ratio, depth, num_heads, num_buckets): |
|
B = x.shape[0] |
|
ctx_type = rpe_config.get('ctx_type', '') |
|
num_kp = rpe_config.num_keypoints |
|
if ctx_type == 'rel_keypoint': |
|
assert mask_ratio == 0 |
|
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] |
|
rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) |
|
extra_ctx = {'rel_keypoints': rel_keypoints} |
|
|
|
elif ctx_type == 'rel_keypoint_unshared': |
|
assert mask_ratio == 0 |
|
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] |
|
rel_keypoints = keypoint_linear(rel_keypoints) |
|
rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2) |
|
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1) |
|
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints] |
|
|
|
elif ctx_type == 'rel_keypoint_unshared_v2': |
|
assert mask_ratio == 0 |
|
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] |
|
rel_keypoints = keypoint_linear(rel_keypoints) |
|
rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2) |
|
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1) |
|
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints] |
|
|
|
elif ctx_type == 'rel_keypoint_splithead': |
|
assert mask_ratio == 0 |
|
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] |
|
rel_keypoints = keypoint_linear(rel_keypoints) |
|
rel_keypoints = rel_keypoints.view(B, -1, num_heads, num_buckets).transpose(1, 2) |
|
extra_ctx = {'rel_keypoints': rel_keypoints} |
|
elif ctx_type == 'rel_keypoint_splithead_unshared': |
|
assert mask_ratio == 0 |
|
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] |
|
rel_keypoints = keypoint_linear(rel_keypoints) |
|
rel_keypoints = rel_keypoints.view(B, -1, num_heads * depth, num_buckets).transpose(1, 2) |
|
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1) |
|
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints] |
|
|
|
elif ctx_type == 'rel_keypoint_v2': |
|
assert mask_ratio == 0 |
|
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] |
|
rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) |
|
extra_ctx = {'rel_keypoints': rel_keypoints} |
|
|
|
elif ctx_type == 'keypoint': |
|
keypoints = keypoints.flatten(1).unsqueeze(1) |
|
keypoints = keypoint_linear(keypoints).unsqueeze(1) |
|
extra_ctx = {'rel_keypoints': keypoints} |
|
else: |
|
raise ValueError(f'Not support ctx_type: {ctx_type}') |
|
|
|
return extra_ctx |