yagizdevre commited on
Commit
5a1cdf2
·
1 Parent(s): cd36468

initial commit

Browse files
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_ministu import MiniSTUConfig
2
+ from .modeling_ministu import MiniSTU
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (261 Bytes). View file
 
__pycache__/attn.cpython-312.pyc ADDED
Binary file (7.99 kB). View file
 
__pycache__/configuration_ministu.cpython-312.pyc ADDED
Binary file (2.22 kB). View file
 
__pycache__/convolve.cpython-312.pyc ADDED
Binary file (5.44 kB). View file
 
__pycache__/filters.cpython-312.pyc ADDED
Binary file (4.71 kB). View file
 
__pycache__/layers.cpython-312.pyc ADDED
Binary file (5.53 kB). View file
 
__pycache__/mlp.cpython-312.pyc ADDED
Binary file (2.67 kB). View file
 
__pycache__/modeling_ministu.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
__pycache__/modules.cpython-312.pyc ADDED
Binary file (291 Bytes). View file
 
__pycache__/stu.cpython-312.pyc ADDED
Binary file (4.49 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.89 kB). View file
 
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<|endofprompt|>": 200018
3
+ }
attn.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .utils import nearest_power_of_two
8
+
9
+ try:
10
+ from flash_attn import flash_attn_func as fa2
11
+ except ImportError as e:
12
+ print(
13
+ f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
14
+ )
15
+ # TODO: Add FlexAttention + local attention mask when it's in stable release
16
+
17
+ class Attention(nn.Module):
18
+ def __init__(self, config):
19
+ super(Attention, self).__init__()
20
+ if isinstance(config.torch_dtype, str):
21
+ torch_dtype = getattr(torch, config.torch_dtype)
22
+ else:
23
+ torch_dtype = config.torch_dtype
24
+ assert torch.cuda.is_available(), "CUDA is required."
25
+ assert config.n_embd % config.n_heads == 0
26
+ self.n_heads = config.n_heads
27
+
28
+ self.device = torch.device("cuda")
29
+ self.bsz = config.bsz
30
+ self.c_attn = nn.Linear(
31
+ config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
32
+ )
33
+ self.c_proj = nn.Linear(
34
+ config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
35
+ )
36
+ self.c_proj.SCALE_INIT = 1
37
+ self.dropout = config.dropout
38
+ self.resid_dropout = nn.Dropout(self.dropout)
39
+ self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
40
+ self.window_size = config.window_size
41
+ self.softcap = config.softcap
42
+
43
+ def _generate_slopes(self, n: int):
44
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
45
+ return [start * (start**i) for i in range(n)]
46
+
47
+ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
48
+ # If n_heads is a power of 2, generate slopes directly
49
+ if math.log2(n_heads).is_integer():
50
+ slopes = self._generate_slopes(n_heads)
51
+ else:
52
+ # Get slopes for the nearest power of two
53
+ n = nearest_power_of_two(n_heads, round_up=False)
54
+ slopes_power_of_two = self._generate_slopes(n)
55
+
56
+ # Generate extra slopes
57
+ extra_slopes = self._generate_slopes(2 * n)
58
+ extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
59
+ slopes = slopes_power_of_two + extra_slopes_trunc
60
+ slopes = torch.tensor(slopes, device=self.device)
61
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
62
+ return slopes.to(torch.float32) # Ensure slopes are in float32
63
+
64
+
65
+ def forward(self, x):
66
+ bsz, seq_len, d_in = x.size()
67
+
68
+ qkv = self.c_attn(x)
69
+ q, k, v = torch.chunk(qkv, 3, dim=2)
70
+
71
+ q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
72
+ k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
73
+ v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
74
+ y = fa2( # https://arxiv.org/pdf/2307.08691
75
+ q,
76
+ k,
77
+ v,
78
+ dropout_p=self.dropout if self.training else 0.0,
79
+ causal=True,
80
+ window_size=(self.window_size, 0),
81
+ alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
82
+ softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
83
+ )
84
+ y = y.contiguous().view(bsz, seq_len, d_in)
85
+ y = self.resid_dropout(self.c_proj(y))
86
+ return y
87
+
88
+ class AttentionSDPA(nn.Module):
89
+ def __init__(self, config):
90
+ super(Attention, self).__init__()
91
+ if isinstance(config.torch_dtype, str):
92
+ torch_dtype = getattr(torch, config.torch_dtype)
93
+ else:
94
+ torch_dtype = config.torch_dtype
95
+ assert torch.cuda.is_available(), "CUDA is required."
96
+ assert config.n_embd % config.n_heads == 0
97
+ self.n_heads = config.n_heads
98
+
99
+ self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
100
+ self.bsz = config.bsz
101
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
102
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
103
+ self.dropout = config.dropout
104
+ self.resid_dropout = nn.Dropout(self.dropout)
105
+
106
+ def forward(self, x):
107
+ bsz, seq_len, d_in = x.size()
108
+
109
+ qkv = self.c_attn(x)
110
+ q, k, v = torch.chunk(qkv, 3, dim=2)
111
+
112
+ q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
113
+ k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
114
+ v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
115
+
116
+ y = F.scaled_dot_product_attention(
117
+ q, k, v,
118
+ is_causal=True,
119
+ dropout_p=self.dropout if self.training else 0.0
120
+ )
121
+
122
+ y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
123
+
124
+ y = self.resid_dropout(self.c_proj(y))
125
+ return y
config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "ministu",
3
+ "_name_or_path": "STU-500M",
4
+ "architectures": ["MiniSTU"],
5
+ "n_embd": 768,
6
+ "n_heads": 8,
7
+ "n_layers": 12,
8
+ "seq_len": 8192,
9
+ "window_size": 1024,
10
+ "vocab_size": 200064,
11
+ "mlp_scale": 12,
12
+ "bias": false,
13
+ "dropout": 0.0,
14
+ "num_eigh": 24,
15
+ "use_hankel_L": false,
16
+ "num_epochs": 1,
17
+ "global_bsz": 524288,
18
+ "bsz": 1,
19
+ "warmup_steps": 1907,
20
+ "eval_period": 25,
21
+ "save_period": 4500,
22
+ "max_lr": 3.0e-3,
23
+ "min_lr": 3.0e-5,
24
+ "max_norm": 1.0,
25
+ "dilation": 1,
26
+ "fsdp": true,
27
+ "ddp": false,
28
+ "mixed_precision": true,
29
+ "use_cpu_offload": false,
30
+ "sharding_strategy": "full_shard",
31
+ "state_dict_type": "full",
32
+ "auto_wrap_policy": "partial",
33
+ "backward_prefetch": "backward_pre",
34
+ "forward_prefetch": false,
35
+ "sync_module_states": true,
36
+ "use_orig_params": true,
37
+ "device_id": null,
38
+ "precision": {
39
+ "param": "bfloat16",
40
+ "reduce": "bfloat16",
41
+ "buffer": "bfloat16"
42
+ },
43
+ "fsdp_modules": [
44
+ "STU",
45
+ "Attention",
46
+ "MLP"
47
+ ],
48
+ "use_activation_checkpointing": true,
49
+ "use_flash_fft": false,
50
+ "use_approx": true,
51
+ "use_attn": true,
52
+ "softcap": 50.0,
53
+ "torch_compile": false
54
+ }
55
+
configuration_ministu.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PretrainedConfig, AutoConfig
3
+
4
+ class MiniSTUConfig(PretrainedConfig):
5
+ model_type = "ministu"
6
+
7
+ def __init__(
8
+ self,
9
+ bsz: int = 1,
10
+ n_embd: int = 768,
11
+ n_heads: int = 8,
12
+ n_layers: int = 12,
13
+ seq_len: int = 8192,
14
+ window_size: int = 1024,
15
+ vocab_size: int = 200064,
16
+ mlp_scale: int = 12,
17
+ bias: bool = False,
18
+ dropout: float = 0.0,
19
+ num_eigh: int = 24,
20
+ use_hankel_L: bool = False,
21
+ use_flash_fft: bool = True,
22
+ use_approx: bool = True,
23
+ use_attn: bool = True,
24
+ softcap: float = 50.0,
25
+ torch_dtype = torch.bfloat16,
26
+ device: str = None,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+ self.bsz = bsz
31
+ self.n_embd = n_embd
32
+ self.n_heads = n_heads
33
+ self.n_layers = n_layers
34
+ self.seq_len = seq_len
35
+ self.window_size = window_size
36
+ self.vocab_size = vocab_size
37
+ self.hidden_size = n_embd
38
+ self.intermediate_size = n_embd * mlp_scale
39
+ self.mlp_scale = mlp_scale
40
+ self.hidden_act = "swish"
41
+ self.bias = bias
42
+ self.dropout = dropout
43
+ self.num_eigh = num_eigh
44
+ self.use_hankel_L = use_hankel_L
45
+ self.use_flash_fft = use_flash_fft
46
+ self.use_approx = use_approx
47
+ self.use_attn = use_attn
48
+ self.softcap = softcap
49
+ self.torch_dtype = torch_dtype # Store as string
50
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
convolve.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .utils import nearest_power_of_two
5
+ from flashfftconv import FlashFFTConv
6
+
7
+
8
+ def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
9
+ bsz, seq_len, d_in = u.shape
10
+
11
+ sgn = torch.full((1, seq_len, 1), 1, device=u.device, dtype=torch.float32)
12
+ sgn[:, 1::2] *= -1
13
+
14
+ # Cast u and v to float32 for FFT
15
+ u = u.to(torch.float32)
16
+ v = v.to(torch.float32)
17
+
18
+ if use_approx:
19
+ _, d_out = v.shape
20
+ v = v.view(1, -1, d_out, 1)
21
+ else:
22
+ _, K = v.shape
23
+ sgn = sgn.unsqueeze(-1)
24
+ v = v.view(1, -1, K, 1, 1)
25
+ u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)
26
+
27
+ v = torch.fft.rfft(v, n=n, dim=1)
28
+ U = torch.stack([u, u * sgn], dim=-1)
29
+ U = torch.fft.rfft(U, n=n, dim=1)
30
+ U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len]
31
+ U_plus, U_minus = torch.unbind(U_conv, dim=-1)
32
+ U_minus = U_minus * sgn
33
+
34
+ # Convert back to original dtype
35
+ U_plus = U_plus.to(u.dtype)
36
+ U_minus = U_minus.to(u.dtype)
37
+
38
+ return U_plus, U_minus
39
+
40
+ def flash_convolve(
41
+ u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
42
+ ) -> tuple[torch.Tensor, torch.Tensor]:
43
+ dtype = u.dtype # Store the original dtype
44
+ u = u.to(torch.float32)
45
+ v = v.to(torch.float32)
46
+
47
+ bsz, seq_len, d_in = u.shape
48
+ _, K = v.shape
49
+
50
+ padded_len = nearest_power_of_two(seq_len, round_up=True)
51
+ pad_len = padded_len - seq_len
52
+
53
+ sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=torch.float32)
54
+ sgn[:, :, 1::2] = -1
55
+
56
+ if use_approx:
57
+ u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).contiguous()
58
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).contiguous()
59
+ u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len)
60
+ else:
61
+ u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1).contiguous()
62
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).repeat(d_in, 1).contiguous()
63
+ u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len)
64
+
65
+ U_conv = flash_fft(u_conv, v_padded)
66
+
67
+ # Trim the output back to the original sequence length
68
+ U_conv = U_conv[..., :seq_len]
69
+
70
+ u_plus, u_minus = torch.chunk(U_conv, 2, dim=0)
71
+
72
+ if use_approx:
73
+ u_minus = u_minus * sgn[:, :, :seq_len]
74
+ U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2)
75
+ else:
76
+ sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2)
77
+ U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous()
78
+ U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn
79
+
80
+ # Convert back to original dtype
81
+ U_plus = U_plus.to(dtype)
82
+ U_minus = U_minus.to(dtype)
83
+
84
+ return U_plus, U_minus
filters.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .utils import logger
7
+ from .utils import get_hankel
8
+
9
+ def get_spectral_filters(
10
+ seq_len: int,
11
+ K: int,
12
+ use_hankel_L: bool = False,
13
+ device: torch.device = None,
14
+ dtype: torch.dtype = torch.bfloat16,
15
+ ) -> torch.Tensor:
16
+ # Generate the Hankel matrix using PyTorch
17
+ Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype)
18
+
19
+ # Cast Z to torch.float32 for the eigenvalue decomposition
20
+ Z_float32 = Z.to(torch.float32)
21
+
22
+ # Perform eigen decomposition using torch.float32
23
+ sigma, phi = torch.linalg.eigh(Z_float32)
24
+
25
+ # Cast the results back to the original dtype (torch.bfloat16)
26
+ sigma = sigma.to(dtype=dtype)
27
+ phi = phi.to(dtype=dtype)
28
+
29
+ # Select the top K eigenvalues and eigenvectors
30
+ sigma_k, phi_k = sigma[-K:], phi[:, -K:]
31
+
32
+ # Compute the spectral filters
33
+ phi_k = phi_k * sigma_k ** 0.25
34
+
35
+ # Ensure the filters are in the correct dtype and device
36
+ filters = phi_k.to(device=device, dtype=dtype)
37
+
38
+ return filters
39
+
40
+
41
+ def compute_dimensions(n: int) -> tuple[int, int, int]:
42
+ if n <= 2:
43
+ raise ValueError("n must be greater than 2")
44
+
45
+ T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2
46
+ sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2))
47
+ k_max = sqrt_T_prime
48
+ return T_prime, sqrt_T_prime, k_max
49
+
50
+ def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor:
51
+ T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
52
+ k = min(k, k_max)
53
+
54
+ Z = get_hankel(sqrt_T_prime).to(device)
55
+ sigma, phi = torch.linalg.eigh(Z)
56
+ sigma_k = sigma[-k:]
57
+ phi_k = phi[:, -k:]
58
+
59
+ result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device)
60
+
61
+ for i in range(k):
62
+ for j in range(k):
63
+ phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25)
64
+ phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25)
65
+ kron = torch.kron(phi_i, phi_j)
66
+ result += kron
67
+
68
+ return result
69
+
70
+
71
+ def get_tensorized_spectral_filters(
72
+ n: int = 8192,
73
+ k: int = 24,
74
+ use_hankel_L: bool = False,
75
+ device: torch.device = None,
76
+ dtype: torch.dtype = torch.bfloat16,
77
+ ) -> torch.Tensor:
78
+ """
79
+ Compute tensorized spectral filters for given sequence length and filter count.
80
+
81
+ Args:
82
+ n: Sequence length
83
+ k: Number of filters
84
+ use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main.
85
+ device: Computation device
86
+ dtype: Computation dtype
87
+ """
88
+ assert torch.cuda.is_available(), "CUDA is required."
89
+
90
+ T_prime, sqrt_T_prime, k_max = compute_dimensions(n)
91
+ k = min(k, k_max)
92
+
93
+ Z = get_hankel(sqrt_T_prime)
94
+ sigma, phi = torch.linalg.eigh(Z)
95
+ phi_i = phi[:, -k:] * sigma[-k:] ** 0.25
96
+
97
+ if use_hankel_L: # TODO: We may want to use Hankel_L above too if use_hankel_L is true, make another variable for this (mix != use_hankel_L)
98
+ logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.")
99
+ Z_L = get_hankel(sqrt_T_prime, True)
100
+ sigma_L, phi_L = torch.linalg.eigh(Z_L)
101
+ phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25
102
+ else:
103
+ phi_j = phi_i
104
+
105
+ filters = torch.kron(phi_i, phi_j)
106
+ return filters.to(device=device, dtype=dtype)
layers.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .modules import STU
5
+ from .modules import MLP
6
+ from .modules import Attention
7
+ try:
8
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
9
+ triton_mlp = True
10
+ except ImportError as e:
11
+ print(
12
+ f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
13
+ )
14
+ triton_mlp = False
15
+
16
+ try:
17
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
18
+ triton_norm = True
19
+ except ImportError as e:
20
+ print(
21
+ f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
22
+ )
23
+ from torch.nn import RMSNorm
24
+ triton_norm = False
25
+
26
+
27
+ class STULayer(nn.Module):
28
+ def __init__(self, config, phi, n):
29
+ super(STULayer, self).__init__()
30
+ if isinstance(config.torch_dtype, str):
31
+ torch_dtype = getattr(torch, config.torch_dtype)
32
+ else:
33
+ torch_dtype = config.torch_dtype
34
+ self.stu_norm = (
35
+ TritonNorm(config.n_embd)
36
+ if triton_norm
37
+ else RMSNorm(config.n_embd, dtype=torch_dtype)
38
+ )
39
+ self.stu = STU(config, phi, n)
40
+ self.stu = self.stu.to(dtype=torch_dtype)
41
+ self.mlp_norm = (
42
+ TritonNorm(config.n_embd)
43
+ if triton_norm
44
+ else RMSNorm(config.n_embd, dtype=torch_dtype)
45
+ )
46
+ self.mlp = (
47
+ TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
48
+ )
49
+
50
+ # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
51
+ self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
52
+ self.mlp = self.mlp.to(dtype=torch_dtype)
53
+ self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ # Debug dtype
57
+
58
+ # Normalize and apply STU
59
+ x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
60
+ x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
61
+ x = x + x_stu
62
+
63
+ # Normalize and apply MLP
64
+ x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP
65
+ x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
66
+ x = x + x_mlp
67
+
68
+ return x
69
+
70
+ class AttentionLayer(nn.Module):
71
+ def __init__(self, config) -> None:
72
+ super(AttentionLayer, self).__init__()
73
+ if isinstance(config.torch_dtype, str):
74
+ torch_dtype = getattr(torch, config.torch_dtype)
75
+ else:
76
+ torch_dtype = config.torch_dtype
77
+ self.attn_norm = (
78
+ TritonNorm(config.n_embd)
79
+ if triton_norm
80
+ else RMSNorm(config.n_embd, dtype=torch_dtype)
81
+ )
82
+ self.attn = Attention(config)
83
+ self.attn = self.attn.to(dtype=torch_dtype)
84
+ self.mlp_norm = (
85
+ TritonNorm(config.n_embd)
86
+ if triton_norm
87
+ else RMSNorm(config.n_embd, dtype=torch_dtype)
88
+ )
89
+ self.mlp = (
90
+ TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
91
+ )
92
+ self.mlp = self.mlp.to(dtype=torch_dtype)
93
+
94
+ # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
95
+ self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
96
+ self.mlp = self.mlp.to(dtype=torch_dtype)
97
+ self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ x = x + self.attn(self.attn_norm(x))
101
+ x = x + self.mlp(self.mlp_norm(x))
102
+ return x
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
mlp.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import functional as F
3
+
4
+ class MLP(nn.Module):
5
+ def __init__(self, config, dtype=None):
6
+ # https://arxiv.org/pdf/2002.05202
7
+ super().__init__()
8
+ torch_dtype = getattr(torch, config.torch_dtype, torch.float32) # Use config dtype
9
+ dtype = dtype if dtype is not None else torch_dtype
10
+ self.hidden_size = config.n_embd
11
+ self.intermediate_size = config.n_embd * config.mlp_scale
12
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
13
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias, dtype=torch.bfloat16)
14
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias, dtype=torch.bfloat16)
15
+ self.dropout = nn.Dropout(config.dropout)
16
+
17
+ def forward(self, x):
18
+ dtype = self.gate_proj.weight.dtype # Match the dtype of projection layers
19
+ x = x.to(dtype=dtype) # Convert input to the same dtype
20
+ x = x.to(self.gate_proj.weight.dtype)
21
+ gate = self.gate_proj(x)
22
+ gate = F.gelu(gate, approximate="tanh").to(dtype=dtype)
23
+ up = self.up_proj(x).to(dtype=dtype)
24
+ fuse = gate * up
25
+ outputs = self.down_proj(fuse).to(dtype=dtype)
26
+ outputs = self.dropout(outputs)
27
+ return outputs
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:101763e8119c492ada9816aa38006cbf6ba8bbc0530224510d62b2c7e20a8bfd
3
+ size 1140654808
modeling_ministu.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel
5
+ from transformers.modeling_outputs import CausalLMOutput
6
+ from .modules import STU, Attention, MLP
7
+ from .utils import nearest_power_of_two
8
+ from .layers import STULayer, AttentionLayer
9
+ from .configuration_ministu import MiniSTUConfig
10
+ from .filters import get_spectral_filters
11
+
12
+ try:
13
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
14
+ triton_norm = True
15
+ except ImportError as e:
16
+ print(
17
+ f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
18
+ )
19
+ from torch.nn import RMSNorm
20
+ triton_norm = False
21
+ # Load the tokenizer
22
+ #from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ #model_name = "Hazan-Lab/STU-426M"
24
+ #tokenizer = AutoTokenizer.from_pretrained(
25
+ # model_name,
26
+ # trust_remote_code=True
27
+ #)
28
+
29
+ class MiniSTU(PreTrainedModel):
30
+ config_class = MiniSTUConfig
31
+
32
+ def __init__(self, config) -> None:
33
+ super(MiniSTU, self).__init__(config)
34
+ self.n_layers = config.n_layers
35
+ self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
36
+
37
+ if isinstance(config.torch_dtype, torch.dtype):
38
+ torch_dtype = config.torch_dtype
39
+ else:
40
+ torch_dtype = getattr(torch, config.torch_dtype)
41
+
42
+ device = torch.device(config.device)
43
+
44
+ self.phi = get_spectral_filters(
45
+ config.seq_len,
46
+ config.num_eigh,
47
+ config.use_hankel_L,
48
+ device=device,
49
+ dtype=torch_dtype,
50
+ )
51
+
52
+ self.use_approx = config.use_approx
53
+ self.use_hankel_L = config.use_hankel_L
54
+
55
+ self.tok_emb = nn.Embedding(
56
+ config.vocab_size, config.n_embd, dtype=torch_dtype, device=device
57
+ )
58
+ self.dropout = nn.Dropout(config.dropout)
59
+
60
+ self.layers = nn.ModuleList()
61
+ for layer_idx in range(self.n_layers):
62
+ if layer_idx % 2 == 0:
63
+ self.layers.append(STULayer(config, self.phi, self.n))
64
+ else:
65
+ self.layers.append(
66
+ AttentionLayer(config)
67
+ if config.use_attn
68
+ else STULayer(config, self.phi, self.n)
69
+ )
70
+
71
+ self.norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd)
72
+
73
+ self.lm_head = nn.Linear(
74
+ config.n_embd, config.vocab_size, bias=config.bias, dtype=torch_dtype, device=device
75
+ )
76
+ self.tok_emb.weight = self.lm_head.weight
77
+
78
+ self.std = (config.n_embd) ** -0.5
79
+ self.apply(self._init_weights)
80
+ print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
81
+
82
+ def forward(
83
+ self,
84
+ input_ids: torch.Tensor,
85
+ labels: torch.Tensor = None,
86
+ **kwargs
87
+ ) -> CausalLMOutput:
88
+ # Compute embeddings
89
+ tok_emb = self.tok_emb(input_ids)
90
+ x = self.dropout(tok_emb)
91
+
92
+ # Pass through layers
93
+ for layer in self.layers:
94
+ x = layer(x)
95
+
96
+ # Normalize and project to vocabulary
97
+ x = self.norm(x)
98
+ logits = self.lm_head(x)
99
+
100
+ loss = None
101
+ if labels is not None:
102
+ # Shift so that tokens predict the next token
103
+ shift_logits = logits[..., :-1, :].contiguous()
104
+ shift_labels = labels[..., 1:].contiguous()
105
+ loss_fct = nn.CrossEntropyLoss()
106
+ loss = loss_fct(
107
+ shift_logits.view(-1, shift_logits.size(-1)),
108
+ shift_labels.view(-1)
109
+ )
110
+
111
+ return CausalLMOutput(
112
+ loss=loss,
113
+ logits=logits,
114
+ )
115
+
116
+ def _get_num_params(self):
117
+ n_params = sum(p.numel() for p in self.parameters())
118
+ if hasattr(self, "pos_emb") and self.pos_emb is not None:
119
+ n_params -= self.pos_emb.weight.numel()
120
+ if self.tok_emb.weight is not self.lm_head.weight:
121
+ n_params -= self.tok_emb.weight.numel()
122
+ return n_params
123
+
124
+ def _init_weights(self, module):
125
+ if isinstance(module, nn.Linear):
126
+ if hasattr(module, "SCALE_INIT"):
127
+ self.std *= (2 * self.n_layers) ** -0.5
128
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
129
+ if module.bias is not None:
130
+ torch.nn.init.zeros_(module.bias)
131
+ elif isinstance(module, nn.Embedding):
132
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
133
+ elif isinstance(module, STU):
134
+ if self.use_approx:
135
+ torch.nn.init.xavier_normal_(module.M_inputs)
136
+ torch.nn.init.xavier_normal_(module.M_filters)
137
+ else:
138
+ torch.nn.init.xavier_normal_(module.M_phi_plus)
139
+ if not self.use_hankel_L:
140
+ torch.nn.init.xavier_normal_(module.M_phi_minus)
141
+ elif isinstance(module, Attention):
142
+ torch.nn.init.xavier_normal_(module.c_attn.weight)
143
+ torch.nn.init.xavier_normal_(module.c_proj.weight)
144
+ if module.c_attn.bias is not None:
145
+ torch.nn.init.zeros_(module.c_attn.bias)
146
+ if module.c_proj.bias is not None:
147
+ torch.nn.init.zeros_(module.c_proj.bias)
148
+ @staticmethod
149
+ def top_k_top_p_filtering(
150
+ logits: torch.Tensor,
151
+ top_k: int = 50,
152
+ top_p: float = 0.95,
153
+ filter_value: float = float("-inf"),
154
+ ):
155
+ """
156
+ Filters a distribution of logits using top-k and/or nucleus (top-p) filtering.
157
+ """
158
+ # top_k
159
+ if top_k > 0:
160
+ top_k = min(top_k, logits.size(-1))
161
+ # Remove all logits that are not in the top k
162
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
163
+ logits[indices_to_remove] = filter_value
164
+
165
+ # top_p (nucleus)
166
+ if 0 < top_p < 1.0:
167
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
168
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
169
+
170
+ # Remove tokens with cumulative probability above the threshold
171
+ sorted_indices_to_remove = cumulative_probs > top_p
172
+ # Shift the indices to the right to keep also the first token above the threshold
173
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
174
+ sorted_indices_to_remove[:, 0] = False
175
+
176
+ indices_to_remove = sorted_indices_to_remove.scatter(
177
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
178
+ )
179
+ logits[indices_to_remove] = filter_value
180
+
181
+ return logits
182
+
183
+ def generate(
184
+ self,
185
+ input_ids: torch.LongTensor,
186
+ max_new_tokens: int = 50,
187
+ temperature: float = 0.5,
188
+ top_k: int = 50,
189
+ top_p: float = 0.95,
190
+ eos_token_id: int = None,
191
+ pad_token_id: int = 0,
192
+ **kwargs
193
+ ):
194
+ """
195
+ Naive token-by-token generation loop that uses top-k/top-p filtering and optional temperature.
196
+
197
+ Args:
198
+ input_ids (torch.LongTensor): shape (batch_size, sequence_length).
199
+ max_new_tokens (int): max number of tokens to generate (beyond input_ids length).
200
+ temperature (float): sampling temperature (>=0).
201
+ top_k (int): Top-K sampling cutoff.
202
+ top_p (float): Nucleus sampling cutoff.
203
+ eos_token_id (int): If set, stop generation when this token is produced.
204
+ pad_token_id (int): If set, can be used to pad sequences. (Not fully used here.)
205
+ kwargs: Unused arguments (like num_beams) for compatibility.
206
+
207
+ Returns:
208
+ torch.LongTensor: shape (batch_size, sequence_length + generated_tokens).
209
+ """
210
+ device = input_ids.device
211
+ #print("1=====================")
212
+ #print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
213
+ #print("1=====================")
214
+
215
+ # We'll accumulate new tokens into generated_ids
216
+ generated_ids = input_ids.clone()
217
+
218
+ for _ in range(max_new_tokens):
219
+ # Forward pass to get logits for the last token
220
+ outputs = self.forward(generated_ids)
221
+ logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
222
+
223
+ # Scale logits by temperature
224
+ if temperature != 1.0:
225
+ logits = logits / temperature
226
+
227
+ # Filter logits using top-k and/or top-p
228
+ logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
229
+
230
+ # Convert to probabilities
231
+ probabilities = F.softmax(logits, dim=-1)
232
+
233
+ # Sample from the distribution
234
+ next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
235
+
236
+ # Append next token
237
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
238
+
239
+ # If eos_token_id is set and any sample produced it, we optionally could break early
240
+ if eos_token_id is not None:
241
+ # Check if all sequences in the batch ended
242
+ # or if you want to do a more fine-grained approach
243
+ if (next_token == eos_token_id).all():
244
+ break
245
+ #print("2=====================")
246
+ #print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
247
+ #print("2=====================")
248
+ return generated_ids
modules.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .attn import Attention
2
+ from .attn import AttentionSDPA
3
+ from .mlp import MLP
4
+ from .stu import STU
5
+
6
+
results.txt ADDED
File without changes
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
stu.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .convolve import convolve, flash_convolve
5
+
6
+ try:
7
+ from flashfftconv import FlashFFTConv
8
+
9
+ flash_fft_available = True
10
+ except ImportError as e:
11
+ print(
12
+ f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation."
13
+ )
14
+ flash_fft_available = False
15
+
16
+
17
+ class STU(nn.Module):
18
+ def __init__(self, config, phi, n) -> None:
19
+ super(STU, self).__init__()
20
+ self.config = config
21
+ if isinstance(config.torch_dtype, str):
22
+ torch_dtype = getattr(torch, config.torch_dtype)
23
+ else:
24
+ torch_dtype = config.torch_dtype
25
+ self.phi = phi.to(device=config.device, dtype=torch_dtype)
26
+ self.n = n
27
+ self.K = config.num_eigh
28
+ self.d_in = config.n_embd
29
+ self.d_out = config.n_embd
30
+ self.use_hankel_L = config.use_hankel_L
31
+ self.use_approx = config.use_approx
32
+ self.flash_fft = (
33
+ FlashFFTConv(self.n, dtype=torch.bfloat16)
34
+ if config.use_flash_fft and flash_fft_available
35
+ else None
36
+ )
37
+ if self.use_approx:
38
+ self.M_inputs = nn.Parameter(
39
+ torch.empty(self.d_in, self.d_out, dtype=torch_dtype)
40
+ )
41
+ self.M_filters = nn.Parameter(
42
+ torch.empty(self.K, self.d_in, dtype=torch_dtype)
43
+ )
44
+ else:
45
+ self.M_phi_plus = nn.Parameter(
46
+ torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
47
+ )
48
+ if not self.use_hankel_L:
49
+ self.M_phi_minus = nn.Parameter(
50
+ torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype)
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ dtype = self.M_inputs.dtype
55
+ x = x.to(dtype=dtype)
56
+ if self.use_approx:
57
+ # Contract inputs and filters over the K and d_in dimensions, then convolve
58
+ x_proj = x @ self.M_inputs
59
+ phi_proj = self.phi @ self.M_filters
60
+ x_proj = x_proj.to(dtype=dtype)
61
+ phi_proj = phi_proj.to(dtype=dtype)
62
+ if self.flash_fft:
63
+ spectral_plus, spectral_minus = flash_convolve(
64
+ x_proj, phi_proj, self.flash_fft, self.use_approx
65
+ )
66
+ else:
67
+ spectral_plus, spectral_minus = convolve(
68
+ x_proj, phi_proj, self.n, self.use_approx
69
+ )
70
+ else:
71
+ # Convolve inputs and filters,
72
+ if self.flash_fft:
73
+ U_plus, U_minus = flash_convolve(
74
+ x, self.phi, self.flash_fft, self.use_approx
75
+ )
76
+ else:
77
+ U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx)
78
+ # Then, contract over the K and d_in dimensions
79
+ spectral_plus = torch.tensordot(
80
+ U_plus, self.M_phi_plus, dims=([2, 3], [0, 1])
81
+ )
82
+ if not self.use_hankel_L:
83
+ spectral_minus = torch.tensordot(
84
+ U_minus, self.M_phi_minus, dims=([2, 3], [0, 1])
85
+ )
86
+
87
+ return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "199999": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "200018": {
13
+ "content": "<|endofprompt|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|endoftext|>",
22
+ "clean_up_tokenization_spaces": false,
23
+ "eos_token": "<|endoftext|>",
24
+ "model_max_length": 128000,
25
+ "tokenizer_class": "GPT2Tokenizer",
26
+ "unk_token": "<|endoftext|>"
27
+ }
utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ import logging
6
+ import os
7
+ import sys
8
+ from colorama import Fore, Style, init
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+ init(autoreset=True)
13
+
14
+ def nearest_power_of_two(x: int, round_up: bool = False) -> int:
15
+ return (
16
+ 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
17
+ )
18
+
19
+ def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor:
20
+ entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device)
21
+ i_plus_j = entries[:, None] + entries[None, :]
22
+
23
+ if use_hankel_L:
24
+ sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
25
+ denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
26
+ Z = sgn * (8.0 / denom)
27
+ elif not use_hankel_L:
28
+ Z = 2.0 / (i_plus_j**3 - i_plus_j)
29
+ else:
30
+ raise ValueError("use_hankel_L must be a boolean")
31
+
32
+ return Z
33
+
34
+
35
+ class ColorFormatter(logging.Formatter):
36
+ """
37
+ A custom log formatter that applies color based on the log level using the Colorama library.
38
+
39
+ Attributes:
40
+ LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes.
41
+ """
42
+
43
+ # Colors for each log level
44
+ LOG_COLORS = {
45
+ logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT,
46
+ logging.INFO: Fore.CYAN,
47
+ logging.WARNING: Fore.YELLOW + Style.BRIGHT,
48
+ logging.ERROR: Fore.RED + Style.BRIGHT,
49
+ logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL,
50
+ }
51
+
52
+ # Colors for other parts of the log message
53
+ TIME_COLOR = Fore.GREEN
54
+ FILE_COLOR = Fore.BLUE
55
+ LEVEL_COLOR = Style.BRIGHT
56
+
57
+ def __init__(self, fmt=None):
58
+ super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S")
59
+
60
+ def format(self, record):
61
+ """
62
+ Formats a log record with the appropriate color based on the log level.
63
+
64
+ Args:
65
+ record (logging.LogRecord): The log record to format.
66
+
67
+ Returns:
68
+ str: The formatted log message with colors applied.
69
+ """
70
+ # Apply color based on the log level
71
+ level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE)
72
+ time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}"
73
+ levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}"
74
+ file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}"
75
+
76
+ # Format the log message with color
77
+ log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}"
78
+ return log_msg
79
+
80
+ def setup_logger():
81
+ """
82
+ Sets up a logger with a custom color formatter that logs to standard output (stdout).
83
+
84
+ The logger is configured with the ColorFormatter to format log messages with color based on the log level.
85
+ The log level is set to INFO by default, but this can be changed to show more or less detailed messages.
86
+
87
+ Returns:
88
+ logging.Logger: A logger instance that logs formatted messages to stdout.
89
+ """
90
+ handler = logging.StreamHandler(sys.stdout)
91
+
92
+ # Set custom formatter
93
+ formatter = ColorFormatter()
94
+ handler.setFormatter(formatter)
95
+ logger = logging.getLogger(__name__)
96
+
97
+ # Set to DEBUG to capture all logging levels
98
+ DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t")
99
+ logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO)
100
+ logger.addHandler(handler)
101
+ logger.propagate = False # Prevents multiple logging if re-initialized
102
+
103
+ return logger
104
+
105
+ logger = setup_logger() # Initialize once to prevent multiple loggers
vocab.json ADDED
The diff for this file is too large to render. See raw diff