File size: 6,025 Bytes
b0544cb
 
 
 
 
 
cd75a33
 
b0544cb
 
 
cd75a33
b0544cb
 
 
 
 
 
 
cd75a33
 
b0544cb
 
 
cd75a33
b0544cb
cd75a33
 
 
 
b0544cb
cd75a33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0544cb
cd75a33
 
 
 
 
 
 
b0544cb
 
cd75a33
 
b0544cb
 
 
 
 
cd75a33
b0544cb
 
cd75a33
 
b0544cb
cd75a33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0544cb
 
 
 
cd75a33
 
b0544cb
cd75a33
 
b0544cb
cd75a33
b0544cb
cd75a33
b0544cb
 
 
 
cd75a33
 
 
 
b0544cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd75a33
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn

from transformers import PreTrainedModel

from stu import STU
from modules_stu import Attention
from utils import nearest_power_of_two
from flash_stu.config import FlashSTUConfig

try:
    from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
    triton_mlp = True
except ImportError as e:
    print(f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead.")
    from modules import MLP
    triton_mlp = False

try:
    from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
    triton_norm = True
except ImportError as e:
    print(f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation.")
    from torch.nn import RMSNorm
    triton_norm = False

class STULayer(nn.Module):
    def __init__(self, config, phi, n):
        super(STULayer, self).__init__()
        self.stu_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
        self.stu = STU(config, phi, n)
        self.mlp_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
        self.mlp = TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype)
        
        # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
        self.stu_norm = self.stu_norm.to(dtype=config.torch_dtype)
        self.mlp = self.mlp.to(dtype=config.torch_dtype)
        self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.stu(self.stu_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x

class AttentionLayer(nn.Module):
    def __init__(self, config) -> None:
        super(AttentionLayer, self).__init__()
        self.attn_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
        self.attn = Attention(config)
        self.mlp_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
        self.mlp = TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype)

        # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
        self.attn_norm = self.attn_norm.to(dtype=config.torch_dtype)
        self.mlp = self.mlp.to(dtype=config.torch_dtype)
        self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x

class FlashSTU(PreTrainedModel):
    config_class = FlashSTUConfig

    def __init__(self, config, phi) -> None:
        super(FlashSTU, self).__init__(config)
        self.n_layers = config.n_layers
        self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
        self.phi = phi
        self.use_approx = config.use_approx

        # TODO: Add support for Liger-Kernel Embedding once no longer experimental
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd, dtype=config.torch_dtype)
        self.dropout = nn.Dropout(config.dropout)

        self.layers = nn.ModuleList()
        for layer_idx in range(self.n_layers):
            # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
            if layer_idx % 2 == 0:
                self.layers.append(STULayer(config, self.phi, self.n))
            else:
                self.layers.append(AttentionLayer(config) if config.use_attn else STULayer(config, self.phi, self.n))

        self.norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
        # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for RMS Norm
        self.norm = self.norm.to(dtype=config.torch_dtype)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=config.bias, dtype=config.torch_dtype)
        self.tok_emb.weight = self.lm_head.weight

        self.std = (config.n_embd) ** -0.5
        self.apply(self._init_weights)
        print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))

    def forward(self, x: torch.Tensor) -> torch.tensor:
        tok_emb = self.tok_emb(x)
        x = self.dropout(tok_emb)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        y_hat = self.lm_head(x)

        return y_hat

    def _get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        if hasattr(self, "pos_emb") and self.pos_emb is not None:
            n_params -= self.pos_emb.weight.numel()
        if self.tok_emb.weight is not self.lm_head.weight:
            n_params -= self.tok_emb.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            if hasattr(module, "SCALE_INIT"):
                self.std *= (2 * self.n_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
        elif isinstance(module, STU):
            if self.use_approx:
                torch.nn.init.xavier_normal_(module.M_inputs)
                torch.nn.init.xavier_normal_(module.M_filters)
            else:
                torch.nn.init.xavier_normal_(module.M_phi_plus)
                torch.nn.init.xavier_normal_(module.M_phi_minus)
        elif isinstance(module, Attention):
            torch.nn.init.xavier_normal_(module.c_attn.weight)
            torch.nn.init.xavier_normal_(module.c_proj.weight)
            if module.c_attn.bias is not None:
                torch.nn.init.zeros_(module.c_attn.bias)
            if module.c_proj.bias is not None:
                torch.nn.init.zeros_(module.c_proj.bias)