minchul commited on
Commit
687e16e
1 Parent(s): e628e01

Upload directory

Browse files
models/vit_kprpe/RPE/KPRPE/kprpe_shared.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from .dist import _rp_2d_cross_cols, _rp_2d_cross_rows, _rp_2d_euclidean, _rp_2d_product, _rp_2d_quant
6
+
7
+ try:
8
+ from ..rpe_ops.rpe_index import RPEIndexFunction
9
+ except Exception as e:
10
+ print('Failed to import cuda/cpp RPEIndexFunction')
11
+ RPEIndexFunction = None
12
+
13
+
14
+
15
+ def get_absolute_positions(height, width, dtype, device):
16
+ '''Get absolute positions
17
+
18
+ Take height = 3, width = 3 as an example:
19
+ rows: cols:
20
+ 1 1 1 1 2 3
21
+ 2 2 2 1 2 3
22
+ 3 3 3 1 2 3
23
+
24
+ return stack([rows, cols], 2)
25
+
26
+ Parameters
27
+ ----------
28
+ height, width: int
29
+ The height and width of feature map
30
+ dtype: torch.dtype
31
+ the data type of returned value
32
+ device: torch.device
33
+ the device of returned value
34
+
35
+ Return
36
+ ------
37
+ 2D absolute positions: torch.Tensor
38
+ The shape is (height, width, 2),
39
+ where 2 represents a 2D position (row, col).
40
+ '''
41
+ rows = torch.arange(height, dtype=dtype, device=device).view(
42
+ height, 1).repeat(1, width)
43
+ cols = torch.arange(width, dtype=dtype, device=device).view(
44
+ 1, width).repeat(height, 1)
45
+ return torch.stack([rows, cols], 2)
46
+
47
+
48
+ class METHOD:
49
+ """define iRPE method IDs
50
+ We divide the implementation of CROSS into CROSS_ROWS and CROSS_COLS.
51
+
52
+ """
53
+ EUCLIDEAN = 0
54
+ QUANT = 1
55
+ PRODUCT = 3
56
+ CROSS = 4
57
+ CROSS_ROWS = 41
58
+ CROSS_COLS = 42
59
+
60
+
61
+ # Define a mapping from METHOD_ID to Python function
62
+ _METHOD_FUNC = {
63
+ METHOD.EUCLIDEAN: _rp_2d_euclidean,
64
+ METHOD.QUANT: _rp_2d_quant,
65
+ METHOD.PRODUCT: _rp_2d_product,
66
+ METHOD.CROSS_ROWS: _rp_2d_cross_rows,
67
+ METHOD.CROSS_COLS: _rp_2d_cross_cols,
68
+ }
69
+
70
+
71
+ def get_num_buckets(method, alpha, beta, gamma):
72
+ """ Get number of buckets storing relative position encoding.
73
+ The buckets does not contain `skip` token.
74
+
75
+ Parameters
76
+ ----------
77
+ method: METHOD
78
+ The method ID of image relative position encoding.
79
+ alpha, beta, gamma: float
80
+ The coefficients of piecewise index function.
81
+
82
+ Returns
83
+ -------
84
+ num_buckets: int
85
+ The number of buckets storing relative position encoding.
86
+ """
87
+ beta_int = int(beta)
88
+ if method == METHOD.PRODUCT:
89
+ # IDs in [0, (2 * beta_int + 1)^2) for Product method
90
+ num_buckets = (2 * beta_int + 1) ** 2
91
+ else:
92
+ # IDs in [-beta_int, beta_int] except of Product method
93
+ num_buckets = 2 * beta_int + 1
94
+ return num_buckets
95
+
96
+
97
+ # (method, alpha, beta, gamma) -> (bucket_ids, num_buckets, height, width)
98
+ BUCKET_IDS_BUF = dict()
99
+
100
+
101
+ @torch.no_grad()
102
+ def get_bucket_ids_2d_without_skip(method, height, width,
103
+ alpha, beta, gamma,
104
+ dtype=torch.long, device=torch.device('cpu')):
105
+ """Get bucket IDs for image relative position encodings without skip token
106
+
107
+ Parameters
108
+ ----------
109
+ method: METHOD
110
+ The method ID of image relative position encoding.
111
+ height, width: int
112
+ The height and width of the feature map.
113
+ The sequence length is equal to `height * width`.
114
+ alpha, beta, gamma: float
115
+ The coefficients of piecewise index function.
116
+ dtype: torch.dtype
117
+ the data type of returned `bucket_ids`
118
+ device: torch.device
119
+ the device of returned `bucket_ids`
120
+
121
+ Returns
122
+ -------
123
+ bucket_ids: torch.Tensor, dtype: long
124
+ The bucket IDs which index to corresponding encodings.
125
+ The shape of `bucket_ids` is (skip + L, skip + L),
126
+ where `L = height * wdith`.
127
+ num_buckets: int
128
+ The number of buckets including `skip` token.
129
+ L: int
130
+ The sequence length
131
+ """
132
+
133
+ key = (method, alpha, beta, gamma, dtype, device)
134
+ value = BUCKET_IDS_BUF.get(key, None)
135
+ if value is None or value[-2] < height or value[-1] < width:
136
+ if value is None:
137
+ max_height, max_width = height, width
138
+ else:
139
+ max_height = max(value[-2], height)
140
+ max_width = max(value[-1], width)
141
+ # relative position encoding mapping function
142
+ func = _METHOD_FUNC.get(method, None)
143
+ if func is None:
144
+ raise NotImplementedError(
145
+ f"[Error] The method ID {method} does not exist.")
146
+ pos = get_absolute_positions(max_height, max_width, dtype, device)
147
+
148
+ # compute the offset of a pair of 2D relative positions
149
+ max_L = max_height * max_width
150
+ pos1 = pos.view((max_L, 1, 2))
151
+ pos2 = pos.view((1, max_L, 2))
152
+ # diff: shape of (L, L, 2)
153
+ diff = pos1 - pos2
154
+
155
+ # bucket_ids: shape of (L, L)
156
+ bucket_ids = func(diff, alpha=alpha, beta=beta,
157
+ gamma=gamma, dtype=dtype)
158
+ beta_int = int(beta)
159
+ if method != METHOD.PRODUCT:
160
+ bucket_ids += beta_int
161
+ bucket_ids = bucket_ids.view(
162
+ max_height, max_width, max_height, max_width)
163
+
164
+ num_buckets = get_num_buckets(method, alpha, beta, gamma)
165
+ value = (bucket_ids, num_buckets, height, width)
166
+ BUCKET_IDS_BUF[key] = value
167
+ L = height * width
168
+ bucket_ids = value[0][:height, :width, :height, :width].reshape(L, L)
169
+ num_buckets = value[1]
170
+
171
+ return bucket_ids, num_buckets, L
172
+
173
+
174
+ @torch.no_grad()
175
+ def get_bucket_ids_2d(method, height, width,
176
+ skip, alpha, beta, gamma,
177
+ dtype=torch.long, device=torch.device('cpu')):
178
+ """Get bucket IDs for image relative position encodings
179
+
180
+ Parameters
181
+ ----------
182
+ method: METHOD
183
+ The method ID of image relative position encoding.
184
+ height, width: int
185
+ The height and width of the feature map.
186
+ The sequence length is equal to `height * width`.
187
+ skip: int
188
+ The number of skip token before spatial tokens.
189
+ When skip is 0, no classification token.
190
+ When skip is 1, there is a classification token before spatial tokens.
191
+ When skip > 1, there are `skip` extra tokens before spatial tokens.
192
+ alpha, beta, gamma: float
193
+ The coefficients of piecewise index function.
194
+ dtype: torch.dtype
195
+ the data type of returned `bucket_ids`
196
+ device: torch.device
197
+ the device of returned `bucket_ids`
198
+
199
+ Returns
200
+ -------
201
+ bucket_ids: torch.Tensor, dtype: long
202
+ The bucket IDs which index to corresponding encodings.
203
+ The shape of `bucket_ids` is (skip + L, skip + L),
204
+ where `L = height * wdith`.
205
+ num_buckets: int
206
+ The number of buckets including `skip` token.
207
+ """
208
+ bucket_ids, num_buckets, L = get_bucket_ids_2d_without_skip(method, height, width,
209
+ alpha, beta, gamma,
210
+ dtype, device)
211
+
212
+ # add an extra encoding (id = num_buckets) for the classification token
213
+ if skip > 0:
214
+ new_bids = bucket_ids.new_empty(size=(skip + L, skip + L))
215
+
216
+ # if extra token exists, we add extra bucket as its encoding.
217
+ extra_bucket_id = num_buckets
218
+ num_buckets += 1
219
+
220
+ new_bids[:skip] = extra_bucket_id
221
+ new_bids[:, :skip] = extra_bucket_id
222
+ new_bids[skip:, skip:] = bucket_ids
223
+
224
+ bucket_ids = new_bids
225
+ bucket_ids = bucket_ids.contiguous()
226
+ return bucket_ids, num_buckets
227
+
228
+
229
+ class iRPE(nn.Module):
230
+ """The implementation of image relative position encoding (excluding Cross method).
231
+
232
+ Parameters
233
+ ----------
234
+ head_dim: int
235
+ The dimension for each head.
236
+ num_heads: int
237
+ The number of parallel attention heads.
238
+ mode: str or None
239
+ The mode of image relative position encoding.
240
+ Choices: [None, 'bias', 'contextual']
241
+ method: METHOD
242
+ The method ID of image relative position encoding.
243
+ The `METHOD` class is defined in `irpe.py`.
244
+ transposed: bool
245
+ Whether to transpose the input feature.
246
+ For iRPE on queries or keys, transposed should be `True`.
247
+ For iRPE on values, transposed should be `False`.
248
+ num_buckets: int
249
+ The number of buckets, which store encodings.
250
+ initializer: None or an inplace function
251
+ [Optional] The initializer to `lookup_table`.
252
+ Initalize `lookup_table` as zero by default.
253
+ rpe_config: RPEConfig
254
+ The config generated by the function `get_single_rpe_config`.
255
+ """
256
+ # a buffer to store bucket index
257
+ # (key, rp_bucket, _ctx_rp_bucket_flatten)
258
+ _rp_bucket_buf = (None, None, None)
259
+
260
+ def __init__(self, head_dim, num_heads=8,
261
+ mode=None, method=None,
262
+ transposed=True, num_buckets=None,
263
+ initializer=None, rpe_config=None):
264
+ super().__init__()
265
+ self.num_heads = num_heads
266
+ self.head_dim = head_dim
267
+
268
+ # relative position
269
+ assert mode in [None, 'bias', 'contextual']
270
+ self.mode = mode
271
+
272
+ assert method is not None, 'method should be a METHOD ID rather than None'
273
+ self.method = method
274
+
275
+ self.transposed = transposed
276
+ self.num_buckets = num_buckets
277
+
278
+ if initializer is None:
279
+ def initializer(x): return None
280
+ self.initializer = initializer
281
+
282
+ self.reset_parameters()
283
+
284
+ self.rpe_config = rpe_config
285
+
286
+ @torch.no_grad()
287
+ def reset_parameters(self):
288
+ # initialize the parameters of iRPE
289
+ if self.transposed:
290
+ if self.mode == 'bias':
291
+ self.lookup_table_bias = nn.Parameter(
292
+ torch.zeros(self.num_heads, self.num_buckets))
293
+ self.initializer(self.lookup_table_bias)
294
+ elif self.mode == 'contextual':
295
+ # shared and initialized from vit
296
+ pass
297
+ else:
298
+ if self.mode == 'bias':
299
+ raise NotImplementedError(
300
+ "[Error] Bias non-transposed RPE does not exist.")
301
+ elif self.mode == 'contextual':
302
+ raise ValueError('may not work, check')
303
+
304
+ def forward(self, x, height=None, width=None):
305
+ """forward function for iRPE.
306
+
307
+ Parameters
308
+ ----------
309
+ x: torch.Tensor
310
+ Input Tensor whose shape is (B, H, L, head_dim),
311
+ where B is batch size,
312
+ H is the number of heads,
313
+ L is the sequence length,
314
+ equal to height * width (+1 if class token exists)
315
+ head_dim is the dimension of each head
316
+
317
+ Returns
318
+ -------
319
+ rpe_encoding: torch.Tensor
320
+ image Relative Position Encoding,
321
+ whose shape is (B, H, L, L)
322
+ """
323
+ rp_bucket, self._ctx_rp_bucket_flatten = \
324
+ self._get_rp_bucket(x, height=height, width=width)
325
+
326
+ if self.transposed:
327
+ return self.forward_rpe_transpose(x, rp_bucket)
328
+ return self.forward_rpe_no_transpose(x, rp_bucket)
329
+
330
+ def _get_rp_bucket(self, x, height=None, width=None):
331
+ """Get relative position encoding buckets IDs corresponding the input shape
332
+
333
+ Parameters
334
+ ----------
335
+ x: torch.Tensor
336
+ Input Tensor whose shape is (B, H, L, head_dim),
337
+ where B is batch size,
338
+ H is the number of heads,
339
+ L is the sequence length,
340
+ equal to height * width (+1 if class token exists)
341
+ head_dim is the dimension of each head
342
+ height: int or None
343
+ [Optional] The height of the input
344
+ If not defined, height = floor(sqrt(L))
345
+ width: int or None
346
+ [Optional] The width of the input
347
+ If not defined, width = floor(sqrt(L))
348
+
349
+ Returns
350
+ -------
351
+ rp_bucket: torch.Tensor
352
+ relative position encoding buckets IDs
353
+ The shape is (L, L)
354
+ _ctx_rp_bucket_flatten: torch.Tensor or None
355
+ It is a private tensor for efficient computation.
356
+ """
357
+ B, H, L, D = x.shape
358
+ device = x.device
359
+ if height is None:
360
+ E = int(math.sqrt(L))
361
+ height = width = E
362
+ key = (height, width, device)
363
+ # use buffer if the spatial shape and device is not changable.
364
+
365
+ if self._rp_bucket_buf[0] == key:
366
+ return self._rp_bucket_buf[1:3]
367
+
368
+ skip = L - height * width
369
+ config = self.rpe_config
370
+ if RPEIndexFunction is not None and self.mode == 'contextual' and self.transposed:
371
+ # RPEIndexFunction uses int32 index.
372
+ dtype = torch.int32
373
+ else:
374
+ dtype = torch.long
375
+ rp_bucket, num_buckets = get_bucket_ids_2d(method=self.method,
376
+ height=height, width=width,
377
+ skip=skip, alpha=config.alpha,
378
+ beta=config.beta, gamma=config.gamma,
379
+ dtype=dtype, device=device)
380
+ assert num_buckets == self.num_buckets
381
+
382
+ # transposed contextual
383
+ _ctx_rp_bucket_flatten = None
384
+ if self.mode == 'contextual' and self.transposed:
385
+ if RPEIndexFunction is None:
386
+ offset = torch.arange(0, L * self.num_buckets, self.num_buckets,
387
+ dtype=rp_bucket.dtype, device=rp_bucket.device).view(-1, 1)
388
+ _ctx_rp_bucket_flatten = (rp_bucket + offset).flatten()
389
+ self._rp_bucket_buf = (key, rp_bucket, _ctx_rp_bucket_flatten)
390
+ return rp_bucket, _ctx_rp_bucket_flatten
391
+
392
+ def forward_rpe_transpose(self, x, rp_bucket):
393
+ """Forward function for iRPE (transposed version)
394
+ This version is utilized by RPE on Query or Key
395
+
396
+ Parameters
397
+ ----------
398
+ x: torch.Tensor
399
+ Input Tensor whose shape is (B, H, L, head_dim),
400
+ where B is batch size,
401
+ H is the number of heads,
402
+ L is the sequence length,
403
+ equal to height * width (+1 if class token exists)
404
+ head_dim is the dimension of each head
405
+ rp_bucket: torch.Tensor
406
+ relative position encoding buckets IDs
407
+ The shape is (L, L)
408
+
409
+ Weights
410
+ -------
411
+ lookup_table_bias: torch.Tensor
412
+ The shape is (H or 1, num_buckets)
413
+
414
+ or
415
+
416
+ lookup_table_weight: torch.Tensor
417
+ The shape is (H or 1, head_dim, num_buckets)
418
+
419
+ Returns
420
+ -------
421
+ output: torch.Tensor
422
+ Relative position encoding on queries or keys.
423
+ The shape is (B or 1, H, L, L),
424
+ where D is the output dimension for each head.
425
+ """
426
+
427
+ B = len(x) # batch_size
428
+ L_query, L_key = rp_bucket.shape
429
+ if self.mode == 'bias':
430
+ return self.lookup_table_bias[:, rp_bucket.flatten()]. \
431
+ view(1, self.num_heads, L_query, L_key)
432
+
433
+ elif self.mode == 'contextual':
434
+ """
435
+ ret[b, h, i, j] = lookup_table_weight[b, h, i, rp_bucket[i, j]]
436
+
437
+ ret[b, h, i * L_key + j] = \
438
+ lookup_table[b, h, i * num_buckets + rp_buckets[i, j]]
439
+
440
+ computational cost
441
+ ------------------
442
+ matmul: B * H * L_query * head_dim * num_buckets
443
+ index: L_query + L_query * L_key + B * H * L_query * L_key
444
+ total: O(B * H * L_query * (head_dim * num_buckets + L_key))
445
+ """
446
+ if RPEIndexFunction is not None:
447
+ return RPEIndexFunction.apply(x, rp_bucket)
448
+ else:
449
+ return x.flatten(2)[:, :, self._ctx_rp_bucket_flatten]. \
450
+ view(B, -1, L_query, L_key)
451
+
452
+ def forward_rpe_no_transpose(self, x, rp_bucket):
453
+ """Forward function for iRPE (non-transposed version)
454
+ This version is utilized by RPE on Value.
455
+
456
+ Parameters
457
+ ----------
458
+ x: torch.Tensor
459
+ Input Tensor whose shape is (B, H, L, head_dim),
460
+ where B is batch size,
461
+ H is the number of heads,
462
+ L is the sequence length,
463
+ equal to height * width (+1 if class token exists)
464
+ head_dim is the dimension of each head
465
+ rp_bucket: torch.Tensor
466
+ relative position encoding buckets IDs
467
+ The shape is (L, L)
468
+
469
+ Weights
470
+ -------
471
+ lookup_table_weight: torch.Tensor
472
+ The shape is (H or 1, num_buckets, head_dim)
473
+
474
+ Returns
475
+ -------
476
+ output: torch.Tensor
477
+ Relative position encoding on values.
478
+ The shape is (B, H, L, D),
479
+ where D is the output dimension for each head.
480
+ """
481
+
482
+ B = len(x) # batch_size
483
+ L_query, L_key = rp_bucket.shape
484
+ assert self.mode == 'contextual', "Only support contextual \
485
+ version in non-transposed version"
486
+ weight = self.lookup_table_weight[:, rp_bucket.flatten()]. \
487
+ view(self.num_heads, L_query, L_key, self.head_dim)
488
+ # (H, L_query, B, L_key) @ (H, L_query, L_key, D) = (H, L_query, B, D)
489
+ # -> (B, H, L_query, D)
490
+ return torch.matmul(x.permute(1, 2, 0, 3), weight).permute(2, 0, 1, 3)
491
+
492
+ def __repr__(self):
493
+ return 'iRPE(head_dim={rpe.head_dim}, num_heads={rpe.num_heads}, \
494
+ mode="{rpe.mode}", method={rpe.method}, transposed={rpe.transposed}, \
495
+ num_buckets={rpe.num_buckets}, initializer={rpe.initializer}, \
496
+ rpe_config={rpe.rpe_config})'.format(rpe=self)
497
+
498
+
499
+ class iRPE_Cross(nn.Module):
500
+ """The implementation of image relative position encoding (specific for Cross method).
501
+
502
+ Parameters
503
+ ----------
504
+ head_dim: int
505
+ The dimension for each head.
506
+ num_heads: int
507
+ The number of parallel attention heads.
508
+ mode: str or None
509
+ The mode of image relative position encoding.
510
+ Choices: [None, 'bias', 'contextual']
511
+ method: METHOD
512
+ The method ID of image relative position encoding.
513
+ The `METHOD` class is defined in `irpe.py`.
514
+ transposed: bool
515
+ Whether to transpose the input feature.
516
+ For iRPE on queries or keys, transposed should be `True`.
517
+ For iRPE on values, transposed should be `False`.
518
+ num_buckets: int
519
+ The number of buckets, which store encodings.
520
+ initializer: None or an inplace function
521
+ [Optional] The initializer to `lookup_table`.
522
+ Initalize `lookup_table` as zero by default.
523
+ rpe_config: RPEConfig
524
+ The config generated by the function `get_single_rpe_config`.
525
+ """
526
+
527
+ def __init__(self, method, **kwargs):
528
+ super().__init__()
529
+ assert method == METHOD.CROSS
530
+ self.rp_rows = iRPE(**kwargs, method=METHOD.CROSS_ROWS)
531
+ self.rp_cols = iRPE(**kwargs, method=METHOD.CROSS_COLS)
532
+
533
+ def forward(self, x, height=None, width=None):
534
+ """forward function for iRPE.
535
+ Compute encoding on horizontal and vertical directions separately,
536
+ then summarize them.
537
+
538
+ Parameters
539
+ ----------
540
+ x: torch.Tensor
541
+ Input Tensor whose shape is (B, H, L, head_dim),
542
+ where B is batch size,
543
+ H is the number of heads,
544
+ L is the sequence length,
545
+ equal to height * width (+1 if class token exists)
546
+ head_dim is the dimension of each head
547
+ height: int or None
548
+ [Optional] The height of the input
549
+ If not defined, height = floor(sqrt(L))
550
+ width: int or None
551
+ [Optional] The width of the input
552
+ If not defined, width = floor(sqrt(L))
553
+
554
+ Returns
555
+ -------
556
+ rpe_encoding: torch.Tensor
557
+ Image Relative Position Encoding,
558
+ whose shape is (B, H, L, L)
559
+ """
560
+
561
+ rows = self.rp_rows(x, height=height, width=width)
562
+ cols = self.rp_cols(x, height=height, width=width)
563
+ return rows + cols
564
+
565
+ def __repr__(self):
566
+ return 'iRPE_Cross(head_dim={rpe.head_dim}, \
567
+ num_heads={rpe.num_heads}, mode="{rpe.mode}", method={rpe.method}, \
568
+ transposed={rpe.transposed}, num_buckets={rpe.num_buckets}, \
569
+ initializer={rpe.initializer}, \
570
+ rpe_config={rpe.rpe_config})'.format(rpe=self.rp_rows)
571
+
572
+
573
+ def get_single_rpe_config(ratio=1.9,
574
+ method=METHOD.PRODUCT,
575
+ mode='contextual',
576
+ shared_head=True,
577
+ skip=0):
578
+ """Get the config of single relative position encoding
579
+
580
+ Parameters
581
+ ----------
582
+ ratio: float
583
+ The ratio to control the number of buckets.
584
+ method: METHOD
585
+ The method ID of image relative position encoding.
586
+ The `METHOD` class is defined in `irpe.py`.
587
+ mode: str or None
588
+ The mode of image relative position encoding.
589
+ Choices: [None, 'bias', 'contextual']
590
+ shared_head: bool
591
+ Whether to share weight among different heads.
592
+ skip: int
593
+ The number of skip token before spatial tokens.
594
+ When skip is 0, no classification token.
595
+ When skip is 1, there is a classification token before spatial tokens.
596
+ When skip > 1, there are `skip` extra tokens before spatial tokens.
597
+
598
+ Returns
599
+ -------
600
+ config: RPEConfig
601
+ The config of single relative position encoding.
602
+ """
603
+ config = edict()
604
+ # whether to share encodings across different heads
605
+ config.shared_head = shared_head
606
+ # mode: None, bias, contextual
607
+ config.mode = mode
608
+ # method: None, Bias, Quant, Cross, Product
609
+ config.method = method
610
+ # the coefficients of piecewise index function
611
+ config.alpha = 1 * ratio
612
+ config.beta = 2 * ratio
613
+ config.gamma = 8 * ratio
614
+
615
+ # set the number of buckets
616
+ config.num_buckets = get_num_buckets(method,
617
+ config.alpha,
618
+ config.beta,
619
+ config.gamma)
620
+ # add extra bucket for `skip` token (e.g. class token)
621
+ if skip > 0:
622
+ config.num_buckets += 1
623
+ return config
624
+
625
+
626
+ def get_rpe_config(ratio=1.9,
627
+ method=METHOD.PRODUCT,
628
+ mode='contextual',
629
+ shared_head=True,
630
+ skip=0,
631
+ rpe_on='k'):
632
+ """Get the config of relative position encoding on queries, keys and values
633
+
634
+ Parameters
635
+ ----------
636
+ ratio: float
637
+ The ratio to control the number of buckets.
638
+ method: METHOD or str
639
+ The method ID (or name) of image relative position encoding.
640
+ The `METHOD` class is defined in `irpe.py`.
641
+ mode: str or None
642
+ The mode of image relative position encoding.
643
+ Choices: [None, 'bias', 'contextual']
644
+ shared_head: bool
645
+ Whether to share weight among different heads.
646
+ skip: int
647
+ The number of skip token before spatial tokens.
648
+ When skip is 0, no classification token.
649
+ When skip is 1, there is a classification token before spatial tokens.
650
+ When skip > 1, there are `skip` extra tokens before spatial tokens.
651
+ rpe_on: str
652
+ Where RPE attaches.
653
+ "q": RPE on queries
654
+ "k": RPE on keys
655
+ "v": RPE on values
656
+ "qk": RPE on queries and keys
657
+ "qkv": RPE on queries, keys and values
658
+
659
+ Returns
660
+ -------
661
+ config: RPEConfigs
662
+ config.rpe_q: the config of relative position encoding on queries
663
+ config.rpe_k: the config of relative position encoding on keys
664
+ config.rpe_v: the config of relative position encoding on values
665
+ """
666
+
667
+ # alias
668
+ if isinstance(method, str):
669
+ method_mapping = dict(
670
+ euc=METHOD.EUCLIDEAN,
671
+ quant=METHOD.QUANT,
672
+ cross=METHOD.CROSS,
673
+ product=METHOD.PRODUCT,
674
+ )
675
+ method = method_mapping[method.lower()]
676
+ if mode == 'ctx':
677
+ mode = 'contextual'
678
+ config = edict()
679
+ # relative position encoding on queries, keys and values
680
+ kwargs = dict(
681
+ ratio=ratio,
682
+ method=method,
683
+ mode=mode,
684
+ shared_head=shared_head,
685
+ skip=skip,
686
+ )
687
+ config.rpe_q = get_single_rpe_config(**kwargs) if 'q' in rpe_on else None
688
+ config.rpe_k = get_single_rpe_config(**kwargs) if 'k' in rpe_on else None
689
+ config.rpe_v = get_single_rpe_config(**kwargs) if 'v' in rpe_on else None
690
+ return config
691
+
692
+
693
+ def build_rpe(config, head_dim, num_heads):
694
+ """Build iRPE modules on queries, keys and values.
695
+
696
+ Parameters
697
+ ----------
698
+ config: RPEConfigs
699
+ config.rpe_q: the config of relative position encoding on queries
700
+ config.rpe_k: the config of relative position encoding on keys
701
+ config.rpe_v: the config of relative position encoding on values
702
+ None when RPE is not used.
703
+ head_dim: int
704
+ The dimension for each head.
705
+ num_heads: int
706
+ The number of parallel attention heads.
707
+
708
+ Returns
709
+ -------
710
+ modules: a list of nn.Module
711
+ The iRPE Modules on [queries, keys, values].
712
+ None when RPE is not used.
713
+ """
714
+ if config is None:
715
+ return None, None, None
716
+ rpes = [config.rpe_q, config.rpe_k, config.rpe_v]
717
+ transposeds = [True, True, False]
718
+
719
+ def _build_single_rpe(rpe, transposed):
720
+ if rpe is None:
721
+ return None
722
+
723
+ rpe_cls = iRPE if rpe.method != METHOD.CROSS else iRPE_Cross
724
+ return rpe_cls(
725
+ head_dim=head_dim,
726
+ num_heads=1 if rpe.shared_head else num_heads,
727
+ mode=rpe.mode,
728
+ method=rpe.method,
729
+ transposed=transposed,
730
+ num_buckets=rpe.num_buckets,
731
+ rpe_config=rpe,
732
+ )
733
+ return [_build_single_rpe(rpe, transposed)
734
+ for rpe, transposed in zip(rpes, transposeds)]
735
+