bos_token + readme
Browse files- modeling_lsg_camembert.py +39 -12
modeling_lsg_camembert.py
CHANGED
@@ -53,16 +53,16 @@ class LSGCamembertConfig(CamembertConfig):
|
|
53 |
self.sparsity_factor = sparsity_factor
|
54 |
self.sparsity_type = sparsity_type
|
55 |
|
56 |
-
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
57 |
logger.warning(
|
58 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
59 |
setting sparsity_type=None, computation will skip sparse attention")
|
60 |
self.sparsity_type = None
|
61 |
|
62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
63 |
-
if self.sparsity_factor > self.
|
64 |
logger.warning(
|
65 |
-
"[WARNING CONFIG]: sparsity_factor >
|
66 |
)
|
67 |
|
68 |
if self.num_global_tokens < 1:
|
@@ -497,15 +497,16 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
497 |
"lsh": self.get_sparse_tokens_with_lsh,
|
498 |
"stride": self.get_sparse_tokens_with_stride,
|
499 |
"block_stride": self.get_sparse_tokens_with_block_stride,
|
|
|
500 |
}
|
501 |
|
502 |
self.sparsity_type = config.sparsity_type
|
503 |
-
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
|
504 |
|
505 |
if config.sparsity_type == "lsh":
|
506 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
507 |
|
508 |
-
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
509 |
|
510 |
if self.sparsity_factor == 1:
|
511 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -533,7 +534,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
533 |
|
534 |
return keys, values, mask
|
535 |
|
536 |
-
def get_sparse_tokens_with_pooling(self, keys, values, mask):
|
537 |
|
538 |
if self.sparsity_factor == 1:
|
539 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -556,7 +557,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
556 |
mask *= torch.finfo(mask.dtype).min
|
557 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
558 |
|
559 |
-
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
560 |
|
561 |
if self.sparsity_factor == 1:
|
562 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -572,7 +573,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
572 |
|
573 |
return keys, values, mask
|
574 |
|
575 |
-
def get_sparse_tokens_with_block_stride(self, keys, values, mask):
|
576 |
|
577 |
if self.sparsity_factor == 1:
|
578 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -592,11 +593,14 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
592 |
|
593 |
return keys, values, mask
|
594 |
|
595 |
-
def get_sparse_tokens_with_lsh(self, keys, values, mask):
|
596 |
|
597 |
if self.sparsity_factor == 1:
|
598 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
599 |
|
|
|
|
|
|
|
600 |
block_size = min(self.block_size, self.sparse_block_size)
|
601 |
keys = self.chunk(keys, block_size)
|
602 |
values = self.chunk(values, block_size)
|
@@ -644,6 +648,29 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
644 |
|
645 |
return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
|
646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
def forward(
|
648 |
self,
|
649 |
hidden_states,
|
@@ -765,7 +792,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
765 |
# Get sparse idx
|
766 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
767 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
768 |
-
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
|
769 |
|
770 |
# Expand masks on heads
|
771 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
@@ -838,7 +865,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
838 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
839 |
|
840 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
841 |
-
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
|
842 |
|
843 |
# Expand masks on heads
|
844 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
|
|
53 |
self.sparsity_factor = sparsity_factor
|
54 |
self.sparsity_type = sparsity_type
|
55 |
|
56 |
+
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
|
57 |
logger.warning(
|
58 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
|
59 |
setting sparsity_type=None, computation will skip sparse attention")
|
60 |
self.sparsity_type = None
|
61 |
|
62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
63 |
+
if self.sparsity_factor > self.num_attention_heads:
|
64 |
logger.warning(
|
65 |
+
"[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
|
66 |
)
|
67 |
|
68 |
if self.num_global_tokens < 1:
|
|
|
497 |
"lsh": self.get_sparse_tokens_with_lsh,
|
498 |
"stride": self.get_sparse_tokens_with_stride,
|
499 |
"block_stride": self.get_sparse_tokens_with_block_stride,
|
500 |
+
"bos_pooling": self.get_sparse_tokens_with_bos_pooling
|
501 |
}
|
502 |
|
503 |
self.sparsity_type = config.sparsity_type
|
504 |
+
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
|
505 |
|
506 |
if config.sparsity_type == "lsh":
|
507 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
508 |
|
509 |
+
def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
|
510 |
|
511 |
if self.sparsity_factor == 1:
|
512 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
534 |
|
535 |
return keys, values, mask
|
536 |
|
537 |
+
def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
|
538 |
|
539 |
if self.sparsity_factor == 1:
|
540 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
557 |
mask *= torch.finfo(mask.dtype).min
|
558 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
559 |
|
560 |
+
def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
|
561 |
|
562 |
if self.sparsity_factor == 1:
|
563 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
573 |
|
574 |
return keys, values, mask
|
575 |
|
576 |
+
def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
|
577 |
|
578 |
if self.sparsity_factor == 1:
|
579 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
593 |
|
594 |
return keys, values, mask
|
595 |
|
596 |
+
def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
|
597 |
|
598 |
if self.sparsity_factor == 1:
|
599 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
600 |
|
601 |
+
if self.sparsity_factor == self.sparse_block_size:
|
602 |
+
return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
|
603 |
+
|
604 |
block_size = min(self.block_size, self.sparse_block_size)
|
605 |
keys = self.chunk(keys, block_size)
|
606 |
values = self.chunk(values, block_size)
|
|
|
648 |
|
649 |
return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
|
650 |
|
651 |
+
def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
|
652 |
+
|
653 |
+
if self.sparsity_factor == 1:
|
654 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
655 |
+
|
656 |
+
queries = queries.unsqueeze(-3)
|
657 |
+
mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
|
658 |
+
keys = self.chunk(keys, self.sparsity_factor)
|
659 |
+
values = self.chunk(values, self.sparsity_factor)
|
660 |
+
|
661 |
+
n, h, b, t, d = keys.size()
|
662 |
+
scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
|
663 |
+
if mask is not None:
|
664 |
+
scores = scores + mask
|
665 |
+
|
666 |
+
scores = torch.softmax(scores, dim=-1)
|
667 |
+
keys = scores @ keys
|
668 |
+
values = scores @ values
|
669 |
+
mask = mask.mean(dim=-1)
|
670 |
+
mask[mask != torch.finfo(mask.dtype).min] = 0
|
671 |
+
|
672 |
+
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
673 |
+
|
674 |
def forward(
|
675 |
self,
|
676 |
hidden_states,
|
|
|
792 |
# Get sparse idx
|
793 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
794 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
795 |
+
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
|
796 |
|
797 |
# Expand masks on heads
|
798 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
|
|
865 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
866 |
|
867 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
868 |
+
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
|
869 |
|
870 |
# Expand masks on heads
|
871 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|