File size: 6,025 Bytes
b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 b0544cb cd75a33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from stu import STU
from modules_stu import Attention
from utils import nearest_power_of_two
from flash_stu.config import FlashSTUConfig
try:
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
triton_mlp = True
except ImportError as e:
print(f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead.")
from modules import MLP
triton_mlp = False
try:
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
triton_norm = True
except ImportError as e:
print(f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation.")
from torch.nn import RMSNorm
triton_norm = False
class STULayer(nn.Module):
def __init__(self, config, phi, n):
super(STULayer, self).__init__()
self.stu_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
self.stu = STU(config, phi, n)
self.mlp_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
self.mlp = TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.stu_norm = self.stu_norm.to(dtype=config.torch_dtype)
self.mlp = self.mlp.to(dtype=config.torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.stu(self.stu_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
class AttentionLayer(nn.Module):
def __init__(self, config) -> None:
super(AttentionLayer, self).__init__()
self.attn_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
self.attn = Attention(config)
self.mlp_norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
self.mlp = TritonMLP(config) if triton_mlp else MLP(config, dtype=config.torch_dtype)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.attn_norm = self.attn_norm.to(dtype=config.torch_dtype)
self.mlp = self.mlp.to(dtype=config.torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=config.torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
class FlashSTU(PreTrainedModel):
config_class = FlashSTUConfig
def __init__(self, config, phi) -> None:
super(FlashSTU, self).__init__(config)
self.n_layers = config.n_layers
self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
self.phi = phi
self.use_approx = config.use_approx
# TODO: Add support for Liger-Kernel Embedding once no longer experimental
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd, dtype=config.torch_dtype)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList()
for layer_idx in range(self.n_layers):
# For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
if layer_idx % 2 == 0:
self.layers.append(STULayer(config, self.phi, self.n))
else:
self.layers.append(AttentionLayer(config) if config.use_attn else STULayer(config, self.phi, self.n))
self.norm = TritonNorm(config.n_embd) if triton_norm else RMSNorm(config.n_embd, dtype=config.torch_dtype)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for RMS Norm
self.norm = self.norm.to(dtype=config.torch_dtype)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=config.bias, dtype=config.torch_dtype)
self.tok_emb.weight = self.lm_head.weight
self.std = (config.n_embd) ** -0.5
self.apply(self._init_weights)
print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
def forward(self, x: torch.Tensor) -> torch.tensor:
tok_emb = self.tok_emb(x)
x = self.dropout(tok_emb)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
y_hat = self.lm_head(x)
return y_hat
def _get_num_params(self):
n_params = sum(p.numel() for p in self.parameters())
if hasattr(self, "pos_emb") and self.pos_emb is not None:
n_params -= self.pos_emb.weight.numel()
if self.tok_emb.weight is not self.lm_head.weight:
n_params -= self.tok_emb.weight.numel()
return n_params
def _init_weights(self, module):
if isinstance(module, nn.Linear):
if hasattr(module, "SCALE_INIT"):
self.std *= (2 * self.n_layers) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
elif isinstance(module, STU):
if self.use_approx:
torch.nn.init.xavier_normal_(module.M_inputs)
torch.nn.init.xavier_normal_(module.M_filters)
else:
torch.nn.init.xavier_normal_(module.M_phi_plus)
torch.nn.init.xavier_normal_(module.M_phi_minus)
elif isinstance(module, Attention):
torch.nn.init.xavier_normal_(module.c_attn.weight)
torch.nn.init.xavier_normal_(module.c_proj.weight)
if module.c_attn.bias is not None:
torch.nn.init.zeros_(module.c_attn.bias)
if module.c_proj.bias is not None:
torch.nn.init.zeros_(module.c_proj.bias)
|