Text Generation
Transformers
PyTorch
English
retnet
custom_code
syncdoth commited on
Commit
7110f83
·
verified ·
1 Parent(s): 2550057

Upload RetNetForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/nfs/checkpoints/RetNet-410m-bs1024-pile_dedup-copy_exp-skip_reten/hf-iter-050000-ckpt",
3
+ "activation_dropout": 0.0,
4
+ "activation_fn": "gelu",
5
+ "architectures": [
6
+ "RetNetForCausalLM"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_retnet.RetNetConfig",
10
+ "AutoModelForCausalLM": "modeling_retnet.RetNetForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "decoder_embed_dim": 1024,
14
+ "decoder_ffn_embed_dim": 4096,
15
+ "decoder_layers": 24,
16
+ "decoder_normalize_before": true,
17
+ "decoder_retention_heads": 4,
18
+ "decoder_value_embed_dim": 1024,
19
+ "deepnorm": false,
20
+ "drop_path_rate": 0.0,
21
+ "dropout": 0.0,
22
+ "eos_token_id": 2,
23
+ "forward_impl": "parallel",
24
+ "groupnorm_affine": false,
25
+ "initializer_range": 0.02,
26
+ "is_decoder": true,
27
+ "layernorm_embedding": false,
28
+ "layernorm_eps": 1e-05,
29
+ "max_position_embeddings": 2048,
30
+ "model_type": "retnet",
31
+ "no_scale_embedding": true,
32
+ "output_retentions": false,
33
+ "parallel_residual": true,
34
+ "recurrent_chunk_size": 512,
35
+ "rotary_percentage": 0.25,
36
+ "subln": false,
37
+ "tie_word_embeddings": false,
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.31.0",
40
+ "use_bias": true,
41
+ "use_cache": false,
42
+ "use_ffn_rms_norm": false,
43
+ "use_glu": false,
44
+ "use_lm_decay": false,
45
+ "use_rms_norm": false,
46
+ "vocab_size": 50254,
47
+ "z_loss_coeff": 0.0
48
+ }
configuration_retnet.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class RetNetConfig(PretrainedConfig):
5
+ model_type = "retnet"
6
+ attribute_map = {
7
+ "hidden_size": "decoder_embed_dim",
8
+ "intermediate_size": "decoder_ffn_embed_dim",
9
+ "num_attention_heads": "decoder_retention_heads",
10
+ "num_hidden_layers": "decoder_layers",
11
+ }
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 50257,
16
+ initializer_range: float = 0.02,
17
+ is_decoder: bool = True,
18
+ bos_token_id: int = None,
19
+ pad_token_id: int = None,
20
+ eos_token_id: int = None,
21
+ output_retentions: bool = False,
22
+ use_cache: bool = True,
23
+ forward_impl: str = 'parallel',
24
+ activation_fn: str = "gelu",
25
+ dropout: float = 0.0, # dropout probability
26
+ activation_dropout: float = 0.0, # dropout probability after activation in FFN.
27
+ drop_path_rate: float = 0.0,
28
+ decoder_embed_dim: int = 768, # decoder embedding dimension
29
+ decoder_value_embed_dim: int = 1280, # decoder value embedding dimension
30
+ decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN
31
+ decoder_layers: int = 12, # num decoder layers
32
+ decoder_retention_heads: int = 3, # num decoder retention heads
33
+ decoder_normalize_before: bool = True, # apply layernorm before each decoder block
34
+ layernorm_embedding: bool = False, # add layernorm to embedding
35
+ no_scale_embedding: bool = True, # if True, dont scale embeddings
36
+ recurrent_chunk_size: int = 512,
37
+ use_glu: bool = True, # use GLU instead of FFN
38
+ z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
39
+ use_lm_decay: bool = False,
40
+ deepnorm: bool = False,
41
+ subln: bool = True,
42
+ use_rms_norm: bool = True,
43
+ groupnorm_affine: bool = False,
44
+ layernorm_eps: float = 1e-6,
45
+ tie_word_embeddings: bool = False,
46
+ use_bias: bool = False,
47
+ parallel_residual: bool = False,
48
+ rotary_percentage: float = 1.0,
49
+ **kwargs):
50
+ self.vocab_size = vocab_size
51
+ self.initializer_range = initializer_range
52
+ self.output_retentions = output_retentions
53
+ self.use_lm_decay = use_lm_decay
54
+ self.use_glu = use_glu
55
+ self.z_loss_coeff = z_loss_coeff
56
+ # size related
57
+ self.decoder_embed_dim = decoder_embed_dim
58
+ self.decoder_value_embed_dim = decoder_value_embed_dim
59
+ self.decoder_retention_heads = decoder_retention_heads
60
+ self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
61
+ self.decoder_layers = decoder_layers
62
+ # normalization related
63
+ self.decoder_normalize_before = decoder_normalize_before
64
+ self.activation_fn = activation_fn
65
+ self.dropout = dropout
66
+ self.drop_path_rate = drop_path_rate
67
+ self.activation_dropout = activation_dropout
68
+ self.no_scale_embedding = no_scale_embedding
69
+ self.layernorm_embedding = layernorm_embedding
70
+ self.deepnorm = deepnorm
71
+ self.subln = subln
72
+ self.use_rms_norm = use_rms_norm
73
+ self.layernorm_eps = layernorm_eps
74
+ self.use_bias = use_bias
75
+ self.groupnorm_affine = groupnorm_affine
76
+ self.parallel_residual = parallel_residual
77
+ # Blockwise
78
+ self.recurrent_chunk_size = recurrent_chunk_size
79
+ self.forward_impl = forward_impl
80
+ # rope
81
+ self.rotary_percentage = rotary_percentage
82
+
83
+ if self.deepnorm:
84
+ self.decoder_normalize_before = False
85
+ self.subln = False
86
+ if self.subln:
87
+ self.decoder_normalize_before = True
88
+ self.deepnorm = False
89
+
90
+ super().__init__(is_decoder=is_decoder,
91
+ bos_token_id=bos_token_id,
92
+ pad_token_id=pad_token_id,
93
+ eos_token_id=eos_token_id,
94
+ use_cache=use_cache,
95
+ tie_word_embeddings=tie_word_embeddings,
96
+ **kwargs)
97
+
98
+ def override(self, args):
99
+ for hp in self.__dict__.keys():
100
+ if getattr(args, hp, None) is not None:
101
+ self.__dict__[hp] = getattr(args, hp, None)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.31.0",
6
+ "use_cache": false
7
+ }
modeling_retnet.py ADDED
@@ -0,0 +1,1416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
+ from transformers import top_k_top_p_filtering
12
+ from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast
13
+ from transformers.modeling_utils import PreTrainedModel
14
+ from transformers.utils import logging
15
+
16
+ try:
17
+ from apex.normalization import FusedLayerNorm as LayerNorm
18
+ except ModuleNotFoundError:
19
+ from torch.nn import LayerNorm
20
+
21
+ from .configuration_retnet import RetNetConfig
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ # helper functions
27
+ def split_heads(tensors, bsz, seqlen, num_heads):
28
+ assert isinstance(tensors, (tuple, list))
29
+ return [x.view(bsz, seqlen, num_heads, -1).transpose(1, 2) for x in tensors]
30
+
31
+
32
+ def rotate_every_two(x):
33
+ x1 = x[..., ::2]
34
+ x2 = x[..., 1::2]
35
+ x = torch.stack((-x2, x1), dim=-1)
36
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
37
+
38
+
39
+ def theta_shift(x, sin, cos):
40
+ return (x * cos) + (rotate_every_two(x) * sin)
41
+
42
+
43
+ def get_activation_fn(activation):
44
+ if activation == "relu":
45
+ return F.relu
46
+ elif activation == "gelu":
47
+ return F.gelu
48
+ elif activation == "swish":
49
+ return F.silu
50
+ else:
51
+ raise NotImplementedError
52
+
53
+
54
+ class RMSNorm(nn.Module):
55
+
56
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
57
+ super().__init__()
58
+ self.normalized_shape = dim
59
+ self.eps = eps
60
+ self.elementwise_affine = elementwise_affine
61
+ if self.elementwise_affine:
62
+ self.weight = nn.Parameter(torch.ones(dim))
63
+ else:
64
+ self.register_parameter("weight", None)
65
+
66
+ def _norm(self, x):
67
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
68
+
69
+ def forward(self, x):
70
+ output = self._norm(x.float()).type_as(x)
71
+ if self.weight is not None:
72
+ output = output * self.weight
73
+ return output
74
+
75
+
76
+ class RetNetRelPos(nn.Module):
77
+
78
+ def __init__(self, config: RetNetConfig):
79
+ super().__init__()
80
+ self.config = config
81
+ num_heads = config.decoder_retention_heads
82
+ n_elem = int(config.decoder_embed_dim // num_heads * config.rotary_percentage)
83
+
84
+ angle = 1.0 / (10000**torch.linspace(0, 1, n_elem // 2))
85
+ angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
86
+ # decay (gamma)
87
+ if config.use_lm_decay:
88
+ # NOTE: alternative way described in the paper
89
+ s = torch.log2(torch.tensor(1 / 32))
90
+ e = torch.log2(torch.tensor(1 / 512))
91
+ decay = torch.log2(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,]
92
+ else:
93
+ decay = torch.log2(1 - 2**(-5 - torch.arange(num_heads, dtype=torch.float)))
94
+ self.register_buffer("angle", angle)
95
+ self.register_buffer("decay", decay)
96
+ self.recurrent_chunk_size = config.recurrent_chunk_size
97
+
98
+ def forward(
99
+ self,
100
+ slen,
101
+ forward_impl="parallel",
102
+ recurrent_chunk_size=None,
103
+ retention_mask=None,
104
+ get_decay_scale=True,
105
+ ):
106
+ if forward_impl == "recurrent":
107
+ sin = torch.sin(self.angle * (slen - 1))
108
+ cos = torch.cos(self.angle * (slen - 1))
109
+ retention_rel_pos = ((sin, cos), self.decay.view(1, -1, 1, 1).exp2())
110
+ elif forward_impl == "chunkwise":
111
+ if recurrent_chunk_size is None:
112
+ recurrent_chunk_size = self.recurrent_chunk_size
113
+ index = torch.arange(slen).to(self.decay)
114
+ sin = torch.sin(index[:, None] * self.angle[None, :])
115
+ cos = torch.cos(index[:, None] * self.angle[None, :])
116
+
117
+ block_index = torch.arange(recurrent_chunk_size).to(self.decay)
118
+ mask = torch.tril(torch.ones(recurrent_chunk_size, recurrent_chunk_size)).to(self.decay)
119
+ mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(),
120
+ float("inf"))
121
+ mask = torch.exp2(mask * self.decay[:, None, None])
122
+ mask = torch.nan_to_num(mask)
123
+ mask = mask.unsqueeze(0) # [1, h, t, t]
124
+ # TODO: need to handle retention_mask
125
+ # scaling
126
+ value_inner_decay = mask[:, :, -1] / mask[:, :, -1].sum(dim=-1, keepdim=True)
127
+ value_inner_decay = value_inner_decay.unsqueeze(-1)
128
+ scale = mask.sum(dim=-1, keepdim=True).sqrt()
129
+ inner_mask = mask / scale
130
+
131
+ cross_decay = torch.exp2(self.decay * recurrent_chunk_size)
132
+ query_inner_decay = torch.exp2(self.decay[:, None] * (block_index + 1))
133
+ cross_decay = cross_decay[None, :, None, None]
134
+ query_inner_decay = query_inner_decay[None, :, :, None] / (
135
+ scale / mask[:, :, -1].sum(dim=-1)[:, :, None, None])
136
+ # decay_scale (used for kv cache)
137
+ if get_decay_scale:
138
+ decay_scale = self.compute_decay_scale(slen, retention_mask)
139
+ else:
140
+ decay_scale = None
141
+ retention_rel_pos = (
142
+ (sin, cos),
143
+ (
144
+ inner_mask,
145
+ cross_decay,
146
+ query_inner_decay,
147
+ value_inner_decay,
148
+ decay_scale,
149
+ ),
150
+ )
151
+ else: # parallel
152
+ index = torch.arange(slen).to(self.decay)
153
+ sin = torch.sin(index[:, None] * self.angle[None, :])
154
+ cos = torch.cos(index[:, None] * self.angle[None, :])
155
+ mask = torch.tril(torch.ones(slen, slen)).to(self.decay)
156
+ mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
157
+ mask = torch.exp2(mask * self.decay[:, None, None])
158
+ mask = torch.nan_to_num(mask)
159
+ mask = mask.unsqueeze(0) # [1, h, t, t]
160
+ if retention_mask is not None:
161
+ # this is required for left padding
162
+ mask = mask * retention_mask.float().view(-1, 1, 1, slen).to(mask)
163
+
164
+ # scaling
165
+ mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
166
+ mask = torch.nan_to_num(mask, nan=0.0)
167
+ # decay_scale (used for kv cache)
168
+ if get_decay_scale:
169
+ decay_scale = self.compute_decay_scale(slen, retention_mask)
170
+ else:
171
+ decay_scale = None
172
+ # mask processing for intra decay
173
+ if retention_mask is not None:
174
+ max_non_zero = (torch.cumsum(retention_mask, dim=-1).max(dim=-1).indices) # [b,]
175
+ intra_decay = mask[range(mask.shape[0]), :, max_non_zero]
176
+ else:
177
+ intra_decay = mask[:, :, -1]
178
+
179
+ retention_rel_pos = ((sin, cos), (mask, intra_decay, decay_scale))
180
+
181
+ return retention_rel_pos
182
+
183
+ def compute_decay_scale(self, slen, retention_mask=None):
184
+ exponent = torch.arange(slen, device=self.decay.device).float()
185
+ decay_scale = self.decay.exp2().view(-1, 1)**exponent.view(1, -1) # [h, t]
186
+ if retention_mask is not None:
187
+ seqlen = retention_mask.sum(dim=-1) # [b,]
188
+ bsz = seqlen.size(0)
189
+ decay_scale = decay_scale.unsqueeze(0).repeat(bsz, 1, 1) # [b, h, t]
190
+ for i, pos in enumerate(seqlen):
191
+ # the formula for decay_scale is `sum(gamma^i) for i in [0, slen).`
192
+ # Since the retention_mask is 0 for padding, we can set the decay_scale
193
+ # to 0 for the padding positions.
194
+ decay_scale[i, :, pos.item():] = 0
195
+ else:
196
+ bsz = 1
197
+ decay_scale = decay_scale.sum(-1).view(bsz, -1, 1, 1) # [b, h, 1, 1]
198
+ return decay_scale
199
+
200
+
201
+ class MultiScaleRetention(nn.Module):
202
+
203
+ def __init__(
204
+ self,
205
+ config: RetNetConfig,
206
+ gate_fn="swish",
207
+ use_bias=False,
208
+ tensor_parallel=False,
209
+ ):
210
+ super().__init__()
211
+ self.config = config
212
+ self.embed_dim = config.decoder_embed_dim
213
+ self.value_dim = config.decoder_value_embed_dim
214
+ self.num_heads = config.decoder_retention_heads
215
+ self.head_dim = self.value_dim // self.num_heads
216
+ self.key_dim = self.embed_dim // self.num_heads
217
+ self.scaling = self.key_dim**-0.5
218
+
219
+ self.gate_fn = get_activation_fn(activation=str(gate_fn))
220
+
221
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
222
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
223
+ self.v_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
224
+ self.g_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
225
+
226
+ self.out_proj = nn.Linear(self.value_dim, self.embed_dim, bias=use_bias)
227
+
228
+ if config.groupnorm_affine:
229
+ if config.use_rms_norm:
230
+ self.group_norm = RMSNorm(self.head_dim,
231
+ eps=config.layernorm_eps,
232
+ elementwise_affine=config.groupnorm_affine)
233
+ else:
234
+ self.group_norm = LayerNorm(self.head_dim,
235
+ bias=use_bias,
236
+ eps=config.layernorm_eps,
237
+ elementwise_affine=config.groupnorm_affine)
238
+ else:
239
+ self.group_norm = RMSNorm(self.head_dim,
240
+ eps=config.layernorm_eps,
241
+ elementwise_affine=False)
242
+ self.reset_parameters()
243
+
244
+ if tensor_parallel:
245
+ self.decay_proj = nn.Linear(self.num_heads, self.num_heads, bias=False)
246
+ else:
247
+ self.decay_proj = None
248
+
249
+ def reset_parameters(self):
250
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5)
251
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5)
252
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5)
253
+ nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5)
254
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=2**-1)
255
+
256
+ def parallel_retention(self, q, k, v, decay_mask, use_cache=True):
257
+ """
258
+ q, # bsz * num_head * len * qk_dim
259
+ k, # bsz * num_head * len * qk_dim
260
+ v, # bsz * num_head * len * v_dim
261
+ decay_mask, # (1 or bsz) * num_head * len * len
262
+ """
263
+ decay_mask, intra_decay, scale = decay_mask
264
+ # just return retention_rel_pos projected
265
+ # TODO: for shardformer
266
+ if self.decay_proj is not None:
267
+ decay_mask = self.decay_proj(decay_mask.transpose(-1, -3)).transpose(-3, -1)
268
+
269
+ # [b, h, t, t]
270
+ retention = q @ k.transpose(-1, -2) # (scaled dot-product)
271
+ retention = retention * decay_mask
272
+
273
+ # invariant after normalization
274
+ retention = retention / retention.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1,
275
+ max=5e4)
276
+
277
+ output = retention @ v # [b, h, t, v_dim / h]
278
+ output = output.transpose(1, 2) # [b, t, h, v_dim / h]
279
+
280
+ if self.training or not use_cache: # skip cache
281
+ return output, None, retention
282
+
283
+ if self.decay_proj is not None:
284
+ intra_decay = self.decay_proj(intra_decay.transpose(-1, -2)).transpose(-2, -1)
285
+
286
+ # kv cache: [b, h, t, v_dim, qk_dim]
287
+ current_kv = k.unsqueeze(-2) * v.unsqueeze(-1)
288
+ intra_decay = intra_decay[:, :, :, None, None] # [b, h, t, 1, 1]
289
+ current_kv = (current_kv * intra_decay).sum(2) # [b, h, v_dim, qk_dim]
290
+
291
+ cache = {"prev_key_value": current_kv, "scale": scale}
292
+ return output, cache, retention
293
+
294
+ def recurrent_retention(self, q, k, v, decay, past_key_value=None, retention_mask=None):
295
+ """
296
+ q, k, v, # bsz * num_head * 1 * qkv_dim
297
+ past_key_value:
298
+ - "prev_key_value" # bsz * num_head * v_dim * qk_dim
299
+ - "scale" # (1 or bsz) * num_head * 1 * 1
300
+ decay # (1 or bsz) * num_head * 1 * 1
301
+ retention_mask # bsz * 1
302
+ """
303
+ if retention_mask is not None:
304
+ retention_mask = retention_mask.float().view(-1, 1, 1, 1).to(decay)
305
+ else:
306
+ retention_mask = torch.ones(k.size(0), 1, 1, 1).to(decay)
307
+ # (b, h, v_dim, qk_dim)
308
+ current_kv = k * v.transpose(-1, -2) * retention_mask
309
+
310
+ if past_key_value is not None and "prev_key_value" in past_key_value:
311
+ prev_kv = past_key_value["prev_key_value"]
312
+ prev_scale = past_key_value["scale"]
313
+ scale = torch.where(retention_mask == 0, prev_scale, prev_scale * decay + 1)
314
+ # connect prev_kv and current_kv
315
+ # how much to decay prev_kv
316
+ decay_amount = prev_scale.sqrt() * decay / scale.sqrt()
317
+ decay_amount = torch.where(retention_mask == 0, 1, decay_amount)
318
+ prev_kv = prev_kv * decay_amount # decay prev_kv
319
+ current_kv = current_kv / scale.sqrt() # scale current_kv
320
+ current_kv = torch.nan_to_num(current_kv, nan=0.0) # remove nan, scale might be 0
321
+
322
+ current_kv = prev_kv + current_kv
323
+ else:
324
+ scale = torch.ones_like(decay)
325
+ # when retention_mask is 0 at the beginning, setting scale to 1 will
326
+ # make the first retention to use the padding incorrectly. Hence,
327
+ # setting it to 0 here. This is a little ugly, so we might want to
328
+ # change this later. TODO: improve
329
+ scale = torch.where(retention_mask == 0, torch.zeros_like(decay), scale)
330
+
331
+ output = torch.sum(q * current_kv, dim=3).unsqueeze(1) # (b, 1, h, d_v)
332
+
333
+ cache = {"prev_key_value": current_kv, "scale": scale}
334
+ return output, cache
335
+
336
+ def chunkwise_retention(self, q, k, v, decay_mask):
337
+ """
338
+ q, k, v, # bsz * num_head * seqlen * qkv_dim
339
+ past_key_value:
340
+ - "prev_key_value" # bsz * num_head * v_dim * qk_dim
341
+ - "scale" # (1 or bsz) * num_head * 1 * 1
342
+ decay_mask, # 1 * num_head * chunk_size * chunk_size
343
+ cross_decay, # 1 * num_head * 1 * 1
344
+ inner_decay, # 1 * num_head * chunk_size * 1
345
+ """
346
+ # TODO: not working properly
347
+ (
348
+ decay_mask,
349
+ cross_decay,
350
+ query_inner_decay,
351
+ value_inner_decay,
352
+ decay_scale,
353
+ ) = decay_mask
354
+ bsz, _, tgt_len, _ = v.size()
355
+ chunk_len = decay_mask.size(-1)
356
+ assert tgt_len % chunk_len == 0
357
+ num_chunks = tgt_len // chunk_len
358
+
359
+ # [b, n_c, h, t_c, qkv_dim]
360
+ q = q.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
361
+ k = k.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
362
+ v = v.view(bsz, self.num_heads, num_chunks, chunk_len, self.head_dim).transpose(1, 2)
363
+
364
+ k_t = k.transpose(-1, -2)
365
+
366
+ qk_mat = q @ k_t # [b, n_c, h, t_c, t_c]
367
+ qk_mat = qk_mat * decay_mask.unsqueeze(1)
368
+ inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
369
+ qk_mat = qk_mat / inner_scale
370
+ # [b, n_c, h, t_c, v_dim]
371
+ inner_output = torch.matmul(qk_mat, v)
372
+
373
+ # reduce kv in one chunk
374
+ # [b, n_c, h, qk_dim, v_dim]
375
+ kv = k_t @ (v * value_inner_decay)
376
+ # kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
377
+
378
+ kv_recurrent = []
379
+ cross_scale = []
380
+ kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
381
+ kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
382
+
383
+ # accumulate kv by loop
384
+ for i in range(num_chunks):
385
+ kv_recurrent.append(kv_state / kv_scale)
386
+ cross_scale.append(kv_scale)
387
+ kv_state = kv_state * cross_decay + kv[:, i]
388
+ kv_scale = (kv_state.detach().abs().sum(dim=-2, keepdim=True).max(
389
+ dim=-1, keepdim=True).values.clamp(min=1))
390
+
391
+ kv_recurrent = torch.stack(kv_recurrent, dim=1)
392
+ cross_scale = torch.stack(cross_scale, dim=1)
393
+
394
+ all_scale = torch.maximum(inner_scale, cross_scale)
395
+ align_inner_scale = all_scale / inner_scale
396
+ align_cross_scale = all_scale / cross_scale
397
+
398
+ cross_output = (q * query_inner_decay.unsqueeze(1)) @ kv_recurrent
399
+ output = inner_output / align_inner_scale + cross_output / align_cross_scale
400
+ output = output.transpose(2, 3) # [b, n_c, t_c, h, v_dim]
401
+
402
+ cache = {"prev_key_value": kv_state.transpose(-2, -1), "scale": decay_scale}
403
+ return output, cache
404
+
405
+ def forward(
406
+ self,
407
+ hidden_states: torch.Tensor,
408
+ rel_pos: Tuple[Tuple[torch.Tensor]],
409
+ retention_mask: Optional[torch.Tensor] = None,
410
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
411
+ forward_impl: str = "parallel",
412
+ output_retentions: Optional[bool] = False,
413
+ use_cache: Optional[bool] = True,
414
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
415
+ B, T, H = hidden_states.size()
416
+ (sin, cos), decay_mask = rel_pos
417
+ # projections
418
+ q = self.q_proj(hidden_states)
419
+ k = self.k_proj(hidden_states)
420
+ v = self.v_proj(hidden_states)
421
+ g = self.g_proj(hidden_states)
422
+ # multi-head
423
+ q, k, v = split_heads((q, k, v), B, T, self.num_heads)
424
+ k *= self.scaling # for scaled dot product
425
+ # rotate
426
+ # NOTE: theta_shift has bug with mps device.
427
+ n_elem = int(self.config.decoder_embed_dim // self.num_heads *
428
+ self.config.rotary_percentage)
429
+ qr = theta_shift(q[..., :n_elem], sin, cos)
430
+ kr = theta_shift(k[..., :n_elem], sin, cos)
431
+ qr = torch.cat([qr, q[..., n_elem:]], dim=-1)
432
+ kr = torch.cat([kr, k[..., n_elem:]], dim=-1)
433
+
434
+ # qr = theta_shift(q, sin, cos)
435
+ # kr = theta_shift(k, sin, cos)
436
+
437
+ # retention
438
+ if forward_impl == "parallel":
439
+ retention_out, curr_kv, retention_weights = self.parallel_retention(
440
+ qr,
441
+ kr,
442
+ v,
443
+ decay_mask,
444
+ use_cache=use_cache,
445
+ )
446
+ elif forward_impl == "recurrent":
447
+ retention_out, curr_kv = self.recurrent_retention(
448
+ qr,
449
+ kr,
450
+ v,
451
+ decay_mask,
452
+ past_key_value=past_key_value,
453
+ retention_mask=retention_mask,
454
+ )
455
+ elif forward_impl == "chunkwise":
456
+ retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask)
457
+ else:
458
+ raise ValueError(f"forward_impl {forward_impl} not supported.")
459
+
460
+ # concaat heads
461
+ normed = self.group_norm(retention_out).reshape(B, T, self.value_dim)
462
+ # out gate & proj
463
+ out = self.gate_fn(g) * normed
464
+ out = self.out_proj(out)
465
+
466
+ outputs = (out, curr_kv)
467
+ if output_retentions:
468
+ outputs += (retention_weights,) if forward_impl == "parallel" else (None,)
469
+ return outputs
470
+
471
+
472
+ class FeedForwardNetwork(nn.Module):
473
+
474
+ def __init__(
475
+ self,
476
+ embed_dim,
477
+ ffn_dim,
478
+ activation_fn,
479
+ dropout,
480
+ activation_dropout,
481
+ layernorm_eps,
482
+ subln=False,
483
+ use_rms_norm=False,
484
+ ):
485
+ super().__init__()
486
+ self.embed_dim = embed_dim
487
+ self.activation_fn = get_activation_fn(activation=str(activation_fn))
488
+ self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
489
+ self.dropout_module = torch.nn.Dropout(dropout)
490
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
491
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
492
+ if subln:
493
+ if use_rms_norm:
494
+ self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
495
+ else:
496
+ self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
497
+ else:
498
+ self.ffn_layernorm = None
499
+
500
+ def reset_parameters(self):
501
+ self.fc1.reset_parameters()
502
+ self.fc2.reset_parameters()
503
+ if self.ffn_layernorm is not None:
504
+ self.ffn_layernorm.reset_parameters()
505
+
506
+ def forward(self, x):
507
+ x_shape = x.shape
508
+ x = x.reshape(-1, x.size(-1))
509
+ x = self.fc1(x)
510
+ x = self.activation_fn(x.float()).type_as(x)
511
+ x = self.activation_dropout_module(x)
512
+ if self.ffn_layernorm is not None:
513
+ x = self.ffn_layernorm(x)
514
+ x = self.fc2(x)
515
+ x = x.view(x_shape)
516
+ x = self.dropout_module(x)
517
+ return x
518
+
519
+
520
+ class GLU(nn.Module):
521
+
522
+ def __init__(
523
+ self,
524
+ embed_dim,
525
+ ffn_dim,
526
+ activation_fn,
527
+ dropout,
528
+ activation_dropout,
529
+ ):
530
+ super().__init__()
531
+ self.embed_dim = embed_dim
532
+ self.activation_fn = get_activation_fn(activation=str(activation_fn))
533
+ self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
534
+ self.dropout_module = torch.nn.Dropout(dropout)
535
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
536
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
537
+ self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
538
+
539
+ def reset_parameters(self):
540
+ self.fc1.reset_parameters()
541
+ self.fc2.reset_parameters()
542
+ self.gate.reset_parameters()
543
+
544
+ def forward(self, x):
545
+ x_shape = x.shape
546
+ x = x.reshape(-1, x.size(-1))
547
+ g = self.gate(x)
548
+ x = self.fc1(x)
549
+ x = self.activation_fn(x.float()).type_as(x) * g
550
+ x = self.activation_dropout_module(x)
551
+ x = self.fc2(x)
552
+ x = x.view(x_shape)
553
+ x = self.dropout_module(x)
554
+ return x
555
+
556
+
557
+ # Copied from timm.layers.drop.drop_path
558
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
559
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
560
+
561
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is
562
+ misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion:
563
+ https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
564
+ argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
565
+
566
+ """
567
+ if drop_prob == 0.0 or not training:
568
+ return x
569
+ keep_prob = 1 - drop_prob
570
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
571
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
572
+ if keep_prob > 0.0 and scale_by_keep:
573
+ random_tensor.div_(keep_prob)
574
+ return x * random_tensor
575
+
576
+
577
+ class DropPath(nn.Module):
578
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
579
+
580
+ def __init__(self, drop_prob=None):
581
+ super(DropPath, self).__init__()
582
+ self.drop_prob = drop_prob
583
+
584
+ def forward(self, x):
585
+ return drop_path(x, self.drop_prob, self.training)
586
+
587
+ def extra_repr(self):
588
+ return "p={}".format(self.drop_prob)
589
+
590
+
591
+ class RetNetDecoderLayer(nn.Module):
592
+
593
+ def __init__(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False):
594
+ super().__init__()
595
+ self.config = config
596
+ self.embed_dim = config.decoder_embed_dim
597
+ self.dropout_module = torch.nn.Dropout(config.dropout)
598
+
599
+ if config.drop_path_rate > 0:
600
+ drop_path_prob = np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth]
601
+ self.drop_path = DropPath(drop_path_prob)
602
+ else:
603
+ self.drop_path = None
604
+
605
+ self.retention = MultiScaleRetention(config,
606
+ use_bias=config.use_bias,
607
+ tensor_parallel=tensor_parallel)
608
+
609
+ self.normalize_before = config.decoder_normalize_before
610
+
611
+ norm_cls = RMSNorm if config.use_rms_norm else LayerNorm
612
+ self.retention_layer_norm = norm_cls(self.embed_dim, eps=config.layernorm_eps)
613
+
614
+ self.ffn_dim = config.decoder_ffn_embed_dim
615
+
616
+ self.ffn = self.build_ffn()
617
+
618
+ self.final_layer_norm = norm_cls(self.embed_dim, eps=config.layernorm_eps)
619
+
620
+ if config.deepnorm:
621
+ self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
622
+ else:
623
+ self.alpha = 1.0
624
+
625
+ def build_ffn(self):
626
+ if self.config.use_glu:
627
+ return GLU(
628
+ self.embed_dim,
629
+ self.ffn_dim,
630
+ self.config.activation_fn,
631
+ self.config.dropout,
632
+ self.config.activation_dropout,
633
+ )
634
+ else:
635
+ return FeedForwardNetwork(
636
+ self.embed_dim,
637
+ self.ffn_dim,
638
+ self.config.activation_fn,
639
+ self.config.dropout,
640
+ self.config.activation_dropout,
641
+ self.config.layernorm_eps,
642
+ self.config.subln,
643
+ self.config.use_rms_norm,
644
+ )
645
+
646
+ def residual_connection(self, x, residual):
647
+ return residual * self.alpha + x
648
+
649
+ def forward(
650
+ self,
651
+ hidden_states: torch.Tensor,
652
+ retention_rel_pos: Tuple[Tuple[torch.Tensor]],
653
+ retention_mask: Optional[torch.Tensor] = None,
654
+ forward_impl: str = "parallel",
655
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
656
+ output_retentions: Optional[bool] = False,
657
+ use_cache: Optional[bool] = True,
658
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
659
+ x = hidden_states
660
+ residual = hidden_states
661
+ if self.normalize_before:
662
+ hidden_states = self.retention_layer_norm(hidden_states)
663
+
664
+ msr_outs = self.retention(
665
+ hidden_states,
666
+ retention_rel_pos,
667
+ retention_mask=retention_mask,
668
+ past_key_value=past_key_value,
669
+ forward_impl=forward_impl,
670
+ output_retentions=output_retentions,
671
+ use_cache=use_cache,
672
+ )
673
+ hidden_states = msr_outs[0]
674
+ curr_kv = msr_outs[1]
675
+
676
+ hidden_states = self.dropout_module(hidden_states)
677
+
678
+ if self.drop_path is not None:
679
+ hidden_states = self.drop_path(hidden_states)
680
+
681
+ if not self.config.parallel_residual:
682
+ hidden_states = self.residual_connection(hidden_states, residual)
683
+
684
+ if not self.normalize_before:
685
+ hidden_states = self.retention_layer_norm(hidden_states)
686
+
687
+ residual = hidden_states
688
+
689
+ if self.config.parallel_residual:
690
+ # x + residual + mlp(ln2(x))
691
+ hidden_states_path2 = x
692
+ if self.normalize_before:
693
+ hidden_states_path2 = self.final_layer_norm(hidden_states_path2)
694
+
695
+ hidden_states_path2 = self.ffn(hidden_states_path2)
696
+
697
+ if self.drop_path is not None:
698
+ hidden_states_path2 = self.drop_path(hidden_states_path2)
699
+
700
+ if not self.normalize_before:
701
+ hidden_states_path2 = self.final_layer_norm(hidden_states_path2)
702
+ hidden_states = x + residual + hidden_states_path2
703
+ else:
704
+ if self.normalize_before:
705
+ hidden_states = self.final_layer_norm(hidden_states)
706
+
707
+ hidden_states = self.ffn(hidden_states)
708
+
709
+ if self.drop_path is not None:
710
+ hidden_states = self.drop_path(hidden_states)
711
+
712
+ hidden_states = self.residual_connection(hidden_states, residual)
713
+ if not self.normalize_before:
714
+ hidden_states = self.final_layer_norm(hidden_states)
715
+
716
+ outputs = (hidden_states, curr_kv)
717
+
718
+ if output_retentions:
719
+ outputs += (msr_outs[2],)
720
+ return outputs
721
+
722
+
723
+ class RetNetPreTrainedModel(PreTrainedModel):
724
+ # copied from LlamaPretrainedModel
725
+ config_class = RetNetConfig
726
+ base_model_prefix = "model"
727
+ supports_gradient_checkpointing = True
728
+ _no_split_modules = ["RetNetDecoderLayer"]
729
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
730
+
731
+ def _init_weights(self, module):
732
+ """
733
+ Following original retnet, weights are already initialized in their own
734
+ ways within their own init.
735
+ """
736
+ pass
737
+ # below is copied from LlamaPretrainedModel
738
+ # std = self.config.initializer_range
739
+ # if isinstance(module, nn.Linear):
740
+ # module.weight.data.normal_(mean=0.0, std=std)
741
+ # if module.bias is not None:
742
+ # module.bias.data.zero_()
743
+ # elif isinstance(module, nn.Embedding):
744
+ # module.weight.data.normal_(mean=0.0, std=std)
745
+ # if module.padding_idx is not None:
746
+ # module.weight.data[module.padding_idx].zero_()
747
+
748
+
749
+ @dataclass
750
+ class RetNetOutputWithPast(ModelOutput):
751
+ """
752
+ class for RetNet model's outputs that may also contain a past key/values (to speed up sequential decoding).
753
+
754
+ config:
755
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, decoder_embed_dim)`):
756
+ Sequence of hidden-states at the output of the last layer of the model.
757
+
758
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
759
+ decoder_embed_dim)` is output.
760
+ past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
761
+ - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim)
762
+ - "scale": shape=((1 or bsz) * num_head * 1 * 1)
763
+
764
+ Contains pre-computed hidden-states (key and values in the multi-scale retention blocks)
765
+ that can be used (see `past_key_values` input) to speed up sequential decoding.
766
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
767
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
768
+ one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`.
769
+
770
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
771
+ retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`):
772
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
773
+ sequence_length)`.
774
+
775
+ Retentions weights, used for visualization.
776
+
777
+ attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions.
778
+ """
779
+
780
+ last_hidden_state: torch.FloatTensor = None
781
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None
782
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
783
+ retentions: Optional[Tuple[torch.FloatTensor]] = None
784
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
785
+
786
+
787
+ class RetNetModel(RetNetPreTrainedModel):
788
+
789
+ def __init__(
790
+ self,
791
+ config: RetNetConfig,
792
+ embed_tokens: nn.Embedding = None,
793
+ tensor_parallel: bool = False,
794
+ ):
795
+ super().__init__(config)
796
+ self.config = config
797
+
798
+ norm_cls = RMSNorm if config.use_rms_norm else LayerNorm
799
+
800
+ self.dropout_module = torch.nn.Dropout(config.dropout)
801
+
802
+ self.embed_dim = config.decoder_embed_dim
803
+ self.embed_scale = (1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim))
804
+
805
+ if embed_tokens is None:
806
+ embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim,
807
+ config.pad_token_id)
808
+ self.embed_tokens = embed_tokens
809
+
810
+ if config.layernorm_embedding:
811
+ self.layernorm_embedding = norm_cls(self.embed_dim, eps=config.layernorm_eps)
812
+ else:
813
+ self.layernorm_embedding = None
814
+
815
+ self.layers = nn.ModuleList([])
816
+
817
+ for i in range(config.decoder_layers):
818
+ self.layers.append(RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel))
819
+
820
+ self.decoder_layers = len(self.layers)
821
+
822
+ if config.decoder_normalize_before:
823
+ self.layer_norm = norm_cls(self.embed_dim, eps=config.layernorm_eps)
824
+ else:
825
+ self.layer_norm = None
826
+
827
+ self.retnet_rel_pos = RetNetRelPos(config)
828
+ self.recurrent_chunk_size = config.recurrent_chunk_size
829
+
830
+ if config.deepnorm:
831
+ init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
832
+ for name, p in self.named_parameters():
833
+ if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name):
834
+ p.data.div_(init_scale)
835
+
836
+ if config.subln and not config.use_glu:
837
+ init_scale = math.sqrt(math.log(config.decoder_layers * 2))
838
+ for name, p in self.named_parameters():
839
+ if ("fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name):
840
+ p.data.mul_(init_scale)
841
+
842
+ self.gradient_checkpointing = False
843
+ self.post_init()
844
+
845
+ def get_input_embeddings(self):
846
+ return self.embed_tokens
847
+
848
+ def set_input_embeddings(self, value):
849
+ self.embed_tokens = value
850
+
851
+ def forward_embedding(
852
+ self,
853
+ input_ids,
854
+ forward_impl,
855
+ inputs_embeds=None,
856
+ past_key_values=None,
857
+ ):
858
+ # if past_key_values is not None:
859
+ if forward_impl == "recurrent":
860
+ input_ids = input_ids[:, -1:]
861
+
862
+ if inputs_embeds is None:
863
+ inputs_embeds = self.embed_tokens(input_ids)
864
+
865
+ embed = self.embed_scale * inputs_embeds
866
+
867
+ if self.layernorm_embedding is not None:
868
+ embed = self.layernorm_embedding(embed)
869
+
870
+ embed = self.dropout_module(embed)
871
+
872
+ return embed
873
+
874
+ def forward(
875
+ self,
876
+ input_ids: torch.LongTensor = None,
877
+ retention_mask: Optional[torch.Tensor] = None,
878
+ attention_mask: Optional[torch.Tensor] = None,
879
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None,
880
+ inputs_embeds: Optional[torch.FloatTensor] = None,
881
+ output_retentions: Optional[bool] = None,
882
+ output_attentions: Optional[bool] = None,
883
+ output_hidden_states: Optional[bool] = None,
884
+ use_cache: Optional[bool] = None,
885
+ return_dict: Optional[bool] = None,
886
+ forward_impl: Optional[str] = "parallel",
887
+ recurrent_chunk_size: Optional[int] = None,
888
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
889
+ ) -> Union[Tuple, RetNetOutputWithPast]:
890
+ if output_retentions is None and output_attentions is not None:
891
+ output_retentions = output_attentions
892
+ output_retentions = (output_retentions
893
+ if output_retentions is not None else self.config.output_retentions)
894
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else
895
+ self.config.output_hidden_states)
896
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
897
+
898
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
899
+
900
+ # retrieve input_ids and inputs_embeds
901
+ if input_ids is not None and inputs_embeds is not None:
902
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
903
+ elif input_ids is not None:
904
+ batch_size, seq_length = input_ids.shape
905
+ elif inputs_embeds is not None:
906
+ batch_size, seq_length, _ = inputs_embeds.shape
907
+ else:
908
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
909
+
910
+ # embed tokens
911
+ if inputs_embeds is None:
912
+ inputs_embeds = self.forward_embedding(input_ids, forward_impl, inputs_embeds,
913
+ past_key_values)
914
+
915
+ if retention_mask is None and attention_mask is not None:
916
+ retention_mask = attention_mask
917
+ if retention_mask is not None and forward_impl == "recurrent":
918
+ retention_mask = retention_mask[:, -1:]
919
+
920
+ hidden_states = inputs_embeds
921
+
922
+ # handling chunking here
923
+ if recurrent_chunk_size is None:
924
+ recurrent_chunk_size = self.recurrent_chunk_size
925
+ need_pad_for_chunkwise = (forward_impl == "chunkwise" and
926
+ seq_length % recurrent_chunk_size != 0)
927
+ if need_pad_for_chunkwise:
928
+ padding_len = recurrent_chunk_size - seq_length % recurrent_chunk_size
929
+ slen = seq_length + padding_len
930
+ hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len))
931
+ else:
932
+ slen = seq_length
933
+ # relative position
934
+ if retention_rel_pos is None:
935
+ retention_rel_pos = self.retnet_rel_pos(
936
+ slen,
937
+ forward_impl=forward_impl,
938
+ recurrent_chunk_size=recurrent_chunk_size,
939
+ retention_mask=retention_mask,
940
+ get_decay_scale=not self.training,
941
+ )
942
+
943
+ # start running through the decoder layers
944
+ all_hidden_states = () if output_hidden_states else None
945
+ all_retentions = () if output_retentions else None
946
+ # layers * [bsz, num_head, qk_dim, decoder_embed_dim]
947
+ next_decoder_cache = () if use_cache else None
948
+
949
+ for idx, layer in enumerate(self.layers):
950
+ if output_hidden_states:
951
+ all_hidden_states += (hidden_states,)
952
+ past_key_value = (past_key_values[idx] if past_key_values is not None else None)
953
+
954
+ if self.gradient_checkpointing and self.training:
955
+
956
+ def create_custom_forward(module):
957
+
958
+ def custom_forward(*inputs):
959
+ return module(*inputs, output_retentions, use_cache)
960
+
961
+ return custom_forward
962
+
963
+ layer_outputs = torch.utils.checkpoint.checkpoint(
964
+ create_custom_forward(layer),
965
+ hidden_states,
966
+ retention_rel_pos,
967
+ retention_mask,
968
+ forward_impl,
969
+ past_key_value,
970
+ )
971
+ else:
972
+ layer_outputs = layer(
973
+ hidden_states,
974
+ retention_rel_pos,
975
+ retention_mask=retention_mask,
976
+ forward_impl=forward_impl,
977
+ past_key_value=past_key_value,
978
+ output_retentions=output_retentions,
979
+ use_cache=use_cache,
980
+ )
981
+
982
+ hidden_states = layer_outputs[0]
983
+
984
+ if use_cache:
985
+ next_decoder_cache += (layer_outputs[1],)
986
+
987
+ if output_retentions:
988
+ all_retentions += (layer_outputs[2],)
989
+
990
+ next_cache = next_decoder_cache if use_cache else None
991
+
992
+ if need_pad_for_chunkwise:
993
+ hidden_states = hidden_states[:, :seq_length, :]
994
+
995
+ if self.layer_norm is not None:
996
+ hidden_states = self.layer_norm(hidden_states)
997
+
998
+ # add hidden states from the last decoder layer
999
+ if output_hidden_states:
1000
+ all_hidden_states += (hidden_states,)
1001
+
1002
+ if not return_dict:
1003
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_retentions]
1004
+ if v is not None)
1005
+ return RetNetOutputWithPast(
1006
+ last_hidden_state=hidden_states,
1007
+ past_key_values=next_cache,
1008
+ hidden_states=all_hidden_states,
1009
+ retentions=all_retentions,
1010
+ attentions=all_retentions,
1011
+ )
1012
+
1013
+
1014
+ @dataclass
1015
+ class RetNetCausalLMOutputWithPast(ModelOutput):
1016
+ """
1017
+ class for RetNet causal language model (or autoregressive) outputs.
1018
+
1019
+ config:
1020
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1021
+ Language modeling loss (for next-token prediction).
1022
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1023
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1024
+ past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1025
+ - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim)
1026
+ - "scale": shape=((1 or bsz) * num_head * 1 * 1)
1027
+
1028
+ Contains pre-computed hidden-states (key and values in the multi-scale retention blocks)
1029
+ that can be used (see `past_key_values` input) to speed up sequential decoding.
1030
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1031
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1032
+ one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`.
1033
+
1034
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1035
+ retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`):
1036
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1037
+ sequence_length)`.
1038
+
1039
+ Retentions weights, used for visualization.
1040
+
1041
+ attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions.
1042
+ """
1043
+
1044
+ loss: Optional[torch.FloatTensor] = None
1045
+ logits: torch.FloatTensor = None
1046
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None
1047
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1048
+ retentions: Optional[Tuple[torch.FloatTensor]] = None
1049
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1050
+
1051
+
1052
+ class RetNetForCausalLM(RetNetPreTrainedModel):
1053
+
1054
+ def __init__(
1055
+ self,
1056
+ config: RetNetConfig,
1057
+ embed_tokens: nn.Embedding = None,
1058
+ tensor_parallel: bool = False,
1059
+ ) -> None:
1060
+ super().__init__(config)
1061
+ self.model = RetNetModel(config, embed_tokens=embed_tokens, tensor_parallel=tensor_parallel)
1062
+ self.lm_head = nn.Linear(config.decoder_embed_dim, config.vocab_size, bias=False)
1063
+ # init here
1064
+ torch.nn.init.normal_(self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5)
1065
+
1066
+ self.post_init()
1067
+
1068
+ def get_input_embeddings(self):
1069
+ return self.model.embed_tokens
1070
+
1071
+ def set_input_embeddings(self, value):
1072
+ self.model.embed_tokens = value
1073
+
1074
+ def get_output_embeddings(self):
1075
+ return self.lm_head
1076
+
1077
+ def set_output_embeddings(self, new_embeddings):
1078
+ self.lm_head = new_embeddings
1079
+
1080
+ def set_decoder(self, decoder):
1081
+ self.model = decoder
1082
+
1083
+ def get_decoder(self):
1084
+ return self.model
1085
+
1086
+ def forward(
1087
+ self,
1088
+ input_ids: torch.LongTensor = None,
1089
+ retention_mask: Optional[torch.Tensor] = None,
1090
+ attention_mask: Optional[torch.Tensor] = None,
1091
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1092
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1093
+ labels: Optional[torch.LongTensor] = None,
1094
+ use_cache: Optional[bool] = None,
1095
+ output_retentions: Optional[bool] = None,
1096
+ output_attentions: Optional[bool] = None,
1097
+ output_hidden_states: Optional[bool] = None,
1098
+ return_dict: Optional[bool] = None,
1099
+ forward_impl: Optional[str] = None,
1100
+ recurrent_chunk_size: Optional[int] = None,
1101
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
1102
+ ) -> Union[Tuple, RetNetCausalLMOutputWithPast]:
1103
+ if output_retentions is None and output_attentions is not None:
1104
+ output_retentions = output_attentions
1105
+ output_retentions = (output_retentions
1106
+ if output_retentions is not None else self.config.output_retentions)
1107
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else
1108
+ self.config.output_hidden_states)
1109
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1110
+ forward_impl = (forward_impl if forward_impl is not None else self.config.forward_impl)
1111
+ recurrent_chunk_size = (recurrent_chunk_size if recurrent_chunk_size is not None else
1112
+ self.config.recurrent_chunk_size)
1113
+
1114
+ if retention_mask is None and attention_mask is not None:
1115
+ retention_mask = attention_mask
1116
+
1117
+ outputs = self.model(
1118
+ input_ids,
1119
+ retention_mask=retention_mask,
1120
+ past_key_values=past_key_values,
1121
+ inputs_embeds=inputs_embeds,
1122
+ output_retentions=output_retentions,
1123
+ output_hidden_states=output_hidden_states,
1124
+ return_dict=return_dict,
1125
+ forward_impl=forward_impl,
1126
+ use_cache=use_cache,
1127
+ recurrent_chunk_size=recurrent_chunk_size,
1128
+ retention_rel_pos=retention_rel_pos,
1129
+ )
1130
+
1131
+ hidden_states = outputs[0]
1132
+ logits = self.lm_head(hidden_states)
1133
+
1134
+ loss = None
1135
+ if labels is not None:
1136
+ # Shift so that tokens < n predict n
1137
+ shift_logits = logits[..., :-1, :].contiguous()
1138
+ shift_labels = labels[..., 1:].contiguous()
1139
+ # Flatten the tokens
1140
+ loss_fct = nn.CrossEntropyLoss()
1141
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1142
+ shift_labels = shift_labels.view(-1)
1143
+ # Enable model parallelism
1144
+ shift_labels = shift_labels.to(shift_logits.device)
1145
+ loss = loss_fct(shift_logits, shift_labels)
1146
+
1147
+ if self.config.z_loss_coeff > 0:
1148
+ # z_loss from PaLM paper
1149
+ # z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits))
1150
+ z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean()
1151
+ loss += self.config.z_loss_coeff * z_loss
1152
+
1153
+ if not return_dict:
1154
+ output = (logits,) + outputs[1:]
1155
+ return (loss,) + output if loss is not None else output
1156
+
1157
+ return RetNetCausalLMOutputWithPast(
1158
+ loss=loss,
1159
+ logits=logits,
1160
+ past_key_values=outputs.past_key_values,
1161
+ hidden_states=outputs.hidden_states,
1162
+ retentions=outputs.retentions,
1163
+ attentions=outputs.retentions,
1164
+ )
1165
+
1166
+ def _crop_past_key_values(model, past_key_values, maximum_length):
1167
+ """Since retnet's kv do not have length, no need to crop. Just return"""
1168
+ return past_key_values
1169
+
1170
+ def prepare_inputs_for_generation(
1171
+ self,
1172
+ input_ids,
1173
+ past_key_values=None,
1174
+ attention_mask=None,
1175
+ inputs_embeds=None,
1176
+ **kwargs,
1177
+ ):
1178
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1179
+ if inputs_embeds is not None and past_key_values is None:
1180
+ model_inputs = {"inputs_embeds": inputs_embeds}
1181
+ else:
1182
+ model_inputs = {"input_ids": input_ids}
1183
+
1184
+ forward_impl = kwargs.get("forward_impl", "parallel")
1185
+ if past_key_values is not None:
1186
+ forward_impl = "recurrent"
1187
+
1188
+ model_inputs.update({
1189
+ "past_key_values": past_key_values,
1190
+ "use_cache": kwargs.get("use_cache"),
1191
+ "attention_mask": attention_mask,
1192
+ "forward_impl": forward_impl,
1193
+ })
1194
+ return model_inputs
1195
+
1196
+ @staticmethod
1197
+ def _reorder_cache(past_key_values, beam_idx):
1198
+ reordered_past = ()
1199
+ for layer_past in past_key_values: # dict
1200
+ layer_past_kv = layer_past["prev_key_value"] # [b, h, v_dim / h, qk_dim]
1201
+ layer_past_scale = layer_past["scale"] # [b, h, 1, 1]
1202
+ if layer_past_scale.size(0) > 1:
1203
+ # this means that retention_mask is not None, so the scale for
1204
+ # each batch is different. We need to select the correct scale then.
1205
+ # NOTE: during huggingface generate, it will generate attention_mask
1206
+ # if it is None, so this linke will always be true. Still, having
1207
+ # this line here for safety.
1208
+ layer_past_scale = layer_past_scale.index_select(0, beam_idx)
1209
+ reordered_past += ({
1210
+ "prev_key_value": layer_past_kv.index_select(0, beam_idx),
1211
+ "scale": layer_past_scale,
1212
+ },)
1213
+ return reordered_past
1214
+
1215
+ def sample_token(self, logit, do_sample=False, top_k=1, top_p=1.0, temperature=1.0):
1216
+ if not do_sample:
1217
+ return torch.argmax(logit, dim=-1, keepdim=True)
1218
+ filtered = top_k_top_p_filtering(logit / temperature, top_k=top_k, top_p=top_p)
1219
+ return torch.multinomial(torch.softmax(filtered, dim=-1), num_samples=1)
1220
+
1221
+ @torch.inference_mode()
1222
+ def custom_generate(
1223
+ self,
1224
+ input_ids: torch.LongTensor = None,
1225
+ retention_mask: Optional[torch.Tensor] = None,
1226
+ attention_mask: Optional[torch.Tensor] = None,
1227
+ parallel_compute_prompt=True,
1228
+ max_new_tokens=20,
1229
+ bos_token_id=0,
1230
+ eos_token_id=0,
1231
+ do_sample=False,
1232
+ top_k=0,
1233
+ top_p=1.0,
1234
+ temperature=1.0,
1235
+ early_stopping=True,
1236
+ ):
1237
+ if retention_mask is None and attention_mask is not None:
1238
+ retention_mask = attention_mask
1239
+
1240
+ if input_ids is not None:
1241
+ if input_ids.shape[1] == 1:
1242
+ past_key_values = None
1243
+ elif parallel_compute_prompt:
1244
+ ret_mask = (retention_mask[:, :-1] if retention_mask is not None else None)
1245
+ outputs = self(
1246
+ input_ids[:, :-1],
1247
+ retention_mask=ret_mask,
1248
+ forward_impl="parallel",
1249
+ return_dict=True,
1250
+ use_cache=True,
1251
+ )
1252
+ past_key_values = outputs.past_key_values
1253
+ else:
1254
+ past_key_values = None
1255
+ for p_i in range(input_ids.shape[1] - 1):
1256
+ ret_mask = (retention_mask[:, :p_i + 1] if retention_mask is not None else None)
1257
+ outputs = self(
1258
+ input_ids[:, :p_i + 1],
1259
+ retention_mask=ret_mask,
1260
+ forward_impl="recurrent",
1261
+ past_key_values=past_key_values,
1262
+ return_dict=True,
1263
+ use_cache=True,
1264
+ )
1265
+ past_key_values = outputs.past_key_values
1266
+
1267
+ generated = input_ids
1268
+ else:
1269
+ generated = torch.tensor([[bos_token_id]]).to(self.lm_head.weight.device)
1270
+ past_key_values = None
1271
+
1272
+ for i in range(max_new_tokens):
1273
+ outputs = self(
1274
+ generated,
1275
+ retention_mask=retention_mask,
1276
+ forward_impl="recurrent",
1277
+ past_key_values=past_key_values,
1278
+ use_cache=True,
1279
+ return_dict=True,
1280
+ )
1281
+ logit = outputs.logits[:, -1, :] # [batch_size, vocab_size]
1282
+ past_key_values = outputs.past_key_values
1283
+ token = self.sample_token(
1284
+ logit,
1285
+ do_sample=do_sample,
1286
+ top_k=top_k,
1287
+ top_p=top_p,
1288
+ temperature=temperature,
1289
+ )
1290
+ generated = torch.cat([generated, token], dim=-1)
1291
+ if retention_mask is not None:
1292
+ retention_mask = torch.cat([retention_mask, torch.ones_like(token)], dim=-1)
1293
+ if early_stopping and (token == eos_token_id).all():
1294
+ break
1295
+ return generated
1296
+
1297
+
1298
+ class RetNetForSequenceClassification(RetNetPreTrainedModel):
1299
+
1300
+ def __init__(self, config, tensor_parallel=False):
1301
+ super().__init__(config)
1302
+ self.num_labels = config.num_labels
1303
+ self.model = RetNetModel(config, tensor_parallel=tensor_parallel)
1304
+ self.score = nn.Linear(config.decoder_embed_dim, self.num_labels, bias=False)
1305
+
1306
+ # Initialize weights and apply final processing
1307
+ self.post_init()
1308
+
1309
+ def get_input_embeddings(self):
1310
+ return self.model.embed_tokens
1311
+
1312
+ def set_input_embeddings(self, value):
1313
+ self.model.embed_tokens = value
1314
+
1315
+ def forward(
1316
+ self,
1317
+ input_ids: torch.LongTensor = None,
1318
+ retention_mask: Optional[torch.Tensor] = None,
1319
+ attention_mask: Optional[torch.Tensor] = None,
1320
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1321
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1322
+ labels: Optional[torch.LongTensor] = None,
1323
+ use_cache: Optional[bool] = None,
1324
+ output_retentions: Optional[bool] = None,
1325
+ output_attentions: Optional[bool] = None,
1326
+ output_hidden_states: Optional[bool] = None,
1327
+ return_dict: Optional[bool] = None,
1328
+ forward_impl: Optional[str] = None,
1329
+ recurrent_chunk_size: Optional[int] = None,
1330
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
1331
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1332
+ if output_retentions is None and output_attentions is not None:
1333
+ output_retentions = output_attentions
1334
+ output_retentions = (output_retentions
1335
+ if output_retentions is not None else self.config.output_retentions)
1336
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else
1337
+ self.config.output_hidden_states)
1338
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1339
+ forward_impl = (forward_impl if forward_impl is not None else self.config.forward_impl)
1340
+ recurrent_chunk_size = (recurrent_chunk_size if recurrent_chunk_size is not None else
1341
+ self.config.recurrent_chunk_size)
1342
+
1343
+ if retention_mask is None and attention_mask is not None:
1344
+ retention_mask = attention_mask
1345
+
1346
+ outputs = self.model(
1347
+ input_ids,
1348
+ retention_mask=retention_mask,
1349
+ past_key_values=past_key_values,
1350
+ inputs_embeds=inputs_embeds,
1351
+ output_retentions=output_retentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ forward_impl=forward_impl,
1355
+ use_cache=use_cache,
1356
+ recurrent_chunk_size=recurrent_chunk_size,
1357
+ retention_rel_pos=retention_rel_pos,
1358
+ )
1359
+
1360
+ hidden_states = outputs[0]
1361
+ logits = self.score(hidden_states)
1362
+
1363
+ if input_ids is not None:
1364
+ batch_size = input_ids.shape[0]
1365
+ else:
1366
+ batch_size = inputs_embeds.shape[0]
1367
+
1368
+ if self.config.pad_token_id is None and batch_size != 1:
1369
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1370
+ if self.config.pad_token_id is None:
1371
+ sequence_lengths = -1
1372
+ else:
1373
+ if input_ids is not None:
1374
+ sequence_lengths = (
1375
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1376
+ logits.device)
1377
+ else:
1378
+ sequence_lengths = -1
1379
+
1380
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1381
+
1382
+ loss = None
1383
+ if labels is not None:
1384
+ labels = labels.to(logits.device)
1385
+ if self.config.problem_type is None:
1386
+ if self.num_labels == 1:
1387
+ self.config.problem_type = "regression"
1388
+ elif self.num_labels > 1 and (labels.dtype == torch.long or
1389
+ labels.dtype == torch.int):
1390
+ self.config.problem_type = "single_label_classification"
1391
+ else:
1392
+ self.config.problem_type = "multi_label_classification"
1393
+
1394
+ if self.config.problem_type == "regression":
1395
+ loss_fct = MSELoss()
1396
+ if self.num_labels == 1:
1397
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1398
+ else:
1399
+ loss = loss_fct(pooled_logits, labels)
1400
+ elif self.config.problem_type == "single_label_classification":
1401
+ loss_fct = CrossEntropyLoss()
1402
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1403
+ elif self.config.problem_type == "multi_label_classification":
1404
+ loss_fct = BCEWithLogitsLoss()
1405
+ loss = loss_fct(pooled_logits, labels)
1406
+ if not return_dict:
1407
+ output = (pooled_logits,) + outputs[1:]
1408
+ return ((loss,) + output) if loss is not None else output
1409
+
1410
+ return SequenceClassifierOutputWithPast(
1411
+ loss=loss,
1412
+ logits=pooled_logits,
1413
+ past_key_values=outputs.past_key_values,
1414
+ hidden_states=outputs.hidden_states,
1415
+ attentions=outputs.attentions,
1416
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ad607419888f703b2130ed20e6abf3153467b0879b40beeb13cc10edf9c24bb
3
+ size 1721828130