Commit
·
5a1cdf2
1
Parent(s):
cd36468
initial commit
Browse files- __init__.py +2 -0
- __pycache__/__init__.cpython-312.pyc +0 -0
- __pycache__/attn.cpython-312.pyc +0 -0
- __pycache__/configuration_ministu.cpython-312.pyc +0 -0
- __pycache__/convolve.cpython-312.pyc +0 -0
- __pycache__/filters.cpython-312.pyc +0 -0
- __pycache__/layers.cpython-312.pyc +0 -0
- __pycache__/mlp.cpython-312.pyc +0 -0
- __pycache__/modeling_ministu.cpython-312.pyc +0 -0
- __pycache__/modules.cpython-312.pyc +0 -0
- __pycache__/stu.cpython-312.pyc +0 -0
- __pycache__/utils.cpython-312.pyc +0 -0
- added_tokens.json +3 -0
- attn.py +125 -0
- config.json +55 -0
- configuration_ministu.py +50 -0
- convolve.py +84 -0
- filters.py +106 -0
- layers.py +102 -0
- merges.txt +0 -0
- mlp.py +27 -0
- model.safetensors +3 -0
- modeling_ministu.py +248 -0
- modules.py +6 -0
- results.txt +0 -0
- special_tokens_map.json +23 -0
- stu.py +87 -0
- tokenizer.json +0 -0
- tokenizer_config.json +27 -0
- utils.py +105 -0
- vocab.json +0 -0
__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .configuration_ministu import MiniSTUConfig
|
2 |
+
from .modeling_ministu import MiniSTU
|
__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (261 Bytes). View file
|
|
__pycache__/attn.cpython-312.pyc
ADDED
Binary file (7.99 kB). View file
|
|
__pycache__/configuration_ministu.cpython-312.pyc
ADDED
Binary file (2.22 kB). View file
|
|
__pycache__/convolve.cpython-312.pyc
ADDED
Binary file (5.44 kB). View file
|
|
__pycache__/filters.cpython-312.pyc
ADDED
Binary file (4.71 kB). View file
|
|
__pycache__/layers.cpython-312.pyc
ADDED
Binary file (5.53 kB). View file
|
|
__pycache__/mlp.cpython-312.pyc
ADDED
Binary file (2.67 kB). View file
|
|
__pycache__/modeling_ministu.cpython-312.pyc
ADDED
Binary file (12.2 kB). View file
|
|
__pycache__/modules.cpython-312.pyc
ADDED
Binary file (291 Bytes). View file
|
|
__pycache__/stu.cpython-312.pyc
ADDED
Binary file (4.49 kB). View file
|
|
__pycache__/utils.cpython-312.pyc
ADDED
Binary file (5.89 kB). View file
|
|
added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<|endofprompt|>": 200018
|
3 |
+
}
|
attn.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .utils import nearest_power_of_two
|
8 |
+
|
9 |
+
try:
|
10 |
+
from flash_attn import flash_attn_func as fa2
|
11 |
+
except ImportError as e:
|
12 |
+
print(
|
13 |
+
f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
|
14 |
+
)
|
15 |
+
# TODO: Add FlexAttention + local attention mask when it's in stable release
|
16 |
+
|
17 |
+
class Attention(nn.Module):
|
18 |
+
def __init__(self, config):
|
19 |
+
super(Attention, self).__init__()
|
20 |
+
if isinstance(config.torch_dtype, str):
|
21 |
+
torch_dtype = getattr(torch, config.torch_dtype)
|
22 |
+
else:
|
23 |
+
torch_dtype = config.torch_dtype
|
24 |
+
assert torch.cuda.is_available(), "CUDA is required."
|
25 |
+
assert config.n_embd % config.n_heads == 0
|
26 |
+
self.n_heads = config.n_heads
|
27 |
+
|
28 |
+
self.device = torch.device("cuda")
|
29 |
+
self.bsz = config.bsz
|
30 |
+
self.c_attn = nn.Linear(
|
31 |
+
config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
|
32 |
+
)
|
33 |
+
self.c_proj = nn.Linear(
|
34 |
+
config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
|
35 |
+
)
|
36 |
+
self.c_proj.SCALE_INIT = 1
|
37 |
+
self.dropout = config.dropout
|
38 |
+
self.resid_dropout = nn.Dropout(self.dropout)
|
39 |
+
self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
|
40 |
+
self.window_size = config.window_size
|
41 |
+
self.softcap = config.softcap
|
42 |
+
|
43 |
+
def _generate_slopes(self, n: int):
|
44 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
45 |
+
return [start * (start**i) for i in range(n)]
|
46 |
+
|
47 |
+
def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
|
48 |
+
# If n_heads is a power of 2, generate slopes directly
|
49 |
+
if math.log2(n_heads).is_integer():
|
50 |
+
slopes = self._generate_slopes(n_heads)
|
51 |
+
else:
|
52 |
+
# Get slopes for the nearest power of two
|
53 |
+
n = nearest_power_of_two(n_heads, round_up=False)
|
54 |
+
slopes_power_of_two = self._generate_slopes(n)
|
55 |
+
|
56 |
+
# Generate extra slopes
|
57 |
+
extra_slopes = self._generate_slopes(2 * n)
|
58 |
+
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
|
59 |
+
slopes = slopes_power_of_two + extra_slopes_trunc
|
60 |
+
slopes = torch.tensor(slopes, device=self.device)
|
61 |
+
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
62 |
+
return slopes.to(torch.float32) # Ensure slopes are in float32
|
63 |
+
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
bsz, seq_len, d_in = x.size()
|
67 |
+
|
68 |
+
qkv = self.c_attn(x)
|
69 |
+
q, k, v = torch.chunk(qkv, 3, dim=2)
|
70 |
+
|
71 |
+
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
|
72 |
+
k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
|
73 |
+
v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
|
74 |
+
y = fa2( # https://arxiv.org/pdf/2307.08691
|
75 |
+
q,
|
76 |
+
k,
|
77 |
+
v,
|
78 |
+
dropout_p=self.dropout if self.training else 0.0,
|
79 |
+
causal=True,
|
80 |
+
window_size=(self.window_size, 0),
|
81 |
+
alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
|
82 |
+
softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
|
83 |
+
)
|
84 |
+
y = y.contiguous().view(bsz, seq_len, d_in)
|
85 |
+
y = self.resid_dropout(self.c_proj(y))
|
86 |
+
return y
|
87 |
+
|
88 |
+
class AttentionSDPA(nn.Module):
|
89 |
+
def __init__(self, config):
|
90 |
+
super(Attention, self).__init__()
|
91 |
+
if isinstance(config.torch_dtype, str):
|
92 |
+
torch_dtype = getattr(torch, config.torch_dtype)
|
93 |
+
else:
|
94 |
+
torch_dtype = config.torch_dtype
|
95 |
+
assert torch.cuda.is_available(), "CUDA is required."
|
96 |
+
assert config.n_embd % config.n_heads == 0
|
97 |
+
self.n_heads = config.n_heads
|
98 |
+
|
99 |
+
self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
|
100 |
+
self.bsz = config.bsz
|
101 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
|
102 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
|
103 |
+
self.dropout = config.dropout
|
104 |
+
self.resid_dropout = nn.Dropout(self.dropout)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
bsz, seq_len, d_in = x.size()
|
108 |
+
|
109 |
+
qkv = self.c_attn(x)
|
110 |
+
q, k, v = torch.chunk(qkv, 3, dim=2)
|
111 |
+
|
112 |
+
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
|
113 |
+
k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
|
114 |
+
v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
|
115 |
+
|
116 |
+
y = F.scaled_dot_product_attention(
|
117 |
+
q, k, v,
|
118 |
+
is_causal=True,
|
119 |
+
dropout_p=self.dropout if self.training else 0.0
|
120 |
+
)
|
121 |
+
|
122 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
|
123 |
+
|
124 |
+
y = self.resid_dropout(self.c_proj(y))
|
125 |
+
return y
|
config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "ministu",
|
3 |
+
"_name_or_path": "STU-500M",
|
4 |
+
"architectures": ["MiniSTU"],
|
5 |
+
"n_embd": 768,
|
6 |
+
"n_heads": 8,
|
7 |
+
"n_layers": 12,
|
8 |
+
"seq_len": 8192,
|
9 |
+
"window_size": 1024,
|
10 |
+
"vocab_size": 200064,
|
11 |
+
"mlp_scale": 12,
|
12 |
+
"bias": false,
|
13 |
+
"dropout": 0.0,
|
14 |
+
"num_eigh": 24,
|
15 |
+
"use_hankel_L": false,
|
16 |
+
"num_epochs": 1,
|
17 |
+
"global_bsz": 524288,
|
18 |
+
"bsz": 1,
|
19 |
+
"warmup_steps": 1907,
|
20 |
+
"eval_period": 25,
|
21 |
+
"save_period": 4500,
|
22 |
+
"max_lr": 3.0e-3,
|
23 |
+
"min_lr": 3.0e-5,
|
24 |
+
"max_norm": 1.0,
|
25 |
+
"dilation": 1,
|
26 |
+
"fsdp": true,
|
27 |
+
"ddp": false,
|
28 |
+
"mixed_precision": true,
|
29 |
+
"use_cpu_offload": false,
|
30 |
+
"sharding_strategy": "full_shard",
|
31 |
+
"state_dict_type": "full",
|
32 |
+
"auto_wrap_policy": "partial",
|
33 |
+
"backward_prefetch": "backward_pre",
|
34 |
+
"forward_prefetch": false,
|
35 |
+
"sync_module_states": true,
|
36 |
+
"use_orig_params": true,
|
37 |
+
"device_id": null,
|
38 |
+
"precision": {
|
39 |
+
"param": "bfloat16",
|
40 |
+
"reduce": "bfloat16",
|
41 |
+
"buffer": "bfloat16"
|
42 |
+
},
|
43 |
+
"fsdp_modules": [
|
44 |
+
"STU",
|
45 |
+
"Attention",
|
46 |
+
"MLP"
|
47 |
+
],
|
48 |
+
"use_activation_checkpointing": true,
|
49 |
+
"use_flash_fft": false,
|
50 |
+
"use_approx": true,
|
51 |
+
"use_attn": true,
|
52 |
+
"softcap": 50.0,
|
53 |
+
"torch_compile": false
|
54 |
+
}
|
55 |
+
|
configuration_ministu.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import PretrainedConfig, AutoConfig
|
3 |
+
|
4 |
+
class MiniSTUConfig(PretrainedConfig):
|
5 |
+
model_type = "ministu"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
bsz: int = 1,
|
10 |
+
n_embd: int = 768,
|
11 |
+
n_heads: int = 8,
|
12 |
+
n_layers: int = 12,
|
13 |
+
seq_len: int = 8192,
|
14 |
+
window_size: int = 1024,
|
15 |
+
vocab_size: int = 200064,
|
16 |
+
mlp_scale: int = 12,
|
17 |
+
bias: bool = False,
|
18 |
+
dropout: float = 0.0,
|
19 |
+
num_eigh: int = 24,
|
20 |
+
use_hankel_L: bool = False,
|
21 |
+
use_flash_fft: bool = True,
|
22 |
+
use_approx: bool = True,
|
23 |
+
use_attn: bool = True,
|
24 |
+
softcap: float = 50.0,
|
25 |
+
torch_dtype = torch.bfloat16,
|
26 |
+
device: str = None,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
super().__init__(**kwargs)
|
30 |
+
self.bsz = bsz
|
31 |
+
self.n_embd = n_embd
|
32 |
+
self.n_heads = n_heads
|
33 |
+
self.n_layers = n_layers
|
34 |
+
self.seq_len = seq_len
|
35 |
+
self.window_size = window_size
|
36 |
+
self.vocab_size = vocab_size
|
37 |
+
self.hidden_size = n_embd
|
38 |
+
self.intermediate_size = n_embd * mlp_scale
|
39 |
+
self.mlp_scale = mlp_scale
|
40 |
+
self.hidden_act = "swish"
|
41 |
+
self.bias = bias
|
42 |
+
self.dropout = dropout
|
43 |
+
self.num_eigh = num_eigh
|
44 |
+
self.use_hankel_L = use_hankel_L
|
45 |
+
self.use_flash_fft = use_flash_fft
|
46 |
+
self.use_approx = use_approx
|
47 |
+
self.use_attn = use_attn
|
48 |
+
self.softcap = softcap
|
49 |
+
self.torch_dtype = torch_dtype # Store as string
|
50 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
|
convolve.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from .utils import nearest_power_of_two
|
5 |
+
from flashfftconv import FlashFFTConv
|
6 |
+
|
7 |
+
|
8 |
+
def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
|
9 |
+
bsz, seq_len, d_in = u.shape
|
10 |
+
|
11 |
+
sgn = torch.full((1, seq_len, 1), 1, device=u.device, dtype=torch.float32)
|
12 |
+
sgn[:, 1::2] *= -1
|
13 |
+
|
14 |
+
# Cast u and v to float32 for FFT
|
15 |
+
u = u.to(torch.float32)
|
16 |
+
v = v.to(torch.float32)
|
17 |
+
|
18 |
+
if use_approx:
|
19 |
+
_, d_out = v.shape
|
20 |
+
v = v.view(1, -1, d_out, 1)
|
21 |
+
else:
|
22 |
+
_, K = v.shape
|
23 |
+
sgn = sgn.unsqueeze(-1)
|
24 |
+
v = v.view(1, -1, K, 1, 1)
|
25 |
+
u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)
|
26 |
+
|
27 |
+
v = torch.fft.rfft(v, n=n, dim=1)
|
28 |
+
U = torch.stack([u, u * sgn], dim=-1)
|
29 |
+
U = torch.fft.rfft(U, n=n, dim=1)
|
30 |
+
U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len]
|
31 |
+
U_plus, U_minus = torch.unbind(U_conv, dim=-1)
|
32 |
+
U_minus = U_minus * sgn
|
33 |
+
|
34 |
+
# Convert back to original dtype
|
35 |
+
U_plus = U_plus.to(u.dtype)
|
36 |
+
U_minus = U_minus.to(u.dtype)
|
37 |
+
|
38 |
+
return U_plus, U_minus
|
39 |
+
|
40 |
+
def flash_convolve(
|
41 |
+
u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
|
42 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
43 |
+
dtype = u.dtype # Store the original dtype
|
44 |
+
u = u.to(torch.float32)
|
45 |
+
v = v.to(torch.float32)
|
46 |
+
|
47 |
+
bsz, seq_len, d_in = u.shape
|
48 |
+
_, K = v.shape
|
49 |
+
|
50 |
+
padded_len = nearest_power_of_two(seq_len, round_up=True)
|
51 |
+
pad_len = padded_len - seq_len
|
52 |
+
|
53 |
+
sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=torch.float32)
|
54 |
+
sgn[:, :, 1::2] = -1
|
55 |
+
|
56 |
+
if use_approx:
|
57 |
+
u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).contiguous()
|
58 |
+
v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).contiguous()
|
59 |
+
u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len)
|
60 |
+
else:
|
61 |
+
u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1).contiguous()
|
62 |
+
v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).repeat(d_in, 1).contiguous()
|
63 |
+
u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len)
|
64 |
+
|
65 |
+
U_conv = flash_fft(u_conv, v_padded)
|
66 |
+
|
67 |
+
# Trim the output back to the original sequence length
|
68 |
+
U_conv = U_conv[..., :seq_len]
|
69 |
+
|
70 |
+
u_plus, u_minus = torch.chunk(U_conv, 2, dim=0)
|
71 |
+
|
72 |
+
if use_approx:
|
73 |
+
u_minus = u_minus * sgn[:, :, :seq_len]
|
74 |
+
U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2)
|
75 |
+
else:
|
76 |
+
sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2)
|
77 |
+
U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous()
|
78 |
+
U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn
|
79 |
+
|
80 |
+
# Convert back to original dtype
|
81 |
+
U_plus = U_plus.to(dtype)
|
82 |
+
U_minus = U_minus.to(dtype)
|
83 |
+
|
84 |
+
return U_plus, U_minus
|
filters.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .utils import logger
|
7 |
+
from .utils import get_hankel
|
8 |
+
|
9 |
+
def get_spectral_filters(
|
10 |
+
seq_len: int,
|
11 |
+
K: int,
|
12 |
+
use_hankel_L: bool = False,
|
13 |
+
device: torch.device = None,
|
14 |
+
dtype: torch.dtype = torch.bfloat16,
|
15 |
+
) -> torch.Tensor:
|
16 |
+
# Generate the Hankel matrix using PyTorch
|
17 |
+
Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype)
|
18 |
+
|
19 |
+
# Cast Z to torch.float32 for the eigenvalue decomposition
|
20 |
+
Z_float32 = Z.to(torch.float32)
|
21 |
+
|
22 |
+
# Perform eigen decomposition using torch.float32
|
23 |
+
sigma, phi = torch.linalg.eigh(Z_float32)
|
24 |
+
|
25 |
+
# Cast the results back to the original dtype (torch.bfloat16)
|
26 |
+
sigma = sigma.to(dtype=dtype)
|
27 |
+
phi = phi.to(dtype=dtype)
|
28 |
+
|
29 |
+
# Select the top K eigenvalues and eigenvectors
|
30 |
+
sigma_k, phi_k = sigma[-K:], phi[:, -K:]
|
31 |
+
|
32 |
+
# Compute the spectral filters
|
33 |
+
phi_k = phi_k * sigma_k ** 0.25
|
34 |
+
|
35 |
+
# Ensure the filters are in the correct dtype and device
|
36 |
+
filters = phi_k.to(device=device, dtype=dtype)
|
37 |
+
|
38 |
+
return filters
|
39 |
+
|
40 |
+
|
41 |
+
def compute_dimensions(n: int) -> tuple[int, int, int]:
|
42 |
+
if n <= 2:
|
43 |
+
raise ValueError("n must be greater than 2")
|
44 |
+
|
45 |
+
T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2
|
46 |
+
sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2))
|
47 |
+
k_max = sqrt_T_prime
|
48 |
+
return T_prime, sqrt_T_prime, k_max
|
49 |
+
|
50 |
+
def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor:
|
51 |
+
T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
|
52 |
+
k = min(k, k_max)
|
53 |
+
|
54 |
+
Z = get_hankel(sqrt_T_prime).to(device)
|
55 |
+
sigma, phi = torch.linalg.eigh(Z)
|
56 |
+
sigma_k = sigma[-k:]
|
57 |
+
phi_k = phi[:, -k:]
|
58 |
+
|
59 |
+
result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device)
|
60 |
+
|
61 |
+
for i in range(k):
|
62 |
+
for j in range(k):
|
63 |
+
phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25)
|
64 |
+
phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25)
|
65 |
+
kron = torch.kron(phi_i, phi_j)
|
66 |
+
result += kron
|
67 |
+
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
def get_tensorized_spectral_filters(
|
72 |
+
n: int = 8192,
|
73 |
+
k: int = 24,
|
74 |
+
use_hankel_L: bool = False,
|
75 |
+
device: torch.device = None,
|
76 |
+
dtype: torch.dtype = torch.bfloat16,
|
77 |
+
) -> torch.Tensor:
|
78 |
+
"""
|
79 |
+
Compute tensorized spectral filters for given sequence length and filter count.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
n: Sequence length
|
83 |
+
k: Number of filters
|
84 |
+
use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main.
|
85 |
+
device: Computation device
|
86 |
+
dtype: Computation dtype
|
87 |
+
"""
|
88 |
+
assert torch.cuda.is_available(), "CUDA is required."
|
89 |
+
|
90 |
+
T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
|
91 |
+
k = min(k, k_max)
|
92 |
+
|
93 |
+
Z = get_hankel(sqrt_T_prime)
|
94 |
+
sigma, phi = torch.linalg.eigh(Z)
|
95 |
+
phi_i = phi[:, -k:] * sigma[-k:] ** 0.25
|
96 |
+
|
97 |
+
if use_hankel_L: # TODO: We may want to use Hankel_L above too if use_hankel_L is true, make another variable for this (mix != use_hankel_L)
|
98 |
+
logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.")
|
99 |
+
Z_L = get_hankel(sqrt_T_prime, True)
|
100 |
+
sigma_L, phi_L = torch.linalg.eigh(Z_L)
|
101 |
+
phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25
|
102 |
+
else:
|
103 |
+
phi_j = phi_i
|
104 |
+
|
105 |
+
filters = torch.kron(phi_i, phi_j)
|
106 |
+
return filters.to(device=device, dtype=dtype)
|
layers.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .modules import STU
|
5 |
+
from .modules import MLP
|
6 |
+
from .modules import Attention
|
7 |
+
try:
|
8 |
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
|
9 |
+
triton_mlp = True
|
10 |
+
except ImportError as e:
|
11 |
+
print(
|
12 |
+
f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
|
13 |
+
)
|
14 |
+
triton_mlp = False
|
15 |
+
|
16 |
+
try:
|
17 |
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
|
18 |
+
triton_norm = True
|
19 |
+
except ImportError as e:
|
20 |
+
print(
|
21 |
+
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
|
22 |
+
)
|
23 |
+
from torch.nn import RMSNorm
|
24 |
+
triton_norm = False
|
25 |
+
|
26 |
+
|
27 |
+
class STULayer(nn.Module):
|
28 |
+
def __init__(self, config, phi, n):
|
29 |
+
super(STULayer, self).__init__()
|
30 |
+
if isinstance(config.torch_dtype, str):
|
31 |
+
torch_dtype = getattr(torch, config.torch_dtype)
|
32 |
+
else:
|
33 |
+
torch_dtype = config.torch_dtype
|
34 |
+
self.stu_norm = (
|
35 |
+
TritonNorm(config.n_embd)
|
36 |
+
if triton_norm
|
37 |
+
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
38 |
+
)
|
39 |
+
self.stu = STU(config, phi, n)
|
40 |
+
self.stu = self.stu.to(dtype=torch_dtype)
|
41 |
+
self.mlp_norm = (
|
42 |
+
TritonNorm(config.n_embd)
|
43 |
+
if triton_norm
|
44 |
+
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
45 |
+
)
|
46 |
+
self.mlp = (
|
47 |
+
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
|
48 |
+
)
|
49 |
+
|
50 |
+
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
|
51 |
+
self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
|
52 |
+
self.mlp = self.mlp.to(dtype=torch_dtype)
|
53 |
+
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56 |
+
# Debug dtype
|
57 |
+
|
58 |
+
# Normalize and apply STU
|
59 |
+
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
|
60 |
+
x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
|
61 |
+
x = x + x_stu
|
62 |
+
|
63 |
+
# Normalize and apply MLP
|
64 |
+
x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP
|
65 |
+
x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
|
66 |
+
x = x + x_mlp
|
67 |
+
|
68 |
+
return x
|
69 |
+
|
70 |
+
class AttentionLayer(nn.Module):
|
71 |
+
def __init__(self, config) -> None:
|
72 |
+
super(AttentionLayer, self).__init__()
|
73 |
+
if isinstance(config.torch_dtype, str):
|
74 |
+
torch_dtype = getattr(torch, config.torch_dtype)
|
75 |
+
else:
|
76 |
+
torch_dtype = config.torch_dtype
|
77 |
+
self.attn_norm = (
|
78 |
+
TritonNorm(config.n_embd)
|
79 |
+
if triton_norm
|
80 |
+
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
81 |
+
)
|
82 |
+
self.attn = Attention(config)
|
83 |
+
self.attn = self.attn.to(dtype=torch_dtype)
|
84 |
+
self.mlp_norm = (
|
85 |
+
TritonNorm(config.n_embd)
|
86 |
+
if triton_norm
|
87 |
+
else RMSNorm(config.n_embd, dtype=torch_dtype)
|
88 |
+
)
|
89 |
+
self.mlp = (
|
90 |
+
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
|
91 |
+
)
|
92 |
+
self.mlp = self.mlp.to(dtype=torch_dtype)
|
93 |
+
|
94 |
+
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
|
95 |
+
self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
|
96 |
+
self.mlp = self.mlp.to(dtype=torch_dtype)
|
97 |
+
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
|
98 |
+
|
99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
100 |
+
x = x + self.attn(self.attn_norm(x))
|
101 |
+
x = x + self.mlp(self.mlp_norm(x))
|
102 |
+
return x
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mlp.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
class MLP(nn.Module):
|
5 |
+
def __init__(self, config, dtype=None):
|
6 |
+
# https://arxiv.org/pdf/2002.05202
|
7 |
+
super().__init__()
|
8 |
+
torch_dtype = getattr(torch, config.torch_dtype, torch.float32) # Use config dtype
|
9 |
+
dtype = dtype if dtype is not None else torch_dtype
|
10 |
+
self.hidden_size = config.n_embd
|
11 |
+
self.intermediate_size = config.n_embd * config.mlp_scale
|
12 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
|
13 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
|
14 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16)
|
15 |
+
self.dropout = nn.Dropout(config.dropout)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
dtype = self.gate_proj.weight.dtype # Match the dtype of projection layers
|
19 |
+
x = x.to(dtype=dtype) # Convert input to the same dtype
|
20 |
+
x = x.to(self.gate_proj.weight.dtype)
|
21 |
+
gate = self.gate_proj(x)
|
22 |
+
gate = F.gelu(gate, approximate="tanh").to(dtype=dtype)
|
23 |
+
up = self.up_proj(x).to(dtype=dtype)
|
24 |
+
fuse = gate * up
|
25 |
+
outputs = self.down_proj(fuse).to(dtype=dtype)
|
26 |
+
outputs = self.dropout(outputs)
|
27 |
+
return outputs
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:101763e8119c492ada9816aa38006cbf6ba8bbc0530224510d62b2c7e20a8bfd
|
3 |
+
size 1140654808
|
modeling_ministu.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import PreTrainedModel
|
5 |
+
from transformers.modeling_outputs import CausalLMOutput
|
6 |
+
from .modules import STU, Attention, MLP
|
7 |
+
from .utils import nearest_power_of_two
|
8 |
+
from .layers import STULayer, AttentionLayer
|
9 |
+
from .configuration_ministu import MiniSTUConfig
|
10 |
+
from .filters import get_spectral_filters
|
11 |
+
|
12 |
+
try:
|
13 |
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
|
14 |
+
triton_norm = True
|
15 |
+
except ImportError as e:
|
16 |
+
print(
|
17 |
+
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
|
18 |
+
)
|
19 |
+
from torch.nn import RMSNorm
|
20 |
+
triton_norm = False
|
21 |
+
# Load the tokenizer
|
22 |
+
#from transformers import AutoModelForCausalLM, AutoTokenizer
|
23 |
+
#model_name = "Hazan-Lab/STU-426M"
|
24 |
+
#tokenizer = AutoTokenizer.from_pretrained(
|
25 |
+
# model_name,
|
26 |
+
# trust_remote_code=True
|
27 |
+
#)
|
28 |
+
|
29 |
+
class MiniSTU(PreTrainedModel):
|
30 |
+
config_class = MiniSTUConfig
|
31 |
+
|
32 |
+
def __init__(self, config) -> None:
|
33 |
+
super(MiniSTU, self).__init__(config)
|
34 |
+
self.n_layers = config.n_layers
|
35 |
+
self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
|
36 |
+
|
37 |
+
if isinstance(config.torch_dtype, torch.dtype):
|
38 |
+
torch_dtype = config.torch_dtype
|
39 |
+
else:
|
40 |
+
torch_dtype = getattr(torch, config.torch_dtype)
|
41 |
+
|
42 |
+
device = torch.device(config.device)
|
43 |
+
|
44 |
+
self.phi = get_spectral_filters(
|
45 |
+
config.seq_len,
|
46 |
+
config.num_eigh,
|
47 |
+
config.use_hankel_L,
|
48 |
+
device=device,
|
49 |
+
dtype=torch_dtype,
|
50 |
+
)
|
51 |
+
|
52 |
+
self.use_approx = config.use_approx
|
53 |
+
self.use_hankel_L = config.use_hankel_L
|
54 |
+
|
55 |
+
self.tok_emb = nn.Embedding(
|
56 |
+
config.vocab_size, config.n_embd, dtype=torch_dtype, device=device
|
57 |
+
)
|
58 |
+
self.dropout = nn.Dropout(config.dropout)
|
59 |
+
|
60 |
+
self.layers = nn.ModuleList()
|
61 |
+
for layer_idx in range(self.n_layers):
|
62 |
+
if layer_idx % 2 == 0:
|
63 |
+
self.layers.append(STULayer(config, self.phi, self.n))
|
64 |
+
else:
|
65 |
+
self.layers.append(
|
66 |
+
AttentionLayer(config)
|
67 |
+
if config.use_attn
|
68 |
+
else STULayer(config, self.phi, self.n)
|
69 |
+
)
|
70 |
+
|
71 |
+
self.norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd)
|
72 |
+
|
73 |
+
self.lm_head = nn.Linear(
|
74 |
+
config.n_embd, config.vocab_size, bias=config.bias, dtype=torch_dtype, device=device
|
75 |
+
)
|
76 |
+
self.tok_emb.weight = self.lm_head.weight
|
77 |
+
|
78 |
+
self.std = (config.n_embd) ** -0.5
|
79 |
+
self.apply(self._init_weights)
|
80 |
+
print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
|
81 |
+
|
82 |
+
def forward(
|
83 |
+
self,
|
84 |
+
input_ids: torch.Tensor,
|
85 |
+
labels: torch.Tensor = None,
|
86 |
+
**kwargs
|
87 |
+
) -> CausalLMOutput:
|
88 |
+
# Compute embeddings
|
89 |
+
tok_emb = self.tok_emb(input_ids)
|
90 |
+
x = self.dropout(tok_emb)
|
91 |
+
|
92 |
+
# Pass through layers
|
93 |
+
for layer in self.layers:
|
94 |
+
x = layer(x)
|
95 |
+
|
96 |
+
# Normalize and project to vocabulary
|
97 |
+
x = self.norm(x)
|
98 |
+
logits = self.lm_head(x)
|
99 |
+
|
100 |
+
loss = None
|
101 |
+
if labels is not None:
|
102 |
+
# Shift so that tokens predict the next token
|
103 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
104 |
+
shift_labels = labels[..., 1:].contiguous()
|
105 |
+
loss_fct = nn.CrossEntropyLoss()
|
106 |
+
loss = loss_fct(
|
107 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
108 |
+
shift_labels.view(-1)
|
109 |
+
)
|
110 |
+
|
111 |
+
return CausalLMOutput(
|
112 |
+
loss=loss,
|
113 |
+
logits=logits,
|
114 |
+
)
|
115 |
+
|
116 |
+
def _get_num_params(self):
|
117 |
+
n_params = sum(p.numel() for p in self.parameters())
|
118 |
+
if hasattr(self, "pos_emb") and self.pos_emb is not None:
|
119 |
+
n_params -= self.pos_emb.weight.numel()
|
120 |
+
if self.tok_emb.weight is not self.lm_head.weight:
|
121 |
+
n_params -= self.tok_emb.weight.numel()
|
122 |
+
return n_params
|
123 |
+
|
124 |
+
def _init_weights(self, module):
|
125 |
+
if isinstance(module, nn.Linear):
|
126 |
+
if hasattr(module, "SCALE_INIT"):
|
127 |
+
self.std *= (2 * self.n_layers) ** -0.5
|
128 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
129 |
+
if module.bias is not None:
|
130 |
+
torch.nn.init.zeros_(module.bias)
|
131 |
+
elif isinstance(module, nn.Embedding):
|
132 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
133 |
+
elif isinstance(module, STU):
|
134 |
+
if self.use_approx:
|
135 |
+
torch.nn.init.xavier_normal_(module.M_inputs)
|
136 |
+
torch.nn.init.xavier_normal_(module.M_filters)
|
137 |
+
else:
|
138 |
+
torch.nn.init.xavier_normal_(module.M_phi_plus)
|
139 |
+
if not self.use_hankel_L:
|
140 |
+
torch.nn.init.xavier_normal_(module.M_phi_minus)
|
141 |
+
elif isinstance(module, Attention):
|
142 |
+
torch.nn.init.xavier_normal_(module.c_attn.weight)
|
143 |
+
torch.nn.init.xavier_normal_(module.c_proj.weight)
|
144 |
+
if module.c_attn.bias is not None:
|
145 |
+
torch.nn.init.zeros_(module.c_attn.bias)
|
146 |
+
if module.c_proj.bias is not None:
|
147 |
+
torch.nn.init.zeros_(module.c_proj.bias)
|
148 |
+
@staticmethod
|
149 |
+
def top_k_top_p_filtering(
|
150 |
+
logits: torch.Tensor,
|
151 |
+
top_k: int = 50,
|
152 |
+
top_p: float = 0.95,
|
153 |
+
filter_value: float = float("-inf"),
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
Filters a distribution of logits using top-k and/or nucleus (top-p) filtering.
|
157 |
+
"""
|
158 |
+
# top_k
|
159 |
+
if top_k > 0:
|
160 |
+
top_k = min(top_k, logits.size(-1))
|
161 |
+
# Remove all logits that are not in the top k
|
162 |
+
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
|
163 |
+
logits[indices_to_remove] = filter_value
|
164 |
+
|
165 |
+
# top_p (nucleus)
|
166 |
+
if 0 < top_p < 1.0:
|
167 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
168 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
169 |
+
|
170 |
+
# Remove tokens with cumulative probability above the threshold
|
171 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
172 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
173 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
174 |
+
sorted_indices_to_remove[:, 0] = False
|
175 |
+
|
176 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
177 |
+
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
178 |
+
)
|
179 |
+
logits[indices_to_remove] = filter_value
|
180 |
+
|
181 |
+
return logits
|
182 |
+
|
183 |
+
def generate(
|
184 |
+
self,
|
185 |
+
input_ids: torch.LongTensor,
|
186 |
+
max_new_tokens: int = 50,
|
187 |
+
temperature: float = 0.5,
|
188 |
+
top_k: int = 50,
|
189 |
+
top_p: float = 0.95,
|
190 |
+
eos_token_id: int = None,
|
191 |
+
pad_token_id: int = 0,
|
192 |
+
**kwargs
|
193 |
+
):
|
194 |
+
"""
|
195 |
+
Naive token-by-token generation loop that uses top-k/top-p filtering and optional temperature.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
input_ids (torch.LongTensor): shape (batch_size, sequence_length).
|
199 |
+
max_new_tokens (int): max number of tokens to generate (beyond input_ids length).
|
200 |
+
temperature (float): sampling temperature (>=0).
|
201 |
+
top_k (int): Top-K sampling cutoff.
|
202 |
+
top_p (float): Nucleus sampling cutoff.
|
203 |
+
eos_token_id (int): If set, stop generation when this token is produced.
|
204 |
+
pad_token_id (int): If set, can be used to pad sequences. (Not fully used here.)
|
205 |
+
kwargs: Unused arguments (like num_beams) for compatibility.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
torch.LongTensor: shape (batch_size, sequence_length + generated_tokens).
|
209 |
+
"""
|
210 |
+
device = input_ids.device
|
211 |
+
#print("1=====================")
|
212 |
+
#print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
|
213 |
+
#print("1=====================")
|
214 |
+
|
215 |
+
# We'll accumulate new tokens into generated_ids
|
216 |
+
generated_ids = input_ids.clone()
|
217 |
+
|
218 |
+
for _ in range(max_new_tokens):
|
219 |
+
# Forward pass to get logits for the last token
|
220 |
+
outputs = self.forward(generated_ids)
|
221 |
+
logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
|
222 |
+
|
223 |
+
# Scale logits by temperature
|
224 |
+
if temperature != 1.0:
|
225 |
+
logits = logits / temperature
|
226 |
+
|
227 |
+
# Filter logits using top-k and/or top-p
|
228 |
+
logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
229 |
+
|
230 |
+
# Convert to probabilities
|
231 |
+
probabilities = F.softmax(logits, dim=-1)
|
232 |
+
|
233 |
+
# Sample from the distribution
|
234 |
+
next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
|
235 |
+
|
236 |
+
# Append next token
|
237 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
238 |
+
|
239 |
+
# If eos_token_id is set and any sample produced it, we optionally could break early
|
240 |
+
if eos_token_id is not None:
|
241 |
+
# Check if all sequences in the batch ended
|
242 |
+
# or if you want to do a more fine-grained approach
|
243 |
+
if (next_token == eos_token_id).all():
|
244 |
+
break
|
245 |
+
#print("2=====================")
|
246 |
+
#print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
|
247 |
+
#print("2=====================")
|
248 |
+
return generated_ids
|
modules.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .attn import Attention
|
2 |
+
from .attn import AttentionSDPA
|
3 |
+
from .mlp import MLP
|
4 |
+
from .stu import STU
|
5 |
+
|
6 |
+
|
results.txt
ADDED
File without changes
|
special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|endoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<|endoftext|>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
stu.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .convolve import convolve, flash_convolve
|
5 |
+
|
6 |
+
try:
|
7 |
+
from flashfftconv import FlashFFTConv
|
8 |
+
|
9 |
+
flash_fft_available = True
|
10 |
+
except ImportError as e:
|
11 |
+
print(
|
12 |
+
f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation."
|
13 |
+
)
|
14 |
+
flash_fft_available = False
|
15 |
+
|
16 |
+
|
17 |
+
class STU(nn.Module):
|
18 |
+
def __init__(self, config, phi, n) -> None:
|
19 |
+
super(STU, self).__init__()
|
20 |
+
self.config = config
|
21 |
+
if isinstance(config.torch_dtype, str):
|
22 |
+
torch_dtype = getattr(torch, config.torch_dtype)
|
23 |
+
else:
|
24 |
+
torch_dtype = config.torch_dtype
|
25 |
+
self.phi = phi.to(device=config.device, dtype=torch_dtype)
|
26 |
+
self.n = n
|
27 |
+
self.K = config.num_eigh
|
28 |
+
self.d_in = config.n_embd
|
29 |
+
self.d_out = config.n_embd
|
30 |
+
self.use_hankel_L = config.use_hankel_L
|
31 |
+
self.use_approx = config.use_approx
|
32 |
+
self.flash_fft = (
|
33 |
+
FlashFFTConv(self.n, dtype=torch.bfloat16)
|
34 |
+
if config.use_flash_fft and flash_fft_available
|
35 |
+
else None
|
36 |
+
)
|
37 |
+
if self.use_approx:
|
38 |
+
self.M_inputs = nn.Parameter(
|
39 |
+
torch.empty(self.d_in, self.d_out, dtype=torch_dtype)
|
40 |
+
)
|
41 |
+
self.M_filters = nn.Parameter(
|
42 |
+
torch.empty(self.K, self.d_in, dtype=torch_dtype)
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
self.M_phi_plus = nn.Parameter(
|
46 |
+
torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
|
47 |
+
)
|
48 |
+
if not self.use_hankel_L:
|
49 |
+
self.M_phi_minus = nn.Parameter(
|
50 |
+
torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54 |
+
dtype = self.M_inputs.dtype
|
55 |
+
x = x.to(dtype=dtype)
|
56 |
+
if self.use_approx:
|
57 |
+
# Contract inputs and filters over the K and d_in dimensions, then convolve
|
58 |
+
x_proj = x @ self.M_inputs
|
59 |
+
phi_proj = self.phi @ self.M_filters
|
60 |
+
x_proj = x_proj.to(dtype=dtype)
|
61 |
+
phi_proj = phi_proj.to(dtype=dtype)
|
62 |
+
if self.flash_fft:
|
63 |
+
spectral_plus, spectral_minus = flash_convolve(
|
64 |
+
x_proj, phi_proj, self.flash_fft, self.use_approx
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
spectral_plus, spectral_minus = convolve(
|
68 |
+
x_proj, phi_proj, self.n, self.use_approx
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
# Convolve inputs and filters,
|
72 |
+
if self.flash_fft:
|
73 |
+
U_plus, U_minus = flash_convolve(
|
74 |
+
x, self.phi, self.flash_fft, self.use_approx
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx)
|
78 |
+
# Then, contract over the K and d_in dimensions
|
79 |
+
spectral_plus = torch.tensordot(
|
80 |
+
U_plus, self.M_phi_plus, dims=([2, 3], [0, 1])
|
81 |
+
)
|
82 |
+
if not self.use_hankel_L:
|
83 |
+
spectral_minus = torch.tensordot(
|
84 |
+
U_minus, self.M_phi_minus, dims=([2, 3], [0, 1])
|
85 |
+
)
|
86 |
+
|
87 |
+
return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"199999": {
|
5 |
+
"content": "<|endoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
},
|
12 |
+
"200018": {
|
13 |
+
"content": "<|endofprompt|>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false,
|
18 |
+
"special": true
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"bos_token": "<|endoftext|>",
|
22 |
+
"clean_up_tokenization_spaces": false,
|
23 |
+
"eos_token": "<|endoftext|>",
|
24 |
+
"model_max_length": 128000,
|
25 |
+
"tokenizer_class": "GPT2Tokenizer",
|
26 |
+
"unk_token": "<|endoftext|>"
|
27 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from colorama import Fore, Style, init
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
init(autoreset=True)
|
13 |
+
|
14 |
+
def nearest_power_of_two(x: int, round_up: bool = False) -> int:
|
15 |
+
return (
|
16 |
+
1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
|
17 |
+
)
|
18 |
+
|
19 |
+
def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
20 |
+
entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device)
|
21 |
+
i_plus_j = entries[:, None] + entries[None, :]
|
22 |
+
|
23 |
+
if use_hankel_L:
|
24 |
+
sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
|
25 |
+
denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
|
26 |
+
Z = sgn * (8.0 / denom)
|
27 |
+
elif not use_hankel_L:
|
28 |
+
Z = 2.0 / (i_plus_j**3 - i_plus_j)
|
29 |
+
else:
|
30 |
+
raise ValueError("use_hankel_L must be a boolean")
|
31 |
+
|
32 |
+
return Z
|
33 |
+
|
34 |
+
|
35 |
+
class ColorFormatter(logging.Formatter):
|
36 |
+
"""
|
37 |
+
A custom log formatter that applies color based on the log level using the Colorama library.
|
38 |
+
|
39 |
+
Attributes:
|
40 |
+
LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes.
|
41 |
+
"""
|
42 |
+
|
43 |
+
# Colors for each log level
|
44 |
+
LOG_COLORS = {
|
45 |
+
logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT,
|
46 |
+
logging.INFO: Fore.CYAN,
|
47 |
+
logging.WARNING: Fore.YELLOW + Style.BRIGHT,
|
48 |
+
logging.ERROR: Fore.RED + Style.BRIGHT,
|
49 |
+
logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL,
|
50 |
+
}
|
51 |
+
|
52 |
+
# Colors for other parts of the log message
|
53 |
+
TIME_COLOR = Fore.GREEN
|
54 |
+
FILE_COLOR = Fore.BLUE
|
55 |
+
LEVEL_COLOR = Style.BRIGHT
|
56 |
+
|
57 |
+
def __init__(self, fmt=None):
|
58 |
+
super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S")
|
59 |
+
|
60 |
+
def format(self, record):
|
61 |
+
"""
|
62 |
+
Formats a log record with the appropriate color based on the log level.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
record (logging.LogRecord): The log record to format.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
str: The formatted log message with colors applied.
|
69 |
+
"""
|
70 |
+
# Apply color based on the log level
|
71 |
+
level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE)
|
72 |
+
time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}"
|
73 |
+
levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}"
|
74 |
+
file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}"
|
75 |
+
|
76 |
+
# Format the log message with color
|
77 |
+
log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}"
|
78 |
+
return log_msg
|
79 |
+
|
80 |
+
def setup_logger():
|
81 |
+
"""
|
82 |
+
Sets up a logger with a custom color formatter that logs to standard output (stdout).
|
83 |
+
|
84 |
+
The logger is configured with the ColorFormatter to format log messages with color based on the log level.
|
85 |
+
The log level is set to INFO by default, but this can be changed to show more or less detailed messages.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
logging.Logger: A logger instance that logs formatted messages to stdout.
|
89 |
+
"""
|
90 |
+
handler = logging.StreamHandler(sys.stdout)
|
91 |
+
|
92 |
+
# Set custom formatter
|
93 |
+
formatter = ColorFormatter()
|
94 |
+
handler.setFormatter(formatter)
|
95 |
+
logger = logging.getLogger(__name__)
|
96 |
+
|
97 |
+
# Set to DEBUG to capture all logging levels
|
98 |
+
DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t")
|
99 |
+
logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO)
|
100 |
+
logger.addHandler(handler)
|
101 |
+
logger.propagate = False # Prevents multiple logging if re-initialized
|
102 |
+
|
103 |
+
return logger
|
104 |
+
|
105 |
+
logger = setup_logger() # Initialize once to prevent multiple loggers
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|