minchul's picture
Upload directory
4e19122 verified
raw
history blame
No virus
5.56 kB
import torch
import torch.nn as nn
from .RPE.KPRPE.kprpe_shared import get_rpe_config
from .RPE.KPRPE import relative_keypoints
def make_kprpe_shared(rpe_config, depth, num_heads):
assert rpe_config.rpe_on == 'k'
num_buckets = get_rpe_config(
ratio=rpe_config.ratio,
method=rpe_config.method,
mode=rpe_config.mode,
shared_head=rpe_config.shared_head,
skip=0,
rpe_on=rpe_config.rpe_on,
)['rpe_k']['num_buckets']
if rpe_config.ctx_type == 'rel_keypoint':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_unshared':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * depth)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_unshared_v2':
keypoint_linear = nn.Sequential(
nn.Linear(2 * rpe_config.num_keypoints, 256),
nn.ReLU(inplace=True),
nn.LayerNorm(256),
nn.Linear(256, num_buckets * depth),
)
# init zero
keypoint_linear[-1].weight.data.zero_()
keypoint_linear[-1].bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_splithead':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_splithead_unshared':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads * depth)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_v2':
keypoint_linear = nn.Sequential(
nn.Linear(2 * rpe_config.num_keypoints, 256),
nn.ReLU(inplace=True),
nn.LayerNorm(256),
nn.Linear(256, num_buckets),
)
# init zero
keypoint_linear[-1].weight.data.zero_()
keypoint_linear[-1].bias.data.zero_()
elif rpe_config.ctx_type == 'keypoint':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
else:
raise ValueError(f'Not support ctx_type: {rpe_config.ctx_type}')
return keypoint_linear, num_buckets
def make_kprpe_input(keypoints, x, keypoint_linear, rpe_config, mask_ratio, depth, num_heads, num_buckets):
B = x.shape[0]
ctx_type = rpe_config.get('ctx_type', '')
num_kp = rpe_config.num_keypoints
if ctx_type == 'rel_keypoint':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) # B H N D
extra_ctx = {'rel_keypoints': rel_keypoints}
elif ctx_type == 'rel_keypoint_unshared':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2)
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]
elif ctx_type == 'rel_keypoint_unshared_v2':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2)
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]
elif ctx_type == 'rel_keypoint_splithead':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, num_heads, num_buckets).transpose(1, 2)
extra_ctx = {'rel_keypoints': rel_keypoints}
elif ctx_type == 'rel_keypoint_splithead_unshared':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, num_heads * depth, num_buckets).transpose(1, 2)
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]
elif ctx_type == 'rel_keypoint_v2':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) # B H N D
extra_ctx = {'rel_keypoints': rel_keypoints}
elif ctx_type == 'keypoint':
keypoints = keypoints.flatten(1).unsqueeze(1)
keypoints = keypoint_linear(keypoints).unsqueeze(1)
extra_ctx = {'rel_keypoints': keypoints}
else:
raise ValueError(f'Not support ctx_type: {ctx_type}')
return extra_ctx