Spaces:
Runtime error
Runtime error
File size: 27,256 Bytes
c2dad70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 |
# Copyright (c) 2019-present, Meta, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# First author is Simon Rouard.
import random
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from einops import rearrange
def create_sin_embedding(
length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
):
# We aim for TBC format
assert dim % 2 == 0
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
half_dim = dim // 2
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
phase = pos / (max_period ** (adim / (half_dim - 1)))
return torch.cat(
[
torch.cos(phase),
torch.sin(phase),
],
dim=-1,
)
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model)
)
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(
torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
)
pos_w = torch.arange(0.0, width).unsqueeze(1)
pos_h = torch.arange(0.0, height).unsqueeze(1)
pe[0:d_model:2, :, :] = (
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[1:d_model:2, :, :] = (
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[d_model::2, :, :] = (
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
pe[d_model + 1:: 2, :, :] = (
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
return pe[None, :].to(device)
def create_sin_embedding_cape(
length: int,
dim: int,
batch_size: int,
mean_normalize: bool,
augment: bool, # True during training
max_global_shift: float = 0.0, # delta max
max_local_shift: float = 0.0, # epsilon max
max_scale: float = 1.0,
device: str = "cpu",
max_period: float = 10000.0,
):
# We aim for TBC format
assert dim % 2 == 0
pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
if mean_normalize:
pos -= torch.nanmean(pos, dim=0, keepdim=True)
if augment:
delta = np.random.uniform(
-max_global_shift, +max_global_shift, size=[1, batch_size, 1]
)
delta_local = np.random.uniform(
-max_local_shift, +max_local_shift, size=[length, batch_size, 1]
)
log_lambdas = np.random.uniform(
-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
)
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
pos = pos.to(device)
half_dim = dim // 2
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
phase = pos / (max_period ** (adim / (half_dim - 1)))
return torch.cat(
[
torch.cos(phase),
torch.sin(phase),
],
dim=-1,
).float()
def get_causal_mask(length):
pos = torch.arange(length)
return pos > pos[:, None]
def get_elementary_mask(
T1,
T2,
mask_type,
sparse_attn_window,
global_window,
mask_random_seed,
sparsity,
device,
):
"""
When the input of the Decoder has length T1 and the output T2
The mask matrix has shape (T2, T1)
"""
assert mask_type in ["diag", "jmask", "random", "global"]
if mask_type == "global":
mask = torch.zeros(T2, T1, dtype=torch.bool)
mask[:, :global_window] = True
line_window = int(global_window * T2 / T1)
mask[:line_window, :] = True
if mask_type == "diag":
mask = torch.zeros(T2, T1, dtype=torch.bool)
rows = torch.arange(T2)[:, None]
cols = (
(T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
.long()
.clamp(0, T1 - 1)
)
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
elif mask_type == "jmask":
mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
rows = torch.arange(T2 + 2)[:, None]
t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
t = (t * (t + 1) / 2).int()
t = torch.cat([-t.flip(0)[:-1], t])
cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
mask = mask[1:-1, 1:-1]
elif mask_type == "random":
gene = torch.Generator(device=device)
gene.manual_seed(mask_random_seed)
mask = (
torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
> sparsity
)
mask = mask.to(device)
return mask
def get_mask(
T1,
T2,
mask_type,
sparse_attn_window,
global_window,
mask_random_seed,
sparsity,
device,
):
"""
Return a SparseCSRTensor mask that is a combination of elementary masks
mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
"""
from xformers.sparse import SparseCSRTensor
# create a list
mask_types = mask_type.split("_")
all_masks = [
get_elementary_mask(
T1,
T2,
mask,
sparse_attn_window,
global_window,
mask_random_seed,
sparsity,
device,
)
for mask in mask_types
]
final_mask = torch.stack(all_masks).sum(axis=0) > 0
return SparseCSRTensor.from_dense(final_mask[None])
class ScaledEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
scale: float = 1.0,
boost: float = 3.0,
):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data *= scale / boost
self.boost = boost
@property
def weight(self):
return self.embedding.weight * self.boost
def forward(self, x):
return self.embedding(x) * self.boost
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonaly residual outputs close to 0 initially, then learnt.
"""
def __init__(self, channels: int, init: float = 0, channel_last=False):
"""
channel_last = False corresponds to (B, C, T) tensors
channel_last = True corresponds to (T, B, C) tensors
"""
super().__init__()
self.channel_last = channel_last
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
self.scale.data[:] = init
def forward(self, x):
if self.channel_last:
return self.scale * x
else:
return self.scale[:, None] * x
class MyGroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
"""
x: (B, T, C)
if num_groups=1: Normalisation on all T and C together for each B
"""
x = x.transpose(1, 2)
return super().forward(x).transpose(1, 2)
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation=F.relu,
group_norm=0,
norm_first=False,
norm_out=False,
layer_norm_eps=1e-5,
layer_scale=False,
init_values=1e-4,
device=None,
dtype=None,
sparse=False,
mask_type="diag",
mask_random_seed=42,
sparse_attn_window=500,
global_window=50,
auto_sparsity=False,
sparsity=0.95,
batch_first=False,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=batch_first,
norm_first=norm_first,
device=device,
dtype=dtype,
)
self.sparse = sparse
self.auto_sparsity = auto_sparsity
if sparse:
if not auto_sparsity:
self.mask_type = mask_type
self.sparse_attn_window = sparse_attn_window
self.global_window = global_window
self.sparsity = sparsity
if group_norm:
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm_out = None
if self.norm_first & norm_out:
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
self.gamma_1 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
self.gamma_2 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
if sparse:
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first,
auto_sparsity=sparsity if auto_sparsity else 0,
)
self.__setattr__("src_mask", torch.zeros(1, 1))
self.mask_random_seed = mask_random_seed
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
if batch_first = False, src shape is (T, B, C)
the case where batch_first=True is not covered
"""
device = src.device
x = src
T, B, C = x.shape
if self.sparse and not self.auto_sparsity:
assert src_mask is None
src_mask = self.src_mask
if src_mask.shape[-1] != T:
src_mask = get_mask(
T,
T,
self.mask_type,
self.sparse_attn_window,
self.global_window,
self.mask_random_seed,
self.sparsity,
device,
)
self.__setattr__("src_mask", src_mask)
if self.norm_first:
x = x + self.gamma_1(
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
)
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
if self.norm_out:
x = self.norm_out(x)
else:
x = self.norm1(
x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
)
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
return x
class CrossTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation=F.relu,
layer_norm_eps: float = 1e-5,
layer_scale: bool = False,
init_values: float = 1e-4,
norm_first: bool = False,
group_norm: bool = False,
norm_out: bool = False,
sparse=False,
mask_type="diag",
mask_random_seed=42,
sparse_attn_window=500,
global_window=50,
sparsity=0.95,
auto_sparsity=None,
device=None,
dtype=None,
batch_first=False,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.sparse = sparse
self.auto_sparsity = auto_sparsity
if sparse:
if not auto_sparsity:
self.mask_type = mask_type
self.sparse_attn_window = sparse_attn_window
self.global_window = global_window
self.sparsity = sparsity
self.cross_attn: nn.Module
self.cross_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.norm1: nn.Module
self.norm2: nn.Module
self.norm3: nn.Module
if group_norm:
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
else:
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm_out = None
if self.norm_first & norm_out:
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
self.gamma_1 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
self.gamma_2 = (
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = self._get_activation_fn(activation)
else:
self.activation = activation
if sparse:
self.cross_attn = MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first,
auto_sparsity=sparsity if auto_sparsity else 0)
if not auto_sparsity:
self.__setattr__("mask", torch.zeros(1, 1))
self.mask_random_seed = mask_random_seed
def forward(self, q, k, mask=None):
"""
Args:
q: tensor of shape (T, B, C)
k: tensor of shape (S, B, C)
mask: tensor of shape (T, S)
"""
device = q.device
T, B, C = q.shape
S, B, C = k.shape
if self.sparse and not self.auto_sparsity:
assert mask is None
mask = self.mask
if mask.shape[-1] != S or mask.shape[-2] != T:
mask = get_mask(
S,
T,
self.mask_type,
self.sparse_attn_window,
self.global_window,
self.mask_random_seed,
self.sparsity,
device,
)
self.__setattr__("mask", mask)
if self.norm_first:
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
if self.norm_out:
x = self.norm_out(x)
else:
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
return x
# self-attention block
def _ca_block(self, q, k, attn_mask=None):
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
def _get_activation_fn(self, activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
# ----------------- MULTI-BLOCKS MODELS: -----------------------
class CrossTransformerEncoder(nn.Module):
def __init__(
self,
dim: int,
emb: str = "sin",
hidden_scale: float = 4.0,
num_heads: int = 8,
num_layers: int = 6,
cross_first: bool = False,
dropout: float = 0.0,
max_positions: int = 1000,
norm_in: bool = True,
norm_in_group: bool = False,
group_norm: int = False,
norm_first: bool = False,
norm_out: bool = False,
max_period: float = 10000.0,
weight_decay: float = 0.0,
lr: tp.Optional[float] = None,
layer_scale: bool = False,
gelu: bool = True,
sin_random_shift: int = 0,
weight_pos_embed: float = 1.0,
cape_mean_normalize: bool = True,
cape_augment: bool = True,
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
sparse_self_attn: bool = False,
sparse_cross_attn: bool = False,
mask_type: str = "diag",
mask_random_seed: int = 42,
sparse_attn_window: int = 500,
global_window: int = 50,
auto_sparsity: bool = False,
sparsity: float = 0.95,
):
super().__init__()
"""
"""
assert dim % num_heads == 0
hidden_dim = int(dim * hidden_scale)
self.num_layers = num_layers
# classic parity = 1 means that if idx%2 == 1 there is a
# classical encoder else there is a cross encoder
self.classic_parity = 1 if cross_first else 0
self.emb = emb
self.max_period = max_period
self.weight_decay = weight_decay
self.weight_pos_embed = weight_pos_embed
self.sin_random_shift = sin_random_shift
if emb == "cape":
self.cape_mean_normalize = cape_mean_normalize
self.cape_augment = cape_augment
self.cape_glob_loc_scale = cape_glob_loc_scale
if emb == "scaled":
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
self.lr = lr
activation: tp.Any = F.gelu if gelu else F.relu
self.norm_in: nn.Module
self.norm_in_t: nn.Module
if norm_in:
self.norm_in = nn.LayerNorm(dim)
self.norm_in_t = nn.LayerNorm(dim)
elif norm_in_group:
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
else:
self.norm_in = nn.Identity()
self.norm_in_t = nn.Identity()
# spectrogram layers
self.layers = nn.ModuleList()
# temporal layers
self.layers_t = nn.ModuleList()
kwargs_common = {
"d_model": dim,
"nhead": num_heads,
"dim_feedforward": hidden_dim,
"dropout": dropout,
"activation": activation,
"group_norm": group_norm,
"norm_first": norm_first,
"norm_out": norm_out,
"layer_scale": layer_scale,
"mask_type": mask_type,
"mask_random_seed": mask_random_seed,
"sparse_attn_window": sparse_attn_window,
"global_window": global_window,
"sparsity": sparsity,
"auto_sparsity": auto_sparsity,
"batch_first": True,
}
kwargs_classic_encoder = dict(kwargs_common)
kwargs_classic_encoder.update({
"sparse": sparse_self_attn,
})
kwargs_cross_encoder = dict(kwargs_common)
kwargs_cross_encoder.update({
"sparse": sparse_cross_attn,
})
for idx in range(num_layers):
if idx % 2 == self.classic_parity:
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
self.layers_t.append(
MyTransformerEncoderLayer(**kwargs_classic_encoder)
)
else:
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
self.layers_t.append(
CrossTransformerEncoderLayer(**kwargs_cross_encoder)
)
def forward(self, x, xt):
B, C, Fr, T1 = x.shape
pos_emb_2d = create_2d_sin_embedding(
C, Fr, T1, x.device, self.max_period
) # (1, C, Fr, T1)
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
x = self.norm_in(x)
x = x + self.weight_pos_embed * pos_emb_2d
B, C, T2 = xt.shape
xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
pos_emb = self._get_pos_embedding(T2, B, C, x.device)
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
xt = self.norm_in_t(xt)
xt = xt + self.weight_pos_embed * pos_emb
for idx in range(self.num_layers):
if idx % 2 == self.classic_parity:
x = self.layers[idx](x)
xt = self.layers_t[idx](xt)
else:
old_x = x
x = self.layers[idx](x, xt)
xt = self.layers_t[idx](xt, old_x)
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
xt = rearrange(xt, "b t2 c -> b c t2")
return x, xt
def _get_pos_embedding(self, T, B, C, device):
if self.emb == "sin":
shift = random.randrange(self.sin_random_shift + 1)
pos_emb = create_sin_embedding(
T, C, shift=shift, device=device, max_period=self.max_period
)
elif self.emb == "cape":
if self.training:
pos_emb = create_sin_embedding_cape(
T,
C,
B,
device=device,
max_period=self.max_period,
mean_normalize=self.cape_mean_normalize,
augment=self.cape_augment,
max_global_shift=self.cape_glob_loc_scale[0],
max_local_shift=self.cape_glob_loc_scale[1],
max_scale=self.cape_glob_loc_scale[2],
)
else:
pos_emb = create_sin_embedding_cape(
T,
C,
B,
device=device,
max_period=self.max_period,
mean_normalize=self.cape_mean_normalize,
augment=False,
)
elif self.emb == "scaled":
pos = torch.arange(T, device=device)
pos_emb = self.position_embeddings(pos)[:, None]
return pos_emb
def make_optim_group(self):
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
if self.lr is not None:
group["lr"] = self.lr
return group
# Attention Modules
class MultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
auto_sparsity=None,
):
super().__init__()
assert auto_sparsity is not None, "sanity check"
self.num_heads = num_heads
self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
self.attn_drop = torch.nn.Dropout(dropout)
self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
self.proj_drop = torch.nn.Dropout(dropout)
self.batch_first = batch_first
self.auto_sparsity = auto_sparsity
def forward(
self,
query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=True,
):
if not self.batch_first: # N, B, C
query = query.permute(1, 0, 2) # B, N_q, C
key = key.permute(1, 0, 2) # B, N_k, C
value = value.permute(1, 0, 2) # B, N_k, C
B, N_q, C = query.shape
B, N_k, C = key.shape
q = (
self.q(query)
.reshape(B, N_q, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
q = q.flatten(0, 1)
k = (
self.k(key)
.reshape(B, N_k, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = k.flatten(0, 1)
v = (
self.v(value)
.reshape(B, N_k, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v = v.flatten(0, 1)
if self.auto_sparsity:
assert attn_mask is None
x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
else:
x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
x = x.transpose(1, 2).reshape(B, N_q, C)
x = self.proj(x)
x = self.proj_drop(x)
if not self.batch_first:
x = x.permute(1, 0, 2)
return x, None
def scaled_query_key_softmax(q, k, att_mask):
from xformers.ops import masked_matmul
q = q / (k.size(-1)) ** 0.5
att = masked_matmul(q, k.transpose(-2, -1), att_mask)
att = torch.nn.functional.softmax(att, -1)
return att
def scaled_dot_product_attention(q, k, v, att_mask, dropout):
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
att = dropout(att)
y = att @ v
return y
def _compute_buckets(x, R):
qq = torch.einsum('btf,bfhi->bhti', x, R)
qq = torch.cat([qq, -qq], dim=-1)
buckets = qq.argmax(dim=-1)
return buckets.permute(0, 2, 1).byte().contiguous()
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
# assert False, "The code for the custom sparse kernel is not ready for release yet."
from xformers.ops import find_locations, sparse_memory_efficient_attention
n_hashes = 32
proj_size = 4
query, key, value = [x.contiguous() for x in [query, key, value]]
with torch.no_grad():
R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
bucket_query = _compute_buckets(query, R)
bucket_key = _compute_buckets(key, R)
row_offsets, column_indices = find_locations(
bucket_query, bucket_key, sparsity, infer_sparsity)
return sparse_memory_efficient_attention(
query, key, value, row_offsets, column_indices, attn_bias)
|