Upload FlashSTU
Browse files- config.json +3 -3
- config.py +5 -0
- model.py +7 -6
- model.safetensors +3 -0
config.json
CHANGED
@@ -11,9 +11,9 @@
|
|
11 |
"dropout": 0.0,
|
12 |
"mlp_scale": 12,
|
13 |
"model_type": "FlashSTU",
|
14 |
-
"n_embd":
|
15 |
-
"n_heads":
|
16 |
-
"n_layers":
|
17 |
"num_eigh": 24,
|
18 |
"seq_len": 8192,
|
19 |
"softcap": 50.0,
|
|
|
11 |
"dropout": 0.0,
|
12 |
"mlp_scale": 12,
|
13 |
"model_type": "FlashSTU",
|
14 |
+
"n_embd": 256,
|
15 |
+
"n_heads": 2,
|
16 |
+
"n_layers": 2,
|
17 |
"num_eigh": 24,
|
18 |
"seq_len": 8192,
|
19 |
"softcap": 50.0,
|
config.py
CHANGED
@@ -1,5 +1,8 @@
|
|
|
|
|
|
1 |
from transformers import PretrainedConfig
|
2 |
|
|
|
3 |
class FlashSTUConfig(PretrainedConfig):
|
4 |
model_type = "FlashSTU"
|
5 |
|
@@ -20,6 +23,7 @@ class FlashSTUConfig(PretrainedConfig):
|
|
20 |
use_flash_fft: bool = True,
|
21 |
use_approx: bool = True,
|
22 |
softcap: float = 50.0,
|
|
|
23 |
**kwargs,
|
24 |
):
|
25 |
super().__init__(**kwargs)
|
@@ -38,3 +42,4 @@ class FlashSTUConfig(PretrainedConfig):
|
|
38 |
self.use_flash_fft = use_flash_fft
|
39 |
self.use_approx = use_approx
|
40 |
self.softcap = softcap
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
from transformers import PretrainedConfig
|
4 |
|
5 |
+
|
6 |
class FlashSTUConfig(PretrainedConfig):
|
7 |
model_type = "FlashSTU"
|
8 |
|
|
|
23 |
use_flash_fft: bool = True,
|
24 |
use_approx: bool = True,
|
25 |
softcap: float = 50.0,
|
26 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
27 |
**kwargs,
|
28 |
):
|
29 |
super().__init__(**kwargs)
|
|
|
42 |
self.use_flash_fft = use_flash_fft
|
43 |
self.use_approx = use_approx
|
44 |
self.softcap = softcap
|
45 |
+
self.torch_dtype = torch_dtype
|
model.py
CHANGED
@@ -33,19 +33,20 @@ class Block(nn.Module):
|
|
33 |
def __init__(self, config, phi, n) -> None:
|
34 |
super(Block, self).__init__()
|
35 |
# For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
|
36 |
-
self.rn_1 = RMSNorm(config.n_embd)
|
37 |
self.stu = STU(config, phi, n)
|
38 |
-
self.rn_2 = RMSNorm(config.n_embd)
|
39 |
self.attn = Attention(config)
|
40 |
-
self.rn_3 = RMSNorm(config.n_embd)
|
41 |
self.mlp = MLP(
|
42 |
config.n_embd,
|
43 |
config.n_embd * config.mlp_scale,
|
44 |
activation=F.silu, # Use SwiGLU
|
45 |
bias1=config.bias,
|
46 |
bias2=config.bias,
|
47 |
-
|
48 |
-
|
|
|
49 |
|
50 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51 |
x = x + self.stu(self.rn_1(x))
|
@@ -84,7 +85,7 @@ class FlashSTU(PreTrainedModel):
|
|
84 |
for _ in range(self.n_layers)
|
85 |
]
|
86 |
),
|
87 |
-
rn_f=RMSNorm(config.n_embd)
|
88 |
)
|
89 |
)
|
90 |
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias)
|
|
|
33 |
def __init__(self, config, phi, n) -> None:
|
34 |
super(Block, self).__init__()
|
35 |
# For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
|
36 |
+
self.rn_1 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
|
37 |
self.stu = STU(config, phi, n)
|
38 |
+
self.rn_2 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
|
39 |
self.attn = Attention(config)
|
40 |
+
self.rn_3 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
|
41 |
self.mlp = MLP(
|
42 |
config.n_embd,
|
43 |
config.n_embd * config.mlp_scale,
|
44 |
activation=F.silu, # Use SwiGLU
|
45 |
bias1=config.bias,
|
46 |
bias2=config.bias,
|
47 |
+
dtype=config.torch_dtype,
|
48 |
+
) if triton_mlp else MLP(config, dtype=config.torch_dtype)
|
49 |
+
self.rn_4 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
|
50 |
|
51 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
52 |
x = x + self.stu(self.rn_1(x))
|
|
|
85 |
for _ in range(self.n_layers)
|
86 |
]
|
87 |
),
|
88 |
+
rn_f=RMSNorm(config.n_embd, dtype=config.torch_dtype)
|
89 |
)
|
90 |
)
|
91 |
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95020e80f8dc983d77cbd36edbf46090d71628617c31e9fce4aedf6fbc79e74e
|
3 |
+
size 420811608
|