File size: 5,556 Bytes
4e19122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
from .RPE.KPRPE.kprpe_shared import get_rpe_config
from .RPE.KPRPE import relative_keypoints


def make_kprpe_shared(rpe_config, depth, num_heads):

    assert rpe_config.rpe_on == 'k'
    num_buckets = get_rpe_config(
        ratio=rpe_config.ratio,
        method=rpe_config.method,
        mode=rpe_config.mode,
        shared_head=rpe_config.shared_head,
        skip=0,
        rpe_on=rpe_config.rpe_on,
    )['rpe_k']['num_buckets']
    if rpe_config.ctx_type == 'rel_keypoint':
        keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets)
        # init zero
        keypoint_linear.weight.data.zero_()
        keypoint_linear.bias.data.zero_()

    elif rpe_config.ctx_type == 'rel_keypoint_unshared':
        keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * depth)
        # init zero
        keypoint_linear.weight.data.zero_()
        keypoint_linear.bias.data.zero_()

    elif rpe_config.ctx_type == 'rel_keypoint_unshared_v2':
        keypoint_linear = nn.Sequential(
            nn.Linear(2 * rpe_config.num_keypoints, 256),
            nn.ReLU(inplace=True),
            nn.LayerNorm(256),
            nn.Linear(256, num_buckets * depth),
        )
        # init zero
        keypoint_linear[-1].weight.data.zero_()
        keypoint_linear[-1].bias.data.zero_()

    elif rpe_config.ctx_type == 'rel_keypoint_splithead':
        keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads)
        # init zero
        keypoint_linear.weight.data.zero_()
        keypoint_linear.bias.data.zero_()

    elif rpe_config.ctx_type == 'rel_keypoint_splithead_unshared':
        keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads * depth)
        # init zero
        keypoint_linear.weight.data.zero_()
        keypoint_linear.bias.data.zero_()

    elif rpe_config.ctx_type == 'rel_keypoint_v2':
        keypoint_linear = nn.Sequential(
            nn.Linear(2 * rpe_config.num_keypoints, 256),
            nn.ReLU(inplace=True),
            nn.LayerNorm(256),
            nn.Linear(256, num_buckets),
        )
        # init zero
        keypoint_linear[-1].weight.data.zero_()
        keypoint_linear[-1].bias.data.zero_()
    elif rpe_config.ctx_type == 'keypoint':
        keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets)
        # init zero
        keypoint_linear.weight.data.zero_()
        keypoint_linear.bias.data.zero_()
    else:
        raise ValueError(f'Not support ctx_type: {rpe_config.ctx_type}')

    return keypoint_linear, num_buckets



def make_kprpe_input(keypoints, x, keypoint_linear, rpe_config, mask_ratio, depth, num_heads, num_buckets):
    B = x.shape[0]
    ctx_type = rpe_config.get('ctx_type', '')
    num_kp = rpe_config.num_keypoints
    if ctx_type == 'rel_keypoint':
        assert mask_ratio == 0
        rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
        rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1)  # B H N D
        extra_ctx = {'rel_keypoints': rel_keypoints}

    elif ctx_type == 'rel_keypoint_unshared':
        assert mask_ratio == 0
        rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
        rel_keypoints = keypoint_linear(rel_keypoints)  # B H N D
        rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2)
        rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
        extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]

    elif ctx_type == 'rel_keypoint_unshared_v2':
        assert mask_ratio == 0
        rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
        rel_keypoints = keypoint_linear(rel_keypoints)  # B H N D
        rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2)
        rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
        extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]

    elif ctx_type == 'rel_keypoint_splithead':
        assert mask_ratio == 0
        rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
        rel_keypoints = keypoint_linear(rel_keypoints)  # B H N D
        rel_keypoints = rel_keypoints.view(B, -1, num_heads, num_buckets).transpose(1, 2)
        extra_ctx = {'rel_keypoints': rel_keypoints}
    elif ctx_type == 'rel_keypoint_splithead_unshared':
        assert mask_ratio == 0
        rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
        rel_keypoints = keypoint_linear(rel_keypoints)  # B H N D
        rel_keypoints = rel_keypoints.view(B, -1, num_heads * depth, num_buckets).transpose(1, 2)
        rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1)
        extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints]

    elif ctx_type == 'rel_keypoint_v2':
        assert mask_ratio == 0
        rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp]
        rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1)  # B H N D
        extra_ctx = {'rel_keypoints': rel_keypoints}

    elif ctx_type == 'keypoint':
        keypoints = keypoints.flatten(1).unsqueeze(1)
        keypoints = keypoint_linear(keypoints).unsqueeze(1)
        extra_ctx = {'rel_keypoints': keypoints}
    else:
        raise ValueError(f'Not support ctx_type: {ctx_type}')

    return extra_ctx