windsornguyen commited on
Commit
0f28fb5
·
verified ·
1 Parent(s): b0544cb

Upload FlashSTU

Browse files
Files changed (4) hide show
  1. config.json +3 -3
  2. config.py +5 -0
  3. model.py +7 -6
  4. model.safetensors +3 -0
config.json CHANGED
@@ -11,9 +11,9 @@
11
  "dropout": 0.0,
12
  "mlp_scale": 12,
13
  "model_type": "FlashSTU",
14
- "n_embd": 2304,
15
- "n_heads": 9,
16
- "n_layers": 7,
17
  "num_eigh": 24,
18
  "seq_len": 8192,
19
  "softcap": 50.0,
 
11
  "dropout": 0.0,
12
  "mlp_scale": 12,
13
  "model_type": "FlashSTU",
14
+ "n_embd": 256,
15
+ "n_heads": 2,
16
+ "n_layers": 2,
17
  "num_eigh": 24,
18
  "seq_len": 8192,
19
  "softcap": 50.0,
config.py CHANGED
@@ -1,5 +1,8 @@
 
 
1
  from transformers import PretrainedConfig
2
 
 
3
  class FlashSTUConfig(PretrainedConfig):
4
  model_type = "FlashSTU"
5
 
@@ -20,6 +23,7 @@ class FlashSTUConfig(PretrainedConfig):
20
  use_flash_fft: bool = True,
21
  use_approx: bool = True,
22
  softcap: float = 50.0,
 
23
  **kwargs,
24
  ):
25
  super().__init__(**kwargs)
@@ -38,3 +42,4 @@ class FlashSTUConfig(PretrainedConfig):
38
  self.use_flash_fft = use_flash_fft
39
  self.use_approx = use_approx
40
  self.softcap = softcap
 
 
1
+ import torch
2
+
3
  from transformers import PretrainedConfig
4
 
5
+
6
  class FlashSTUConfig(PretrainedConfig):
7
  model_type = "FlashSTU"
8
 
 
23
  use_flash_fft: bool = True,
24
  use_approx: bool = True,
25
  softcap: float = 50.0,
26
+ torch_dtype: torch.dtype = torch.bfloat16,
27
  **kwargs,
28
  ):
29
  super().__init__(**kwargs)
 
42
  self.use_flash_fft = use_flash_fft
43
  self.use_approx = use_approx
44
  self.softcap = softcap
45
+ self.torch_dtype = torch_dtype
model.py CHANGED
@@ -33,19 +33,20 @@ class Block(nn.Module):
33
  def __init__(self, config, phi, n) -> None:
34
  super(Block, self).__init__()
35
  # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
36
- self.rn_1 = RMSNorm(config.n_embd)
37
  self.stu = STU(config, phi, n)
38
- self.rn_2 = RMSNorm(config.n_embd)
39
  self.attn = Attention(config)
40
- self.rn_3 = RMSNorm(config.n_embd)
41
  self.mlp = MLP(
42
  config.n_embd,
43
  config.n_embd * config.mlp_scale,
44
  activation=F.silu, # Use SwiGLU
45
  bias1=config.bias,
46
  bias2=config.bias,
47
- ) if triton_mlp else MLP(config)
48
- self.rn_4 = RMSNorm(config.n_embd)
 
49
 
50
  def forward(self, x: torch.Tensor) -> torch.Tensor:
51
  x = x + self.stu(self.rn_1(x))
@@ -84,7 +85,7 @@ class FlashSTU(PreTrainedModel):
84
  for _ in range(self.n_layers)
85
  ]
86
  ),
87
- rn_f=RMSNorm(config.n_embd)
88
  )
89
  )
90
  self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias)
 
33
  def __init__(self, config, phi, n) -> None:
34
  super(Block, self).__init__()
35
  # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
36
+ self.rn_1 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
37
  self.stu = STU(config, phi, n)
38
+ self.rn_2 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
39
  self.attn = Attention(config)
40
+ self.rn_3 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
41
  self.mlp = MLP(
42
  config.n_embd,
43
  config.n_embd * config.mlp_scale,
44
  activation=F.silu, # Use SwiGLU
45
  bias1=config.bias,
46
  bias2=config.bias,
47
+ dtype=config.torch_dtype,
48
+ ) if triton_mlp else MLP(config, dtype=config.torch_dtype)
49
+ self.rn_4 = RMSNorm(config.n_embd, dtype=config.torch_dtype)
50
 
51
  def forward(self, x: torch.Tensor) -> torch.Tensor:
52
  x = x + self.stu(self.rn_1(x))
 
85
  for _ in range(self.n_layers)
86
  ]
87
  ),
88
+ rn_f=RMSNorm(config.n_embd, dtype=config.torch_dtype)
89
  )
90
  )
91
  self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95020e80f8dc983d77cbd36edbf46090d71628617c31e9fce4aedf6fbc79e74e
3
+ size 420811608