SmerkyG commited on
Commit
bde54d0
·
verified ·
1 Parent(s): 4651be4

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Rwkv7ForCausalLM"
4
+ ],
5
+ "attention_hidden_size": 2048,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_rwkv7.Rwkv7Config",
8
+ "AutoModelForCausalLM": "modeling_rwkv7.Rwkv7ForCausalLM"
9
+ },
10
+ "bos_token_id": 0,
11
+ "eos_token_id": 0,
12
+ "head_size": 64,
13
+ "hidden_size": 2048,
14
+ "intermediate_size": null,
15
+ "layer_norm_epsilon": 1e-05,
16
+ "lora_rank_decay": null,
17
+ "lora_rank_gate": null,
18
+ "lora_rank_iclr": null,
19
+ "lora_rank_value_residual_mix": null,
20
+ "model_type": "rwkv7",
21
+ "num_hidden_layers": 24,
22
+ "tie_word_embeddings": false,
23
+ "transformers_version": "4.46.2",
24
+ "use_cache": true,
25
+ "vocab_size": 50304
26
+ }
configuration_rwkv7.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ RWKV configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ RWKV7_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ class Rwkv7Config(PretrainedConfig):
28
+ """
29
+ This is the configuration class to store the configuration of a [`Rwkv7Model`]. It is used to instantiate a RWKV7
30
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the RWVK-7
32
+ [RWKV/v7-Goose-1.6B-Pile-HF](https://huggingface.co/RWKV/v7-Goose-1.6B-Pile-HF) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 65536):
40
+ Vocabulary size of the RWKV7 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`Rwkv7Model`].
42
+ hidden_size (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the embeddings and hidden states.
44
+ num_hidden_layers (`int`, *optional*, defaults to 24):
45
+ Number of hidden layers in the model.
46
+ attention_hidden_size (`int`, *optional*):
47
+ Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
48
+ num_attention_heads (`int`, *optional*, defaults to 64):
49
+ The attention heads to use in rwkv7 self_attention module.
50
+ head_size (`int`, *optional*, defaults to 64): head_size of rwkv7 self_attention module.
51
+ intermediate_size (`int`, *optional*):
52
+ Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
54
+ The epsilon to use in the layer normalization layers.
55
+ bos_token_id (`int`, *optional*, defaults to 0):
56
+ The id of the beginning of sentence token in the vocabulary. Defaults to 0.
57
+ eos_token_id (`int`, *optional*, defaults to 0):
58
+ The id of the end of sentence token in the vocabulary. Defaults to 0.
59
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
60
+ Whether or not to tie the word embeddings with the input token embeddings.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last state.
63
+
64
+
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import Rwkv7Config, Rwkv7Model
69
+
70
+ >>> # Initializing a Rwkv7 configuration
71
+ >>> configuration = Rwkv7Config()
72
+
73
+ >>> # Initializing a model (with random weights) from the configuration
74
+ >>> model = Rwkv7Model(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "rwkv7"
81
+
82
+ def __init__(
83
+ self,
84
+ vocab_size=65536,
85
+ hidden_size=768,
86
+ num_hidden_layers=24,
87
+ attention_hidden_size=None,
88
+ head_size=64,
89
+ intermediate_size=None,
90
+ lora_rank_decay=None,
91
+ lora_rank_iclr=None,
92
+ lora_rank_value_residual_mix=None,
93
+ lora_rank_gate=None,
94
+ layer_norm_epsilon=1e-5,
95
+ bos_token_id=0,
96
+ eos_token_id=0,
97
+ tie_word_embeddings=False,
98
+ use_cache=True,
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.hidden_size = hidden_size
103
+ self.num_hidden_layers = num_hidden_layers
104
+ self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
105
+ self.head_size = head_size
106
+ self.intermediate_size = intermediate_size
107
+ self.lora_rank_decay = lora_rank_decay
108
+ self.lora_rank_iclr = lora_rank_iclr
109
+ self.lora_rank_value_residual_mix = lora_rank_value_residual_mix
110
+ self.lora_rank_gate = lora_rank_gate
111
+ self.layer_norm_epsilon = layer_norm_epsilon
112
+ self.use_cache = use_cache
113
+
114
+ super().__init__(
115
+ tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
116
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01d1b9c811799e33f691a4628d5662c02c309dc3f718e052ccbdfb5ef9c8ea3c
3
+ size 2930113328
modeling_rwkv7.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The RWKV team and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RWKV7 World model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ from pathlib import Path
21
+
22
+ import math
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
+
29
+ from transformers.modeling_utils import PreTrainedModel, GenerationMixin, _init_weights
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_ninja_available,
36
+ is_torch_cuda_available,
37
+ logging,
38
+ )
39
+
40
+ from .configuration_rwkv7 import Rwkv7Config
41
+
42
+ # MIT License
43
+
44
+ # Copyright (c) 2024 Songlin Yang
45
+
46
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
47
+ # of this software and associated documentation files (the "Software"), to deal
48
+ # in the Software without restriction, including without limitation the rights
49
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
50
+ # copies of the Software, and to permit persons to whom the Software is
51
+ # furnished to do so, subject to the following conditions:
52
+
53
+ # The above copyright notice and this permission notice shall be included in all
54
+ # copies or substantial portions of the Software.
55
+
56
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
57
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
58
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
59
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
60
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
61
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
62
+ # SOFTWARE.
63
+
64
+ # Copyright (c) 2024, Johan Sokrates Wind
65
+
66
+ import torch as th
67
+ import triton
68
+ import triton.language as tl
69
+
70
+ @triton.jit
71
+ def IND4(a,b,c,d,nb,nc,nd):
72
+ return ((a*nb+b)*nc+c)*nd+d
73
+ @triton.jit
74
+ def IND5(a,b,c,d,e,nb,nc,nd,ne):
75
+ return (((a*nb+b)*nc+c)*nd+d)*ne+e
76
+
77
+ @triton.jit
78
+ def _prod(a,b): return a*b
79
+
80
+ # inv(I-A) where A is a strictly lower triangular nxn matrix
81
+ @triton.jit
82
+ def tri_minv(A, n:tl.constexpr, prec:tl.constexpr):
83
+ i = tl.arange(0,n)
84
+ prod = (i[None,:]==i[:,None]).to(tl.float32)
85
+ for j in range(n-1):
86
+ prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans())
87
+ return prod.trans()
88
+
89
+ @triton.jit
90
+ def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
91
+ bi = tl.program_id(1)
92
+ hi = tl.program_id(0)
93
+
94
+ i = tl.arange(0,C)[None,:]
95
+ state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
96
+ for t0 in range(T//dT):
97
+ t = t0*dT+tl.arange(0,dT)[:,None]
98
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
99
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
100
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
101
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
102
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
103
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
104
+
105
+ w = (-sw.exp()).exp()
106
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
107
+ incl_pref = tl.cumprod(w,axis=0)
108
+ non_incl_pref = incl_pref / w
109
+ inv_incl_pref = 1 / incl_pref
110
+
111
+ wq = sq * incl_pref
112
+ wa = sa * non_incl_pref
113
+ kwi = sk * inv_incl_pref
114
+ bwi = sb * inv_incl_pref
115
+
116
+ mask1 = (t > t.trans())
117
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
118
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
119
+
120
+ ab_inv = tri_minv(ab, dT, prec)
121
+
122
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
123
+ u = tl_dot(prec, ab_inv, ab_u)
124
+ mask2 = (t >= t.trans())
125
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
126
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
127
+ yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans())
128
+ tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16))
129
+
130
+ tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32))
131
+ state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw)
132
+ tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16))
133
+
134
+ @triton.jit
135
+ def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
136
+ bi = tl.program_id(1)
137
+ hi = tl.program_id(0)
138
+
139
+ i = tl.arange(0,C)[None,:]
140
+ dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
141
+
142
+ for t0 in range(T//dT-1,-1,-1):
143
+ t = t0*dT+tl.arange(0,dT)[:,None]
144
+
145
+ state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32)
146
+
147
+ sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
148
+ sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
149
+ sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
150
+ sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
151
+ sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
152
+ sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
153
+ sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
154
+
155
+ dw_fac = -sw.exp()
156
+ w = dw_fac.exp()
157
+ fw = tl.reduce(w, 0, _prod, keep_dims=True)
158
+ incl_pref = tl.cumprod(w,axis=0)
159
+ non_incl_pref = incl_pref / w
160
+ inv_incl_pref = 1 / incl_pref
161
+
162
+ wq = sq * incl_pref
163
+ wa = sa * non_incl_pref
164
+ kwi = sk * inv_incl_pref
165
+ bwi = sb * inv_incl_pref
166
+
167
+ mask1 = (t > t.trans())
168
+ ab = tl_dot(prec, wa, bwi.trans()) * mask1
169
+ ak = tl_dot(prec, wa, kwi.trans()) * mask1
170
+
171
+ ab_inv = tri_minv(ab, dT, prec)
172
+
173
+ ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
174
+ u = tl_dot(prec, ab_inv, ab_u)
175
+ mask2 = (t >= t.trans())
176
+ qk = tl_dot(prec, wq, kwi.trans()) * mask2
177
+ qb = tl_dot(prec, wq, bwi.trans()) * mask2
178
+
179
+ du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans())
180
+ dab_u = tl_dot(prec, ab_inv.trans(), du)
181
+
182
+ dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u)
183
+ tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16))
184
+
185
+ dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1
186
+ dak = tl_dot(prec, dab_u, sv.trans()) * mask1
187
+ dab_u_state = tl_dot(prec, dab_u, state)
188
+ da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state)
189
+ tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16))
190
+
191
+ dqb = tl_dot(prec, sdy, u.trans()) * mask2
192
+ dqk = tl_dot(prec, sdy, sv.trans()) * mask2
193
+ dy_state = tl_dot(prec, sdy, state)
194
+ dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state)
195
+ tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16))
196
+
197
+ fw_u_dstate = fw * tl_dot(prec, u, dstate)
198
+ db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate)
199
+ tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16))
200
+
201
+ fw_v_dstate = fw * tl_dot(prec, sv, dstate)
202
+ dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate)
203
+ tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16))
204
+
205
+ dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True)
206
+ for k in range(t0*dT,t0*dT+dT):
207
+ lmask = (t<k).trans()
208
+ A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k)
209
+ A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k)
210
+ A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k)
211
+ A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k)
212
+ dw = tl.sum(A, axis=0,keep_dims=True) + dw0
213
+
214
+ wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32)
215
+ dw *= -wk.exp()
216
+ tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16))
217
+
218
+ dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa)
219
+ tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16))
220
+
221
+
222
+ class TritonRWKV7(th.autograd.Function):
223
+ @staticmethod
224
+ def forward(ctx, w,q,k,v,z,b,s0, dot_prec):
225
+ K = 16
226
+ B,T,H,C = w.shape
227
+ s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0
228
+ y = th.empty_like(v)
229
+ sT = th.empty_like(s0)
230
+ s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device)
231
+ fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec)
232
+ ctx.dot_prec = dot_prec
233
+ ctx.save_for_backward(w,q,k,v,z,b,s)
234
+ return y, sT
235
+ @staticmethod
236
+ def backward(ctx, dy, dsT):
237
+ K = 16
238
+ w,q,k,v,z,b,s = ctx.saved_tensors
239
+ B,T,H,C = w.shape
240
+ dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]]
241
+ bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec)
242
+ return dw,dq,dk,dv,dz,db,ds0,None
243
+
244
+ @triton.jit
245
+ def tl_dot(prec:tl.constexpr, a, b) -> torch.Tensor:
246
+ if prec == 'fp32':
247
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False)
248
+ elif prec == 'tf32':
249
+ return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True)
250
+ elif prec == 'bf16':
251
+ return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True)
252
+ else:
253
+ tl.static_assert(False)
254
+
255
+ def rwkv7_attn_triton(r,w,k,v,a,b, HEAD_SIZE, dot_prec = 'fp32'):
256
+ B,T,HC = w.shape
257
+ C = HEAD_SIZE
258
+ H = HC//C
259
+ r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]]
260
+ s0 = th.zeros(B,H,C,C, dtype=th.bfloat16,device=w.device)
261
+ return TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec)[0].view(B,T,HC)
262
+
263
+ logger = logging.get_logger(__name__)
264
+
265
+ _CHECKPOINT_FOR_DOC = "RWKV/v7-Goose-1.6B-Pile-HF"
266
+ _CONFIG_FOR_DOC = "Rwkv7Config"
267
+
268
+ class Rwkv7SelfAttention(nn.Module):
269
+ def __init__(self, config, layer_id=0):
270
+ super().__init__()
271
+ self.config = config
272
+ self.layer_id = layer_id
273
+ C = hidden_size = config.hidden_size
274
+ attention_hidden_size = config.attention_hidden_size
275
+ self.attention_hidden_size = attention_hidden_size
276
+ H = self.num_heads = attention_hidden_size // config.head_size
277
+ N = self.head_size = config.head_size
278
+
279
+ calc_lora_rank = lambda exponent, multiplier: max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
280
+ lora_rank_decay = config.lora_rank_decay or calc_lora_rank(0.5, 1.8)
281
+ lora_rank_iclr = config.lora_rank_iclr or calc_lora_rank(0.5, 1.8)
282
+ lora_rank_value_residual_mix = config.lora_rank_value_residual_mix or calc_lora_rank(0.5, 1.3)
283
+ lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
284
+
285
+ self.x_r = nn.Parameter(torch.empty(1,1,C))
286
+ self.x_w = nn.Parameter(torch.empty(1,1,C))
287
+ self.x_k = nn.Parameter(torch.empty(1,1,C))
288
+ self.x_v = nn.Parameter(torch.empty(1,1,C))
289
+ self.x_a = nn.Parameter(torch.empty(1,1,C))
290
+ self.x_g = nn.Parameter(torch.empty(1,1,C))
291
+
292
+ self.w0 = nn.Parameter(torch.empty(1,1,C))
293
+ self.w1 = nn.Parameter(torch.empty(C, lora_rank_decay))
294
+ self.w2 = nn.Parameter(torch.empty(lora_rank_decay, C))
295
+
296
+ self.a0 = nn.Parameter(torch.empty(1,1,C))
297
+ self.a1 = nn.Parameter(torch.empty(C, lora_rank_iclr))
298
+ self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, C))
299
+
300
+ if layer_id > 0:
301
+ self.v0 = nn.Parameter(torch.empty(1,1,C))
302
+ self.v1 = nn.Parameter(torch.empty(C, lora_rank_value_residual_mix))
303
+ self.v2 = nn.Parameter(torch.empty(lora_rank_value_residual_mix, C))
304
+
305
+ self.g1 = nn.Parameter(torch.empty(C, lora_rank_gate))
306
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, C))
307
+
308
+ self.k_k = nn.Parameter(torch.empty(1,1,C))
309
+ self.k_a = nn.Parameter(torch.empty(1,1,C))
310
+ self.r_k = nn.Parameter(torch.empty(H,N))
311
+
312
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
313
+ self.receptance = nn.Linear(C, C, bias=False)
314
+ self.key = nn.Linear(C, C, bias=False)
315
+ self.value = nn.Linear(C, C, bias=False)
316
+ self.output = nn.Linear(C, C, bias=False)
317
+ self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
318
+
319
+
320
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True):
321
+ # Mix hidden with the previous timestep to produce key, value, receptance
322
+ if hidden.size(1) == 1 and state is not None:
323
+ shifted = state[0][self.layer_id]
324
+ else:
325
+ shifted = self.time_shift(hidden)
326
+ if state is not None:
327
+ shifted[:, 0] = state[0][self.layer_id]
328
+ if len(shifted.size()) == 2:
329
+ shifted = shifted.unsqueeze(1)
330
+
331
+ x = hidden
332
+
333
+ B, T, C = hidden.shape
334
+ H = self.num_heads
335
+ N = self.head_size
336
+
337
+ xx = shifted - x
338
+
339
+ xr = x+xx*self.x_r
340
+ xw = x+xx*self.x_w
341
+ xk = x+xx*self.x_k
342
+ xv = x+xx*self.x_v
343
+ xa = x+xx*self.x_a
344
+ xg = x+xx*self.x_g
345
+
346
+ r = self.receptance(xr)
347
+ w = torch.tanh(xw @ self.w1) @ self.w2
348
+ k = self.key(xk)
349
+ v = self.value(xv)
350
+ a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
351
+ g = torch.sigmoid(xg @ self.g1) @ self.g2
352
+
353
+ kk = torch.nn.functional.normalize((k * self.k_k).view(B,T,H,-1), dim=-1, p=2.0).view(B,T,-1)
354
+ k = k * (1 + (a-1) * self.k_a)
355
+ if self.layer_id == 0: v_first = v
356
+ else: v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)
357
+
358
+ if T == 1 or not self.training:
359
+ w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
360
+ vk_state = state[1][self.layer_id]
361
+ for t in range(T):
362
+ r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t]
363
+ vk = v_.view(B,H,N,1) @ k_.view(B,H,1,N)
364
+ ab = (-kk_).view(B,H,N,1) @ (kk_*a_).view(B,H,1,N)
365
+ vk_state = vk_state * w_.view(B,H,1,N) + vk_state @ ab.float() + vk.float()
366
+ xx[:,t] = (vk_state.to(dtype=x.dtype) @ r_.view(B,H,N,1)).view(B,H*N)
367
+ state[1][self.layer_id] = vk_state
368
+ # FIXME - support fast triton kernel for non-training pre-fill with state in and out
369
+ else:
370
+ w = -torch.nn.functional.softplus(-(self.w0 + w)) - 0.5
371
+ rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
372
+
373
+ xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
374
+ xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
375
+ xx = self.output(xx * g)
376
+
377
+ if state is not None:
378
+ state[0][self.layer_id] = hidden[:, -1]
379
+
380
+ return xx, state, v_first
381
+
382
+
383
+ class Rwkv7FeedForward(nn.Module):
384
+ def __init__(self, config, layer_id=0):
385
+ super().__init__()
386
+ self.config = config
387
+ self.layer_id = layer_id
388
+ hidden_size = config.hidden_size
389
+ intermediate_size = (
390
+ config.intermediate_size
391
+ if config.intermediate_size is not None
392
+ else int(config.hidden_size * 4)
393
+ )
394
+
395
+
396
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
397
+
398
+ self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size))
399
+
400
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
401
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
402
+
403
+ def forward(self, hidden, state=None):
404
+ if hidden.size(1) == 1 and state is not None:
405
+ shifted = state[2][self.layer_id]
406
+ else:
407
+ shifted = self.time_shift(hidden)
408
+ if state is not None:
409
+ shifted[:, 0] = state[2][self.layer_id]
410
+ if len(shifted.size()) == 2:
411
+ shifted = shifted.unsqueeze(1)
412
+
413
+ delta_hidden_to_shifted = shifted - hidden
414
+ key = hidden + delta_hidden_to_shifted * self.x_k
415
+
416
+ key = torch.square(torch.relu(self.key(key)))
417
+ value = self.value(key)
418
+
419
+ if state is not None:
420
+ state[2][self.layer_id] = hidden[:, -1]
421
+
422
+ return value, state
423
+
424
+
425
+ class Rwkv7Block(nn.Module):
426
+ def __init__(self, config, layer_id):
427
+ super().__init__()
428
+ self.config = config
429
+ self.layer_id = layer_id
430
+
431
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
432
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
433
+
434
+ self.attention = Rwkv7SelfAttention(config, layer_id)
435
+ self.feed_forward = Rwkv7FeedForward(config, layer_id)
436
+
437
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True):
438
+ attention, state, v_first = self.attention(self.ln1(hidden), state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode)
439
+ hidden = hidden + attention
440
+
441
+ feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
442
+ hidden = hidden + feed_forward
443
+
444
+ outputs = (hidden, state, v_first)
445
+ if output_attentions:
446
+ outputs += (attention,)
447
+ else:
448
+ outputs += (None,)
449
+
450
+ return outputs
451
+
452
+
453
+ class Rwkv7PreTrainedModel(PreTrainedModel):
454
+ """
455
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
456
+ models.
457
+ """
458
+
459
+ config_class = Rwkv7Config
460
+ base_model_prefix = "rwkv7"
461
+ _no_split_modules = ["Rwkv7Block"]
462
+ _keep_in_fp32_modules = []
463
+ supports_gradient_checkpointing = True
464
+
465
+ def _init_weights(self, module):
466
+ return
467
+
468
+ """Initialize the weights."""
469
+ if isinstance(module, Rwkv7SelfAttention):
470
+ layer_id = module.layer_id
471
+ num_hidden_layers = module.config.num_hidden_layers
472
+ hidden_size = module.config.hidden_size
473
+ attention_hidden_size = module.attention_hidden_size
474
+ head_size = module.config.head_size
475
+ num_heads = attention_hidden_size // head_size
476
+
477
+ ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
478
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
479
+
480
+ time_weight = torch.tensor(
481
+ [i / hidden_size for i in range(hidden_size)],
482
+ dtype=module.x_k.dtype,
483
+ device=module.x_k.device,
484
+ )
485
+ time_weight = time_weight[None, None, :]
486
+
487
+ decay_speed = [
488
+ -7.0 + 5.0 * (n / (attention_hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5)
489
+ for n in range(attention_hidden_size)
490
+ ]
491
+ decay_speed = torch.tensor(decay_speed, dtype=module.w0.dtype, device=module.w0.device)
492
+
493
+ with torch.no_grad():
494
+ module.x_r.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
495
+ module.x_w.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
496
+ module.x_k.copy_( 1.0 - (torch.pow(time_weight, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1) )
497
+ module.x_v.copy_( 1.0 - (torch.pow(time_weight, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1) )
498
+ module.x_a.copy_( 1.0 - torch.pow(time_weight, 0.9 * ratio_1_to_almost0) )
499
+ module.x_g.copy_( 1.0 - torch.pow(time_weight, 0.2 * ratio_1_to_almost0) )
500
+
501
+ def ortho_init(x, scale):
502
+ with torch.no_grad():
503
+ shape = x.shape
504
+ if len(shape) == 2:
505
+ gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
506
+ nn.init.orthogonal_(x, gain=gain * scale)
507
+ elif len(shape) == 3:
508
+ gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
509
+ for i in range(shape[0]):
510
+ nn.init.orthogonal_(x[i], gain=gain * scale)
511
+ else:
512
+ assert False
513
+ return x
514
+
515
+ module.w0.copy_(decay_speed.reshape(1,1,attention_hidden_size) + 0.5) # !!! 0.5 comes from F.softplus !!!
516
+ module.w1.zero_()
517
+ ortho_init(module.w2, 0.1)
518
+
519
+ module.a0.zero_()
520
+ module.a1.zero_()
521
+ ortho_init(module.a2, 0.1)
522
+
523
+ module.v0.copy_(1.0)
524
+ module.v1.zero_()
525
+ ortho_init(module.v2, 0.1)
526
+
527
+ module.g1.zero_()
528
+ ortho_init(module.g2, 0.1)
529
+
530
+ self.k_k.copy_(0.85)
531
+ self.k_a.copy_(1.0)
532
+ self.r_k.zero_()
533
+
534
+ module.receptance.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
535
+ module.key.weight.data.uniform_(-0.05/(hidden_size**0.5), 0.05/(attention_hidden_size**0.5))
536
+ module.value.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(attention_hidden_size**0.5))
537
+ module.output.weight.data.zero_()
538
+
539
+ elif isinstance(module, Rwkv7FeedForward):
540
+ layer_id = module.layer_id
541
+ num_hidden_layers = module.config.num_hidden_layers
542
+ hidden_size = module.config.hidden_size
543
+
544
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
545
+
546
+ time_weight = torch.tensor(
547
+ [i / hidden_size for i in range(hidden_size)],
548
+ dtype=module.x_k.dtype,
549
+ device=module.x_k.device,
550
+ )
551
+ time_weight = time_weight[None, None, :]
552
+
553
+ with torch.no_grad():
554
+ module.x_k.copy_( 1.0 - torch.pow(time_weight, ratio_1_to_almost0**4) )
555
+
556
+ self.key.weight.data.uniform_(-0.5/(hidden_size**0.5), 0.5/(hidden_size**0.5))
557
+ self.value.weight.data.zero_()
558
+
559
+ @dataclass
560
+ class Rwkv7Output(ModelOutput):
561
+ """
562
+ Class for the RWKV model outputs.
563
+ Args:
564
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
565
+ Sequence of hidden-states at the output of the last layer of the model.
566
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
567
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
568
+ avoid providing the old `input_ids`.
569
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
570
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
571
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
572
+ the model at the output of each layer plus the optional initial embedding outputs.
573
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
574
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
575
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
576
+ the self-attention heads.
577
+ """
578
+
579
+ last_hidden_state: torch.FloatTensor = None
580
+ state: Optional[List[torch.FloatTensor]] = None
581
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
582
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
583
+
584
+
585
+ @dataclass
586
+ class Rwkv7CausalLMOutput(ModelOutput):
587
+ """
588
+ Base class for causal language model (or autoregressive) outputs.
589
+ Args:
590
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
591
+ Language modeling loss (for next-token prediction).
592
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
593
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
594
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
595
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
596
+ avoid providing the old `input_ids`.
597
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
598
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
599
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
600
+ the model at the output of each layer plus the optional initial embedding outputs.
601
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
602
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
603
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
604
+ the self-attention heads.
605
+ """
606
+
607
+ loss: Optional[torch.FloatTensor] = None
608
+ logits: torch.FloatTensor = None
609
+ state: Optional[List[torch.FloatTensor]] = None
610
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
611
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
612
+
613
+
614
+ RWKV7_START_DOCSTRING = r"""
615
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
616
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
617
+ etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
618
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
619
+ general usage and behavior.
620
+ Parameters:
621
+ config ([`Rwkv7Config`]): Model configuration class with all the parameters of the model.
622
+ Initializing with a config file does not load the weights associated with the model, only the
623
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
624
+ """
625
+
626
+ RWKV7_INPUTS_DOCSTRING = r"""
627
+ Args:
628
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
629
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
630
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
631
+ sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
632
+ past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
633
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
634
+ IDs?](../glossary#input-ids)
635
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
636
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
637
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
638
+ model's internal embedding lookup matrix.
639
+ state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
640
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
641
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
642
+ use_cache (`bool`, *optional*):
643
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
644
+ output_attentions (`bool`, *optional*):
645
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
646
+ tensors for more detail.
647
+ output_hidden_states (`bool`, *optional*):
648
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
649
+ more detail.
650
+ return_dict (`bool`, *optional*):
651
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
652
+ """
653
+
654
+
655
+ @add_start_docstrings(
656
+ "The bare RWKV7 Model transformer outputting raw hidden-states without any specific head on top.",
657
+ RWKV7_START_DOCSTRING,
658
+ )
659
+ class Rwkv7Model(Rwkv7PreTrainedModel):
660
+ def __init__(self, config):
661
+ super().__init__(config)
662
+
663
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
664
+ self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
665
+ self.blocks = nn.ModuleList([Rwkv7Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
666
+ self.ln_out = nn.LayerNorm(config.hidden_size)
667
+
668
+ self.gradient_checkpointing = False
669
+
670
+ # Initialize weights and apply final processing
671
+ self.post_init()
672
+
673
+ def get_input_embeddings(self):
674
+ return self.embeddings
675
+
676
+ def set_input_embeddings(self, new_embeddings):
677
+ self.embeddings = new_embeddings
678
+
679
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
680
+ @add_code_sample_docstrings(
681
+ checkpoint=_CHECKPOINT_FOR_DOC,
682
+ output_type=Rwkv7Output,
683
+ config_class=_CONFIG_FOR_DOC,
684
+ )
685
+ def forward(
686
+ self,
687
+ input_ids: Optional[torch.LongTensor] = None,
688
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
689
+ inputs_embeds: Optional[torch.FloatTensor] = None,
690
+ state: Optional[List[torch.FloatTensor]] = None,
691
+ use_cache: Optional[bool] = None,
692
+ output_attentions: Optional[bool] = None,
693
+ output_hidden_states: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ ) -> Union[Tuple, Rwkv7Output]:
696
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
697
+ output_hidden_states = (
698
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
699
+ )
700
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
701
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
702
+
703
+ if input_ids is not None and inputs_embeds is not None:
704
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
705
+ elif input_ids is None and inputs_embeds is None:
706
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
707
+
708
+ if inputs_embeds is None:
709
+ inputs_embeds = self.embeddings(input_ids)
710
+
711
+ if state is None:
712
+ state = []
713
+ head_size = self.config.head_size
714
+ num_heads = self.config.attention_hidden_size // head_size
715
+ state_attn_x = torch.zeros(
716
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
717
+ dtype=inputs_embeds.dtype,
718
+ requires_grad=False,
719
+ device=inputs_embeds.device,
720
+ ).contiguous()
721
+ state_attn_vk = torch.zeros(
722
+ (
723
+ self.config.num_hidden_layers,
724
+ inputs_embeds.size(0),
725
+ num_heads,
726
+ head_size,
727
+ head_size,
728
+ ),
729
+ dtype=torch.float32,
730
+ requires_grad=False,
731
+ device=inputs_embeds.device,
732
+ ).contiguous()
733
+ state_ffn_x = torch.zeros(
734
+ (self.config.num_hidden_layers, inputs_embeds.size(0), self.config.hidden_size),
735
+ dtype=inputs_embeds.dtype,
736
+ requires_grad=False,
737
+ device=inputs_embeds.device,
738
+ ).contiguous()
739
+ state.append(state_attn_x)
740
+ state.append(state_attn_vk)
741
+ state.append(state_ffn_x)
742
+
743
+ seq_mode = inputs_embeds.shape[1] > 1
744
+ hidden_states = self.pre_ln(inputs_embeds)
745
+ v_first = None
746
+
747
+ all_self_attentions = () if output_attentions else None
748
+ all_hidden_states = () if output_hidden_states else None
749
+ for idx, block in enumerate(self.blocks):
750
+ hidden_states, state, v_first, attentions = block(
751
+ hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode
752
+ )
753
+
754
+ if output_hidden_states:
755
+ all_hidden_states = all_hidden_states + (hidden_states,)
756
+
757
+ if output_attentions:
758
+ all_self_attentions = all_self_attentions + (attentions,)
759
+
760
+ hidden_states = self.ln_out(hidden_states)
761
+
762
+ if output_hidden_states:
763
+ all_hidden_states = all_hidden_states + (hidden_states,)
764
+
765
+ if not return_dict:
766
+ return (hidden_states, state, all_hidden_states, all_self_attentions)
767
+
768
+ return Rwkv7Output(
769
+ last_hidden_state=hidden_states,
770
+ state=state,
771
+ hidden_states=all_hidden_states, # None
772
+ attentions=all_self_attentions, # None
773
+ )
774
+
775
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
776
+ @add_start_docstrings(
777
+ """
778
+ The RWKV7 Model transformer with a language modeling head on top (linear layer with weights tied to the input
779
+ embeddings).
780
+ """,
781
+ RWKV7_START_DOCSTRING,
782
+ )
783
+ class Rwkv7ForCausalLM(Rwkv7PreTrainedModel, GenerationMixin):
784
+ _tied_weights_keys = ["head.weight"]
785
+
786
+ def __init__(self, config):
787
+ super().__init__(config)
788
+ self.model = Rwkv7Model(config)
789
+ self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
790
+
791
+ # Initialize weights and apply final processing
792
+ self.post_init()
793
+
794
+ def get_output_embeddings(self):
795
+ return self.head
796
+
797
+ def set_output_embeddings(self, new_embeddings):
798
+ self.head = new_embeddings
799
+
800
+ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
801
+ # only last token for inputs_ids if the state is passed along.
802
+ if state is not None:
803
+ input_ids = input_ids[:, -1].unsqueeze(-1)
804
+
805
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
806
+ if inputs_embeds is not None and state is None:
807
+ model_inputs = {"inputs_embeds": inputs_embeds}
808
+ else:
809
+ model_inputs = {"input_ids": input_ids}
810
+
811
+ model_inputs["state"] = state
812
+ return model_inputs
813
+
814
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
815
+ @add_code_sample_docstrings(
816
+ checkpoint=_CHECKPOINT_FOR_DOC,
817
+ output_type=Rwkv7CausalLMOutput,
818
+ config_class=_CONFIG_FOR_DOC,
819
+ )
820
+ def forward(
821
+ self,
822
+ input_ids: Optional[torch.LongTensor] = None,
823
+ attention_mask: Optional[torch.LongTensor] = None,
824
+ inputs_embeds: Optional[torch.FloatTensor] = None,
825
+ state: Optional[List[torch.FloatTensor]] = None,
826
+ labels: Optional[torch.LongTensor] = None,
827
+ use_cache: Optional[bool] = None,
828
+ output_attentions: Optional[bool] = None,
829
+ output_hidden_states: Optional[bool] = None,
830
+ return_dict: Optional[bool] = None,
831
+ ) -> Union[Tuple, Rwkv7CausalLMOutput]:
832
+ r"""
833
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
834
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
835
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
836
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
837
+ """
838
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
839
+
840
+ outputs = self.model(
841
+ input_ids,
842
+ inputs_embeds=inputs_embeds,
843
+ state=state,
844
+ use_cache=use_cache,
845
+ output_attentions=output_attentions,
846
+ output_hidden_states=output_hidden_states,
847
+ return_dict=return_dict,
848
+ )
849
+ hidden_states = outputs[0]
850
+
851
+ logits = self.head(hidden_states)
852
+
853
+ loss = None
854
+ if labels is not None:
855
+ # move labels to correct device to enable model parallelism
856
+ labels = labels.to(logits.device)
857
+ # Shift so that tokens < n predict n
858
+ shift_logits = logits[..., :-1, :].contiguous()
859
+ shift_labels = labels[..., 1:].contiguous()
860
+ # Flatten the tokens
861
+ loss_fct = CrossEntropyLoss()
862
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
863
+
864
+ if not return_dict:
865
+ output = (logits,) + outputs[1:]
866
+ return ((loss,) + output) if loss is not None else output
867
+
868
+ return Rwkv7CausalLMOutput(
869
+ loss=loss,
870
+ logits=logits,
871
+ state=outputs.state,
872
+ hidden_states=outputs.hidden_states,
873
+ attentions=outputs.attentions,
874
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": false,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }