File size: 5,556 Bytes
4e19122 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import torch.nn as nn
from .RPE.KPRPE.kprpe_shared import get_rpe_config
from .RPE.KPRPE import relative_keypoints
def make_kprpe_shared(rpe_config, depth, num_heads):
assert rpe_config.rpe_on == 'k'
num_buckets = get_rpe_config(
ratio=rpe_config.ratio,
method=rpe_config.method,
mode=rpe_config.mode,
shared_head=rpe_config.shared_head,
skip=0,
rpe_on=rpe_config.rpe_on,
)['rpe_k']['num_buckets']
if rpe_config.ctx_type == 'rel_keypoint':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_unshared':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * depth)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_unshared_v2':
keypoint_linear = nn.Sequential(
nn.Linear(2 * rpe_config.num_keypoints, 256),
nn.ReLU(inplace=True),
nn.LayerNorm(256),
nn.Linear(256, num_buckets * depth),
)
# init zero
keypoint_linear[-1].weight.data.zero_()
keypoint_linear[-1].bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_splithead':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_splithead_unshared':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads * depth)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
elif rpe_config.ctx_type == 'rel_keypoint_v2':
keypoint_linear = nn.Sequential(
nn.Linear(2 * rpe_config.num_keypoints, 256),
nn.ReLU(inplace=True),
nn.LayerNorm(256),
nn.Linear(256, num_buckets),
)
# init zero
keypoint_linear[-1].weight.data.zero_()
keypoint_linear[-1].bias.data.zero_()
elif rpe_config.ctx_type == 'keypoint':
keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets)
# init zero
keypoint_linear.weight.data.zero_()
keypoint_linear.bias.data.zero_()
else:
raise ValueError(f'Not support ctx_type: {rpe_config.ctx_type}')
return keypoint_linear, num_buckets
def make_kprpe_input(keypoints, x, keypoint_linear, rpe_config, mask_ratio, depth, num_heads, num_buckets):
B = x.shape[0]
ctx_type = rpe_config.get('ctx_type', '')
num_kp = rpe_config.num_keypoints
if ctx_type == 'rel_keypoint':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) # B H N D
extra_ctx = {'rel_keypoints': rel_keypoints}
elif ctx_type == 'rel_keypoint_unshared':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2)
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]
elif ctx_type == 'rel_keypoint_unshared_v2':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2)
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]
elif ctx_type == 'rel_keypoint_splithead':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, num_heads, num_buckets).transpose(1, 2)
extra_ctx = {'rel_keypoints': rel_keypoints}
elif ctx_type == 'rel_keypoint_splithead_unshared':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints) # B H N D
rel_keypoints = rel_keypoints.view(B, -1, num_heads * depth, num_buckets).transpose(1, 2)
rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]
elif ctx_type == 'rel_keypoint_v2':
assert mask_ratio == 0
rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) # B H N D
extra_ctx = {'rel_keypoints': rel_keypoints}
elif ctx_type == 'keypoint':
keypoints = keypoints.flatten(1).unsqueeze(1)
keypoints = keypoint_linear(keypoints).unsqueeze(1)
extra_ctx = {'rel_keypoints': keypoints}
else:
raise ValueError(f'Not support ctx_type: {ctx_type}')
return extra_ctx |