Upload 18 files
Browse files- epoch3.pt +3 -0
- inference.py +54 -0
- llama_modeling/__init__.py +0 -0
- llama_modeling/attention.py +106 -0
- llama_modeling/config.py +15 -0
- llama_modeling/decoder.py +42 -0
- llama_modeling/diff_attn.py +150 -0
- llama_modeling/extact.py +25 -0
- llama_modeling/front_end.py +85 -0
- llama_modeling/liger_rope.py +258 -0
- llama_modeling/mlp.py +18 -0
- llama_modeling/model.py +35 -0
- llama_modeling/rms_norm.py +16 -0
- llama_modeling/rope.py +34 -0
- llama_modeling/tensor_prod_attn.py +147 -0
- test-train.py +238 -0
- utils/__init__.py +0 -0
- utils/trainutils.py +92 -0
epoch3.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:849af991539dcc0d5b278fb81e96e2e72d005b5d671b024e18573c30ea51f676
|
3 |
+
size 507810914
|
inference.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from llama_modeling.front_end import LlamaForCausalLM
|
4 |
+
from llama_modeling.config import LlamaConfig
|
5 |
+
import json
|
6 |
+
import sys
|
7 |
+
from utils.trainutils import load_checkpoint
|
8 |
+
|
9 |
+
def generate_text(model, tokenizer, prompt, max_new_tokens=30):
|
10 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to("cuda")
|
11 |
+
|
12 |
+
with torch.inference_mode():
|
13 |
+
outputs = model.generate(
|
14 |
+
input_ids,
|
15 |
+
max_new_tokens=max_new_tokens,
|
16 |
+
temperature=0.7
|
17 |
+
)
|
18 |
+
|
19 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
20 |
+
|
21 |
+
def main():
|
22 |
+
if len(sys.argv) != 2:
|
23 |
+
print("Usage: python inference.py <path_to_model>")
|
24 |
+
sys.exit(1)
|
25 |
+
|
26 |
+
model_path = sys.argv[1]
|
27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
|
29 |
+
with open("config.json") as f:
|
30 |
+
config_dict = json.load(f)
|
31 |
+
config = LlamaConfig(**{k: v for k, v in config_dict.items() if k in LlamaConfig.__dataclass_fields__})
|
32 |
+
|
33 |
+
model = LlamaForCausalLM(config).to(device)
|
34 |
+
|
35 |
+
load_checkpoint(model, model_path)
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained("./SmolLM2-135M-Instruct")
|
39 |
+
|
40 |
+
prompts = [
|
41 |
+
"Once upon a time,",
|
42 |
+
"The best way to learn programming is",
|
43 |
+
"Here's a recipe for chocolate cake:"
|
44 |
+
]
|
45 |
+
|
46 |
+
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=None):
|
47 |
+
for prompt in prompts:
|
48 |
+
print(f"\nPrompt: {prompt}")
|
49 |
+
output = generate_text(model, tokenizer, prompt)
|
50 |
+
print(f"Generated: {output}")
|
51 |
+
print("-" * 50)
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
main()
|
llama_modeling/__init__.py
ADDED
File without changes
|
llama_modeling/attention.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flash_attn import flash_attn_func
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from .liger_rope import LigerRopeFunction
|
8 |
+
from .config import LlamaConfig
|
9 |
+
|
10 |
+
class LlamaAttention(nn.Module):
|
11 |
+
def __init__(self, config: LlamaConfig):
|
12 |
+
super().__init__()
|
13 |
+
self.hidden_size = config.hidden_size
|
14 |
+
self.num_heads = config.num_attention_heads
|
15 |
+
self.num_key_value_heads = config.num_key_value_heads
|
16 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
17 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
18 |
+
self.max_position_embeddings = config.max_position_embeddings
|
19 |
+
self.rope_theta = config.rope_theta
|
20 |
+
|
21 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
22 |
+
raise ValueError(
|
23 |
+
f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size}"
|
24 |
+
f" and `num_attention_heads`: {self.num_heads})."
|
25 |
+
)
|
26 |
+
|
27 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
28 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
29 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
30 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
31 |
+
|
32 |
+
self.register_buffer(
|
33 |
+
"cos_cached",
|
34 |
+
self._compute_rope_embeddings(
|
35 |
+
self.max_position_embeddings,
|
36 |
+
self.head_dim,
|
37 |
+
self.rope_theta,
|
38 |
+
dtype=torch.float32,
|
39 |
+
device=self.q_proj.weight.device,
|
40 |
+
)[0],
|
41 |
+
persistent=False,
|
42 |
+
)
|
43 |
+
self.register_buffer(
|
44 |
+
"sin_cached",
|
45 |
+
self._compute_rope_embeddings(
|
46 |
+
self.max_position_embeddings,
|
47 |
+
self.head_dim,
|
48 |
+
self.rope_theta,
|
49 |
+
dtype=torch.float32,
|
50 |
+
device=self.q_proj.weight.device,
|
51 |
+
)[1],
|
52 |
+
persistent=False,
|
53 |
+
)
|
54 |
+
|
55 |
+
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
|
56 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
57 |
+
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
|
58 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
59 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
60 |
+
cos = emb.cos().to(dtype)
|
61 |
+
sin = emb.sin().to(dtype)
|
62 |
+
return cos.unsqueeze(0), sin.unsqueeze(0)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
hidden_states: torch.Tensor,
|
67 |
+
attention_mask: Optional[torch.Tensor] = None,
|
68 |
+
position_ids: Optional[torch.LongTensor] = None,
|
69 |
+
) -> torch.Tensor:
|
70 |
+
# In B S (H D)
|
71 |
+
bsz, seq_len, _ = hidden_states.size()
|
72 |
+
|
73 |
+
if position_ids is None:
|
74 |
+
position_ids = torch.arange(seq_len, device=hidden_states.device)
|
75 |
+
position_ids = repeat(position_ids, 'l -> b l', b=bsz)
|
76 |
+
|
77 |
+
query_states = self.q_proj(hidden_states)
|
78 |
+
key_states = self.k_proj(hidden_states)
|
79 |
+
value_states = self.v_proj(hidden_states)
|
80 |
+
|
81 |
+
query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.num_heads, d=self.head_dim)
|
82 |
+
key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)
|
83 |
+
value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)
|
84 |
+
|
85 |
+
# Slice off position specific rope freqs from the cached freqs
|
86 |
+
cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
|
87 |
+
sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
|
88 |
+
|
89 |
+
query_states, key_states = LigerRopeFunction.apply(
|
90 |
+
query_states,
|
91 |
+
key_states,
|
92 |
+
cos.squeeze(0),
|
93 |
+
sin.squeeze(0),
|
94 |
+
position_ids
|
95 |
+
)
|
96 |
+
|
97 |
+
attn_output = flash_attn_func(
|
98 |
+
query_states,
|
99 |
+
key_states,
|
100 |
+
value_states,
|
101 |
+
dropout_p=0.0,
|
102 |
+
causal=attention_mask is None
|
103 |
+
)
|
104 |
+
|
105 |
+
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
106 |
+
return self.o_proj(attn_output)
|
llama_modeling/config.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class LlamaConfig:
|
5 |
+
hidden_size: int = 576
|
6 |
+
num_attention_heads: int = 16
|
7 |
+
num_key_value_heads: int = 4
|
8 |
+
num_hidden_layers: int = 30
|
9 |
+
intermediate_size: int = 1536
|
10 |
+
hidden_act: str = "silu"
|
11 |
+
rms_norm_eps: float = 1e-5
|
12 |
+
vocab_size: int = 49152
|
13 |
+
max_position_embeddings: int = 8192
|
14 |
+
rope_theta: int = 100000
|
15 |
+
tie_word_embeddings: bool = False
|
llama_modeling/decoder.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .mlp import LlamaMLP
|
7 |
+
from .config import LlamaConfig
|
8 |
+
from .rms_norm import LlamaRMSNorm
|
9 |
+
from .attention import LlamaAttention
|
10 |
+
from .diff_attn import DifferentialAttention
|
11 |
+
from .tensor_prod_attn import CausalTensorProductSelfAttn
|
12 |
+
|
13 |
+
class LlamaDecoderLayer(nn.Module):
|
14 |
+
def __init__(self, config: LlamaConfig, layer_num):
|
15 |
+
super().__init__()
|
16 |
+
self.self_attn = CausalTensorProductSelfAttn(config)
|
17 |
+
self.mlp = LlamaMLP(config)
|
18 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
19 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
20 |
+
|
21 |
+
def forward(
|
22 |
+
self,
|
23 |
+
hidden_states: torch.Tensor,
|
24 |
+
attention_mask: Optional[torch.Tensor] = None,
|
25 |
+
position_ids: Optional[torch.LongTensor] = None,
|
26 |
+
) -> torch.Tensor:
|
27 |
+
|
28 |
+
residual = hidden_states
|
29 |
+
hidden_states = self.input_layernorm(hidden_states)
|
30 |
+
hidden_states = self.self_attn(
|
31 |
+
hidden_states=hidden_states,
|
32 |
+
attention_mask=attention_mask,
|
33 |
+
position_ids=position_ids,
|
34 |
+
)
|
35 |
+
hidden_states = residual + hidden_states
|
36 |
+
|
37 |
+
residual = hidden_states
|
38 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
39 |
+
hidden_states = self.mlp(hidden_states)
|
40 |
+
hidden_states = residual + hidden_states
|
41 |
+
|
42 |
+
return hidden_states
|
llama_modeling/diff_attn.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flash_attn import flash_attn_func
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from .extact import xATGLU
|
9 |
+
from .liger_rope import LigerRopeFunction
|
10 |
+
from .config import LlamaConfig
|
11 |
+
|
12 |
+
# The four-flash attn strategy comes from here:
|
13 |
+
# https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_2.py
|
14 |
+
|
15 |
+
class DifferentialAttention(nn.Module):
|
16 |
+
def __init__(self, config: LlamaConfig, layer_num):
|
17 |
+
super().__init__()
|
18 |
+
self.hidden_size = config.hidden_size
|
19 |
+
self.num_heads = config.num_attention_heads
|
20 |
+
self.num_kv_heads = config.num_key_value_heads
|
21 |
+
self.n_rep = self.num_heads // self.num_kv_heads
|
22 |
+
self.head_dim = self.hidden_size // (2 * self.num_heads)
|
23 |
+
self.max_position_embeddings = config.max_position_embeddings
|
24 |
+
self.rope_theta = config.rope_theta
|
25 |
+
self.scaling = self.head_dim ** -0.5
|
26 |
+
|
27 |
+
self.q_proj = nn.Linear(self.hidden_size, 2 * self.num_heads * self.head_dim, bias=False)
|
28 |
+
self.k_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False)
|
29 |
+
self.v_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False)
|
30 |
+
self.o_proj = nn.Linear(2 * self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
31 |
+
|
32 |
+
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_num)
|
33 |
+
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
|
34 |
+
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
|
35 |
+
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
|
36 |
+
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
|
37 |
+
|
38 |
+
self.subln = nn.LayerNorm(2 * self.head_dim, elementwise_affine=False)
|
39 |
+
|
40 |
+
self.register_buffer(
|
41 |
+
"cos_cached",
|
42 |
+
self._compute_rope_embeddings(
|
43 |
+
self.max_position_embeddings,
|
44 |
+
self.head_dim,
|
45 |
+
self.rope_theta,
|
46 |
+
dtype=torch.float32,
|
47 |
+
device=self.q_proj.weight.device,
|
48 |
+
)[0],
|
49 |
+
persistent=False,
|
50 |
+
)
|
51 |
+
self.register_buffer(
|
52 |
+
"sin_cached",
|
53 |
+
self._compute_rope_embeddings(
|
54 |
+
self.max_position_embeddings,
|
55 |
+
self.head_dim,
|
56 |
+
self.rope_theta,
|
57 |
+
dtype=torch.float32,
|
58 |
+
device=self.q_proj.weight.device,
|
59 |
+
)[1],
|
60 |
+
persistent=False,
|
61 |
+
)
|
62 |
+
|
63 |
+
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
|
64 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
65 |
+
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
|
66 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
67 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
68 |
+
cos = emb.cos().to(dtype)
|
69 |
+
sin = emb.sin().to(dtype)
|
70 |
+
return cos.unsqueeze(0), sin.unsqueeze(0)
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
hidden_states,
|
75 |
+
attention_mask,
|
76 |
+
position_ids,
|
77 |
+
) -> torch.Tensor:
|
78 |
+
bsz, seq_len, embed_dim = hidden_states.size()
|
79 |
+
|
80 |
+
if position_ids is None:
|
81 |
+
position_ids = torch.arange(seq_len, device=hidden_states.device)
|
82 |
+
position_ids = repeat(position_ids, 'l -> b l', b=bsz)
|
83 |
+
|
84 |
+
q = self.q_proj(hidden_states)
|
85 |
+
k = self.k_proj(hidden_states)
|
86 |
+
v = self.v_proj(hidden_states)
|
87 |
+
|
88 |
+
q = rearrange(q, 'b s (h d) -> b s h d', h=2*self.num_heads, d=self.head_dim)
|
89 |
+
k = rearrange(k, 'b s (h d) -> b s h d', h=2*self.num_kv_heads, d=self.head_dim)
|
90 |
+
|
91 |
+
# Reshaped for GQA
|
92 |
+
v = rearrange(v, 'b s (h g d) -> b s h g d', h=self.num_kv_heads, g=2, d=self.head_dim)
|
93 |
+
|
94 |
+
# Apply rotary embeddings using LigerRopeFunction
|
95 |
+
cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
|
96 |
+
sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
|
97 |
+
q, k = LigerRopeFunction.apply(q, k, cos, sin, position_ids)
|
98 |
+
|
99 |
+
# Rearrange into GQA style
|
100 |
+
q = rearrange(q, 'b s (h g) d -> b s h g d', h=self.num_heads, g=2)
|
101 |
+
k = rearrange(k, 'b s (h g) d -> b s h g d', h=self.num_kv_heads, g=2)
|
102 |
+
|
103 |
+
q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
|
104 |
+
k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
|
105 |
+
v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
|
106 |
+
|
107 |
+
# First attention group on q1/k1 and the v's
|
108 |
+
attn11 = flash_attn_func(
|
109 |
+
q1,
|
110 |
+
k1,
|
111 |
+
v1,
|
112 |
+
dropout_p=0.0, # @Z TODO::
|
113 |
+
causal=attention_mask is None
|
114 |
+
)
|
115 |
+
attn12 = flash_attn_func(
|
116 |
+
q1,
|
117 |
+
k1,
|
118 |
+
v2,
|
119 |
+
dropout_p=0.0,
|
120 |
+
causal=attention_mask is None
|
121 |
+
)
|
122 |
+
attn1 = torch.cat([attn11, attn12], dim=-1)
|
123 |
+
|
124 |
+
# Second attention group on q2/k2 and the v's
|
125 |
+
attn21 = flash_attn_func(
|
126 |
+
q2,
|
127 |
+
k2,
|
128 |
+
v1,
|
129 |
+
dropout_p=0.0,
|
130 |
+
causal=attention_mask is None
|
131 |
+
)
|
132 |
+
attn22 = flash_attn_func(
|
133 |
+
q2,
|
134 |
+
k2,
|
135 |
+
v2,
|
136 |
+
dropout_p=0.0,
|
137 |
+
causal=attention_mask is None
|
138 |
+
)
|
139 |
+
attn2 = torch.cat([attn21, attn22], dim=-1)
|
140 |
+
|
141 |
+
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
|
142 |
+
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
|
143 |
+
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
144 |
+
attn = attn1 - lambda_full * attn2
|
145 |
+
|
146 |
+
attn = self.subln(attn)
|
147 |
+
attn = attn * (1 - self.lambda_init)
|
148 |
+
|
149 |
+
attn_output = rearrange(attn, "b s h d -> b s (h d)")
|
150 |
+
return self.o_proj(attn_output)
|
llama_modeling/extact.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn.
|
6 |
+
class xATGLU(nn.Module):
|
7 |
+
def __init__(self, input_dim, output_dim, bias=True):
|
8 |
+
super().__init__()
|
9 |
+
# GATE path | VALUE path
|
10 |
+
self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias)
|
11 |
+
nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear')
|
12 |
+
|
13 |
+
self.alpha = nn.Parameter(torch.zeros(1))
|
14 |
+
self.half_pi = torch.pi / 2
|
15 |
+
self.inv_pi = 1 / torch.pi
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
projected = self.proj(x)
|
19 |
+
gate_path, value_path = projected.chunk(2, dim=-1)
|
20 |
+
|
21 |
+
# Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768
|
22 |
+
gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi
|
23 |
+
expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha
|
24 |
+
|
25 |
+
return expanded_gate * value_path # g(x) × y
|
llama_modeling/front_end.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
from .config import LlamaConfig
|
8 |
+
from .model import LlamaModel
|
9 |
+
|
10 |
+
class LlamaForCausalLM(nn.Module):
|
11 |
+
def __init__(self, config: LlamaConfig):
|
12 |
+
super().__init__()
|
13 |
+
self.model = LlamaModel(config)
|
14 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
15 |
+
|
16 |
+
# Weight tying uses the head weights as the classifier for the token embeddings for both in and out.
|
17 |
+
if config.tie_word_embeddings:
|
18 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
19 |
+
|
20 |
+
self._init_weights()
|
21 |
+
|
22 |
+
def _init_weights(self):
|
23 |
+
"""Initialize weights for all layers."""
|
24 |
+
# Initialize embeddings
|
25 |
+
if hasattr(self.model, 'embed_tokens'):
|
26 |
+
nn.init.normal_(self.model.embed_tokens.weight, mean=0.0, std=0.041666666666666664)
|
27 |
+
|
28 |
+
# Initialize linear layers
|
29 |
+
for module in self.modules():
|
30 |
+
if isinstance(module, nn.Linear):
|
31 |
+
# Xavier/Glorot initialization for weights
|
32 |
+
nn.init.xavier_uniform_(module.weight)
|
33 |
+
if module.bias is not None:
|
34 |
+
# Zero initialization for biases
|
35 |
+
nn.init.zeros_(module.bias)
|
36 |
+
|
37 |
+
def forward(
|
38 |
+
self,
|
39 |
+
input_ids: torch.LongTensor,
|
40 |
+
attention_mask: Optional[torch.Tensor] = None,
|
41 |
+
position_ids: Optional[torch.LongTensor] = None,
|
42 |
+
labels: Optional[torch.LongTensor] = None,
|
43 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
44 |
+
hidden_states = self.model(
|
45 |
+
input_ids,
|
46 |
+
attention_mask=attention_mask,
|
47 |
+
position_ids=position_ids,
|
48 |
+
)
|
49 |
+
|
50 |
+
return hidden_states, self.lm_head.weight
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def generate(
|
54 |
+
self,
|
55 |
+
input_ids: torch.LongTensor,
|
56 |
+
max_new_tokens: int = 30,
|
57 |
+
temperature: float = 0.0,
|
58 |
+
) -> torch.LongTensor:
|
59 |
+
self.eval()
|
60 |
+
bsz, seq_len = input_ids.shape
|
61 |
+
|
62 |
+
position_ids = repeat(
|
63 |
+
torch.arange(seq_len, device=input_ids.device),
|
64 |
+
'l -> b l',
|
65 |
+
b=bsz
|
66 |
+
)
|
67 |
+
|
68 |
+
for _ in range(max_new_tokens):
|
69 |
+
hidden_states, classifier_weights = self.forward(input_ids, position_ids=position_ids)
|
70 |
+
|
71 |
+
# Get logits by computing hidden_states @ classifier_weights.T
|
72 |
+
next_token_logits = hidden_states[:, -1] @ classifier_weights.T
|
73 |
+
|
74 |
+
if temperature == 0:
|
75 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
76 |
+
else:
|
77 |
+
scaled_logits = next_token_logits / temperature
|
78 |
+
probs = torch.softmax(scaled_logits, dim=-1)
|
79 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
80 |
+
|
81 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
82 |
+
new_position_ids = position_ids[:, -1:] + 1
|
83 |
+
position_ids = torch.cat([position_ids, new_position_ids], dim=1)
|
84 |
+
|
85 |
+
return input_ids
|
llama_modeling/liger_rope.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py
|
6 |
+
# BSD 2-CLAUSE LICENSE
|
7 |
+
# Copyright 2024 LinkedIn Corporation
|
8 |
+
# All Rights Reserved.
|
9 |
+
# Redistribution and use in source and binary forms, with or
|
10 |
+
# without modification, are permitted provided that the following
|
11 |
+
# conditions are met:
|
12 |
+
# 1. Redistributions of source code must retain the above copyright
|
13 |
+
# notice, this list of conditions and the following disclaimer.
|
14 |
+
# 2. Redistributions in binary form must reproduce the above
|
15 |
+
# copyright notice, this list of conditions and the following
|
16 |
+
# disclaimer in the documentation and/or other materials provided
|
17 |
+
# with the distribution.
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
19 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
20 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
21 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
22 |
+
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
23 |
+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
24 |
+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
25 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
26 |
+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
27 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
|
30 |
+
@triton.jit
|
31 |
+
def _triton_rope(
|
32 |
+
q_ptr,
|
33 |
+
q_row_stride,
|
34 |
+
k_ptr,
|
35 |
+
k_row_stride,
|
36 |
+
cos,
|
37 |
+
cos_row_stride,
|
38 |
+
sin,
|
39 |
+
sin_row_stride,
|
40 |
+
sl,
|
41 |
+
bs: tl.constexpr,
|
42 |
+
cos_bs: tl.constexpr,
|
43 |
+
n_qh: tl.constexpr,
|
44 |
+
n_kh: tl.constexpr,
|
45 |
+
hd: tl.constexpr,
|
46 |
+
pad_n_qh: tl.constexpr,
|
47 |
+
pad_n_kh: tl.constexpr,
|
48 |
+
pad_hd: tl.constexpr,
|
49 |
+
BLOCK_SIZE: tl.constexpr,
|
50 |
+
BACKWARD_PASS: tl.constexpr = False,
|
51 |
+
):
|
52 |
+
# q size: (bsz, seq_len, num_q_heads, head_dim)
|
53 |
+
# q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
|
54 |
+
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
55 |
+
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
56 |
+
|
57 |
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
58 |
+
# stride: (seq_len * head_dim, head_dim, 1)
|
59 |
+
pid = tl.program_id(0)
|
60 |
+
|
61 |
+
# locate start address
|
62 |
+
q_ptr = q_ptr + pid * q_row_stride
|
63 |
+
k_ptr = k_ptr + pid * k_row_stride
|
64 |
+
|
65 |
+
# ####################################################################
|
66 |
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
67 |
+
# m of this program instance
|
68 |
+
# ####################################################################
|
69 |
+
|
70 |
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
71 |
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
72 |
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
73 |
+
# and pid % sl to get the sequence index.
|
74 |
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
75 |
+
# a clone of the left half.
|
76 |
+
batch_idx = pid // sl
|
77 |
+
cos_row_idx = pid % sl
|
78 |
+
cos = cos + tl.where(
|
79 |
+
cos_bs == 1,
|
80 |
+
cos_row_idx * cos_row_stride,
|
81 |
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
82 |
+
)
|
83 |
+
sin = sin + tl.where(
|
84 |
+
cos_bs == 1,
|
85 |
+
cos_row_idx * sin_row_stride,
|
86 |
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
87 |
+
)
|
88 |
+
|
89 |
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
90 |
+
cos_mask = cos_offsets < hd // 2
|
91 |
+
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
92 |
+
sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
|
93 |
+
|
94 |
+
# ####################################################################
|
95 |
+
# Load the left and right half of q and k for the current
|
96 |
+
# program instance (i.e. for the current token) separately
|
97 |
+
# ####################################################################
|
98 |
+
# left half of the head
|
99 |
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
100 |
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
101 |
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
102 |
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
103 |
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
104 |
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
105 |
+
|
106 |
+
# right half of the head
|
107 |
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
108 |
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
109 |
+
second_q_mask = first_q_mask
|
110 |
+
second_k_mask = first_k_mask
|
111 |
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
112 |
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
113 |
+
|
114 |
+
if not BACKWARD_PASS:
|
115 |
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
116 |
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
117 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
118 |
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
119 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
120 |
+
|
121 |
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
122 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
123 |
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
124 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
125 |
+
else:
|
126 |
+
# with some math, we can get:
|
127 |
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
128 |
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
129 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
130 |
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
131 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
132 |
+
|
133 |
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
134 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
135 |
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
136 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
137 |
+
|
138 |
+
|
139 |
+
def rope_forward(q, k, cos, sin):
|
140 |
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
141 |
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
142 |
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
143 |
+
n_kv_head = k.shape[2]
|
144 |
+
|
145 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
146 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
147 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
148 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
149 |
+
|
150 |
+
n_row = batch_size * seq_len
|
151 |
+
|
152 |
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
153 |
+
q = q.contiguous()
|
154 |
+
k = k.contiguous()
|
155 |
+
cos = cos.contiguous()
|
156 |
+
sin = sin.contiguous()
|
157 |
+
cos_batch_size = cos.shape[0]
|
158 |
+
|
159 |
+
_triton_rope[(n_row,)](
|
160 |
+
q,
|
161 |
+
q.stride(1),
|
162 |
+
k,
|
163 |
+
k.stride(1),
|
164 |
+
cos,
|
165 |
+
cos.stride(-2),
|
166 |
+
sin,
|
167 |
+
sin.stride(-2),
|
168 |
+
seq_len,
|
169 |
+
batch_size,
|
170 |
+
cos_batch_size,
|
171 |
+
n_q_head,
|
172 |
+
n_kv_head,
|
173 |
+
head_dim,
|
174 |
+
pad_n_q_head,
|
175 |
+
pad_n_kv_head,
|
176 |
+
pad_hd,
|
177 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
178 |
+
BACKWARD_PASS=False,
|
179 |
+
)
|
180 |
+
return q, k, cos, sin
|
181 |
+
|
182 |
+
|
183 |
+
def rope_backward(dq, dk, cos, sin):
|
184 |
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
185 |
+
cos_batch_size = cos.shape[0]
|
186 |
+
n_kv_head = dk.shape[2]
|
187 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
188 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
189 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
190 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
191 |
+
|
192 |
+
n_row = batch_size * seq_len
|
193 |
+
|
194 |
+
# ensure dq and dk are contiguous
|
195 |
+
dq = dq.contiguous()
|
196 |
+
dk = dk.contiguous()
|
197 |
+
|
198 |
+
# backward is similar to forward except swapping few ops
|
199 |
+
_triton_rope[(n_row,)](
|
200 |
+
dq,
|
201 |
+
dq.stride(1),
|
202 |
+
dk,
|
203 |
+
dk.stride(1),
|
204 |
+
cos,
|
205 |
+
cos.stride(-2),
|
206 |
+
sin,
|
207 |
+
sin.stride(-2),
|
208 |
+
seq_len,
|
209 |
+
batch_size,
|
210 |
+
cos_batch_size,
|
211 |
+
n_q_head,
|
212 |
+
n_kv_head,
|
213 |
+
head_dim,
|
214 |
+
pad_n_q_head,
|
215 |
+
pad_n_kv_head,
|
216 |
+
pad_hd,
|
217 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
218 |
+
BACKWARD_PASS=True,
|
219 |
+
)
|
220 |
+
return dq, dk
|
221 |
+
|
222 |
+
|
223 |
+
class LigerRopeFunction(torch.autograd.Function):
|
224 |
+
"""
|
225 |
+
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
|
226 |
+
this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
|
227 |
+
than the original RoPE paper.
|
228 |
+
|
229 |
+
Please find the corresponding HuggingFace implementation here:
|
230 |
+
https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
|
231 |
+
|
232 |
+
For more details about the rotation matrix used here, please refer to:
|
233 |
+
https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
|
234 |
+
"""
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
238 |
+
"""
|
239 |
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
240 |
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
241 |
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
242 |
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
243 |
+
"""
|
244 |
+
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
245 |
+
ctx.save_for_backward(cos, sin)
|
246 |
+
return q, k
|
247 |
+
|
248 |
+
def backward(ctx, dq, dk):
|
249 |
+
"""
|
250 |
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
251 |
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
252 |
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
253 |
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
254 |
+
"""
|
255 |
+
|
256 |
+
cos, sin = ctx.saved_tensors
|
257 |
+
dq, dk = rope_backward(dq, dk, cos, sin)
|
258 |
+
return dq, dk, None, None, None, None
|
llama_modeling/mlp.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .config import LlamaConfig
|
6 |
+
|
7 |
+
class LlamaMLP(nn.Module):
|
8 |
+
def __init__(self, config: LlamaConfig):
|
9 |
+
super().__init__()
|
10 |
+
self.hidden_size = config.hidden_size
|
11 |
+
self.intermediate_size = config.intermediate_size
|
12 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
13 |
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
14 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
15 |
+
self.act_fn = nn.SiLU()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
llama_modeling/model.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .mlp import LlamaMLP
|
7 |
+
from .config import LlamaConfig
|
8 |
+
from .rms_norm import LlamaRMSNorm
|
9 |
+
from .decoder import LlamaDecoderLayer
|
10 |
+
|
11 |
+
class LlamaModel(nn.Module):
|
12 |
+
def __init__(self, config: LlamaConfig):
|
13 |
+
super().__init__()
|
14 |
+
self.config = config
|
15 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=None)
|
16 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
|
17 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
18 |
+
|
19 |
+
def forward(
|
20 |
+
self,
|
21 |
+
input_ids: torch.LongTensor,
|
22 |
+
attention_mask: Optional[torch.Tensor] = None,
|
23 |
+
position_ids: Optional[torch.LongTensor] = None,
|
24 |
+
) -> torch.Tensor:
|
25 |
+
hidden_states = self.embed_tokens(input_ids)
|
26 |
+
|
27 |
+
for decoder_layer in self.layers:
|
28 |
+
hidden_states = decoder_layer(
|
29 |
+
hidden_states,
|
30 |
+
attention_mask=attention_mask,
|
31 |
+
position_ids=position_ids,
|
32 |
+
)
|
33 |
+
|
34 |
+
hidden_states = self.norm(hidden_states)
|
35 |
+
return hidden_states
|
llama_modeling/rms_norm.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class LlamaRMSNorm(nn.Module):
|
6 |
+
def __init__(self, hidden_size, eps=1e-5):
|
7 |
+
super().__init__()
|
8 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
9 |
+
self.variance_epsilon = eps
|
10 |
+
|
11 |
+
def forward(self, hidden_states):
|
12 |
+
input_dtype = hidden_states.dtype
|
13 |
+
hidden_states = hidden_states.to(torch.float32)
|
14 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
15 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
16 |
+
return self.weight * hidden_states.to(input_dtype)
|
llama_modeling/rope.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def rotate_half(x):
|
6 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
7 |
+
return torch.cat((-x2, x1), dim=-1)
|
8 |
+
|
9 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
10 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
11 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
12 |
+
return q_embed, k_embed
|
13 |
+
|
14 |
+
class LlamaRotaryEmbedding(nn.Module):
|
15 |
+
def __init__(self, dim, max_position_embeddings=8192, base=10000):
|
16 |
+
super().__init__()
|
17 |
+
self.dim = dim
|
18 |
+
self.base = base
|
19 |
+
self.max_position_embeddings = max_position_embeddings
|
20 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
21 |
+
self.register_buffer("inv_freq", inv_freq)
|
22 |
+
|
23 |
+
def forward(self, position_ids: torch.LongTensor):
|
24 |
+
# position_ids: [batch_size, seq_len]
|
25 |
+
inv_freq = self.inv_freq.to(device=position_ids.device)
|
26 |
+
inv_freq_expanded = inv_freq[None, None, :] # [1, 1, dim//2]
|
27 |
+
position_ids_expanded = position_ids[:, :, None].float() # [batch_size, seq_len, 1]
|
28 |
+
freqs = torch.matmul(position_ids_expanded, inv_freq_expanded) # [batch_size, seq_len, dim//2]
|
29 |
+
freqs = torch.cat([freqs, freqs], dim=-1) # [batch_size, seq_len, dim]
|
30 |
+
cos = torch.cos(freqs)
|
31 |
+
sin = torch.sin(freqs)
|
32 |
+
cos = cos.unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
33 |
+
sin = sin.unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
34 |
+
return cos, sin
|
llama_modeling/tensor_prod_attn.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from flash_attn import flash_attn_func
|
9 |
+
from .liger_rope import LigerRopeFunction
|
10 |
+
from .rms_norm import LlamaRMSNorm
|
11 |
+
from .config import LlamaConfig
|
12 |
+
|
13 |
+
class CPLinear(nn.Module):
|
14 |
+
def __init__(self, in_features, n_head, head_dim, kv_rank=2, q_rank=6):
|
15 |
+
super().__init__()
|
16 |
+
self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False)
|
17 |
+
self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False)
|
18 |
+
self.W_A_k = nn.Linear(in_features, n_head * kv_rank, bias=False)
|
19 |
+
self.W_B_k = nn.Linear(in_features, kv_rank * head_dim, bias=False)
|
20 |
+
self.W_A_v = nn.Linear(in_features, n_head * kv_rank, bias=False)
|
21 |
+
self.W_B_v = nn.Linear(in_features, kv_rank * head_dim, bias=False)
|
22 |
+
|
23 |
+
nn.init.xavier_uniform_(self.W_A_q.weight)
|
24 |
+
nn.init.xavier_uniform_(self.W_B_q.weight)
|
25 |
+
nn.init.xavier_uniform_(self.W_A_k.weight)
|
26 |
+
nn.init.xavier_uniform_(self.W_B_k.weight)
|
27 |
+
nn.init.xavier_uniform_(self.W_A_v.weight)
|
28 |
+
nn.init.xavier_uniform_(self.W_B_v.weight)
|
29 |
+
|
30 |
+
self.n_head = n_head
|
31 |
+
self.q_rank = q_rank
|
32 |
+
self.head_dim = head_dim
|
33 |
+
self.kv_rank = kv_rank
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
batch_size, seq_len, _ = x.size()
|
37 |
+
|
38 |
+
A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank)
|
39 |
+
A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.kv_rank)
|
40 |
+
A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.kv_rank)
|
41 |
+
|
42 |
+
B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim)
|
43 |
+
B_k = self.W_B_k(x).view(batch_size, seq_len, self.kv_rank, self.head_dim)
|
44 |
+
B_v = self.W_B_v(x).view(batch_size, seq_len, self.kv_rank, self.head_dim)
|
45 |
+
|
46 |
+
A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank)
|
47 |
+
A_k = A_k.view(batch_size * seq_len, self.n_head, self.kv_rank)
|
48 |
+
A_v = A_v.view(batch_size * seq_len, self.n_head, self.kv_rank)
|
49 |
+
|
50 |
+
B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim)
|
51 |
+
B_k = B_k.view(batch_size * seq_len, self.kv_rank, self.head_dim)
|
52 |
+
B_v = B_v.view(batch_size * seq_len, self.kv_rank, self.head_dim)
|
53 |
+
|
54 |
+
q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
|
55 |
+
k = torch.bmm(A_k, B_k).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
|
56 |
+
v = torch.bmm(A_v, B_v).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
|
57 |
+
|
58 |
+
return q, k, v
|
59 |
+
|
60 |
+
class CausalTensorProductSelfAttn(nn.Module):
|
61 |
+
def __init__(self, config, kv_rank=2, q_rank=6):
|
62 |
+
super().__init__()
|
63 |
+
self.n_head = config.num_attention_heads
|
64 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
65 |
+
self.n_embd = config.hidden_size
|
66 |
+
self.rank = kv_rank
|
67 |
+
self.q_rank = q_rank
|
68 |
+
self.max_position_embeddings = config.max_position_embeddings
|
69 |
+
self.rope_theta = config.rope_theta
|
70 |
+
|
71 |
+
self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.rank, self.q_rank)
|
72 |
+
self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=False)
|
73 |
+
|
74 |
+
self.register_buffer(
|
75 |
+
"cos_cached",
|
76 |
+
self._compute_rope_embeddings(
|
77 |
+
self.max_position_embeddings,
|
78 |
+
self.head_dim,
|
79 |
+
self.rope_theta,
|
80 |
+
dtype=torch.float32,
|
81 |
+
device=self.o_proj.weight.device,
|
82 |
+
)[0],
|
83 |
+
persistent=False,
|
84 |
+
)
|
85 |
+
self.register_buffer(
|
86 |
+
"sin_cached",
|
87 |
+
self._compute_rope_embeddings(
|
88 |
+
self.max_position_embeddings,
|
89 |
+
self.head_dim,
|
90 |
+
self.rope_theta,
|
91 |
+
dtype=torch.float32,
|
92 |
+
device=self.o_proj.weight.device,
|
93 |
+
)[1],
|
94 |
+
persistent=False,
|
95 |
+
)
|
96 |
+
|
97 |
+
self.using_groupnorm = getattr(config, 'using_groupnorm', False)
|
98 |
+
self.subln = LlamaRMSNorm(self.head_dim, eps=1e-5)
|
99 |
+
|
100 |
+
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
|
101 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
102 |
+
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
|
103 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
104 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
105 |
+
cos = emb.cos().to(dtype)
|
106 |
+
sin = emb.sin().to(dtype)
|
107 |
+
return cos.unsqueeze(0), sin.unsqueeze(0)
|
108 |
+
|
109 |
+
def forward(
|
110 |
+
self,
|
111 |
+
hidden_states: torch.Tensor,
|
112 |
+
attention_mask: Optional[torch.Tensor] = None,
|
113 |
+
position_ids: Optional[torch.LongTensor] = None,
|
114 |
+
) -> torch.Tensor:
|
115 |
+
# In B S (H D)
|
116 |
+
bsz, seq_len, _ = hidden_states.size()
|
117 |
+
|
118 |
+
if position_ids is None:
|
119 |
+
position_ids = torch.arange(seq_len, device=hidden_states.device)
|
120 |
+
position_ids = repeat(position_ids, 'l -> b l', b=bsz)
|
121 |
+
|
122 |
+
q, k, v = self.c_qkv(hidden_states) # B S (HD) -> B S H D
|
123 |
+
|
124 |
+
cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
|
125 |
+
sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
|
126 |
+
|
127 |
+
q, k = LigerRopeFunction.apply(
|
128 |
+
q,
|
129 |
+
k,
|
130 |
+
cos.squeeze(0),
|
131 |
+
sin.squeeze(0),
|
132 |
+
position_ids
|
133 |
+
)
|
134 |
+
|
135 |
+
attn_out = flash_attn_func(
|
136 |
+
q,
|
137 |
+
k,
|
138 |
+
v,
|
139 |
+
dropout_p=0.0,
|
140 |
+
causal=attention_mask is None
|
141 |
+
)
|
142 |
+
|
143 |
+
attn_out = self.subln(attn_out)
|
144 |
+
|
145 |
+
attn_out = rearrange(attn_out, "b s h d -> b s (h d)")
|
146 |
+
attn_out = self.o_proj(attn_out)
|
147 |
+
return attn_out
|
test-train.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.cuda.amp import autocast
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from tqdm import tqdm
|
6 |
+
import math, os, sys, json, glob, time, random
|
7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
from distributed_shampoo import AdamGraftingConfig, DistributedShampoo
|
10 |
+
from cut_cross_entropy import linear_cross_entropy
|
11 |
+
from torch.nn.utils import clip_grad_norm_
|
12 |
+
from utils.trainutils import count_parameters_layerwise, save_checkpoint, TBLogger
|
13 |
+
|
14 |
+
from llama_modeling.front_end import LlamaForCausalLM
|
15 |
+
from llama_modeling.config import LlamaConfig
|
16 |
+
|
17 |
+
class JSONLDataset(Dataset):
|
18 |
+
def __init__(self, directory_path, tokenizer, seq_length=1024,
|
19 |
+
text_key="text", max_files=None, batch_size=1000,
|
20 |
+
pad_token_id=0):
|
21 |
+
self.seq_length = seq_length
|
22 |
+
self.tokenizer = tokenizer
|
23 |
+
self.pad_token_id = pad_token_id
|
24 |
+
self.sequences = []
|
25 |
+
|
26 |
+
files = glob.glob(os.path.join(directory_path, "*.jsonl"))
|
27 |
+
if max_files is not None:
|
28 |
+
files = files[:max_files]
|
29 |
+
|
30 |
+
text_batch = []
|
31 |
+
for file_idx, file_path in enumerate(files):
|
32 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
33 |
+
for line in f:
|
34 |
+
try:
|
35 |
+
data = json.loads(line)
|
36 |
+
text = data.get(text_key, "")
|
37 |
+
if len(text) >= 100:
|
38 |
+
text_batch.append(text)
|
39 |
+
|
40 |
+
if len(text_batch) >= batch_size:
|
41 |
+
self._process_batch(text_batch)
|
42 |
+
text_batch = []
|
43 |
+
except:
|
44 |
+
continue
|
45 |
+
|
46 |
+
if text_batch:
|
47 |
+
self._process_batch(text_batch)
|
48 |
+
|
49 |
+
if self.sequences:
|
50 |
+
self.sequences = torch.tensor(self.sequences, dtype=torch.long)
|
51 |
+
else:
|
52 |
+
self.sequences = torch.empty((0, seq_length), dtype=torch.long)
|
53 |
+
|
54 |
+
def _process_batch(self, texts):
|
55 |
+
encoded = self.tokenizer(
|
56 |
+
texts,
|
57 |
+
add_special_tokens=False,
|
58 |
+
truncation=True,
|
59 |
+
padding=False,
|
60 |
+
return_attention_mask=False,
|
61 |
+
return_tensors=None
|
62 |
+
)['input_ids']
|
63 |
+
|
64 |
+
mlen = 0
|
65 |
+
for token_ids in encoded:
|
66 |
+
for i in range(0, len(token_ids), self.seq_length):
|
67 |
+
chunk = token_ids[i:i+self.seq_length]
|
68 |
+
|
69 |
+
# Pad
|
70 |
+
if len(chunk) < self.seq_length:
|
71 |
+
chunk += [self.pad_token_id] * (self.seq_length - len(chunk))
|
72 |
+
|
73 |
+
self.sequences.append(chunk)
|
74 |
+
mlen = max(mlen, len(chunk))
|
75 |
+
|
76 |
+
print("MAX: ", mlen)
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return len(self.sequences)
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
return self.sequences[idx]
|
83 |
+
|
84 |
+
def train_model(model, train_loader, optimizer, device, epochs=5, forward_dtype=torch.float32):
|
85 |
+
model.train()
|
86 |
+
criterion = nn.CrossEntropyLoss()
|
87 |
+
scaler = torch.amp.GradScaler("cuda")
|
88 |
+
|
89 |
+
logger = TBLogger(log_dir=f'logs/run-{time.time()}')
|
90 |
+
|
91 |
+
total_steps = len(train_loader) * epochs
|
92 |
+
scheduler = CosineAnnealingLR(
|
93 |
+
optimizer,
|
94 |
+
T_max=total_steps,
|
95 |
+
eta_min=5e-6
|
96 |
+
)
|
97 |
+
|
98 |
+
model = torch.compile(
|
99 |
+
model,
|
100 |
+
)
|
101 |
+
|
102 |
+
global_step = 0
|
103 |
+
for epoch in range(epochs):
|
104 |
+
running_loss = 0.0
|
105 |
+
total_batches = 0
|
106 |
+
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
|
107 |
+
|
108 |
+
for batch_idx, data in enumerate(progress_bar):
|
109 |
+
data = data.to(device)
|
110 |
+
optimizer.zero_grad(set_to_none=True)
|
111 |
+
|
112 |
+
with torch.autocast(device_type='cuda', dtype=forward_dtype):
|
113 |
+
hidden_states, classifier_weights = model(data)
|
114 |
+
|
115 |
+
loss = linear_cross_entropy(
|
116 |
+
hidden_states,
|
117 |
+
classifier_weights,
|
118 |
+
data,
|
119 |
+
shift=True,
|
120 |
+
reduction="mean"
|
121 |
+
)
|
122 |
+
|
123 |
+
scaler.scale(loss).backward()
|
124 |
+
scaler.unscale_(optimizer)
|
125 |
+
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
126 |
+
scaler.step(optimizer)
|
127 |
+
scaler.update()
|
128 |
+
scheduler.step()
|
129 |
+
|
130 |
+
# Update metrics - just add the loss itself
|
131 |
+
running_loss += loss.item()
|
132 |
+
total_batches += 1
|
133 |
+
global_step += 1
|
134 |
+
avg_loss = running_loss / total_batches
|
135 |
+
perplexity = math.exp(min(avg_loss, 100))
|
136 |
+
|
137 |
+
progress_bar.set_postfix({
|
138 |
+
'loss': f'{avg_loss:.4f}',
|
139 |
+
'ppl': f'{perplexity:.2f}'
|
140 |
+
})
|
141 |
+
|
142 |
+
metrics = {
|
143 |
+
'loss': loss.item(),
|
144 |
+
'perplexity': perplexity,
|
145 |
+
'learning_rate': optimizer.param_groups[0]['lr'],
|
146 |
+
'batch_size': data.size(0)
|
147 |
+
}
|
148 |
+
|
149 |
+
logger.log(metrics, step=global_step, model=model, grad_checking=True)
|
150 |
+
|
151 |
+
if batch_idx % 100 == 0:
|
152 |
+
print(f'\nBatch {batch_idx}/{len(train_loader)}: '
|
153 |
+
f'Loss: {avg_loss:.4f}, '
|
154 |
+
f'Perplexity: {perplexity:.2f}, '
|
155 |
+
f'Batches Processed: {total_batches}')
|
156 |
+
|
157 |
+
epoch_loss = running_loss / total_batches
|
158 |
+
epoch_ppl = math.exp(min(epoch_loss, 100))
|
159 |
+
print(f'\nEpoch {epoch+1} Summary:')
|
160 |
+
print(f'Average Loss: {epoch_loss:.4f}')
|
161 |
+
print(f'Perplexity: {epoch_ppl:.2f}')
|
162 |
+
print(f'Total Batches Processed: {total_batches}\n')
|
163 |
+
|
164 |
+
save_checkpoint(model, f'epoch_{epoch+1}.safetensors')
|
165 |
+
|
166 |
+
def sample_examples(dataset, tokenizer, num_samples=5):
|
167 |
+
if len(dataset) == 0:
|
168 |
+
print("The dataset is empty.")
|
169 |
+
return
|
170 |
+
|
171 |
+
num_samples = min(num_samples, len(dataset))
|
172 |
+
|
173 |
+
sampled_indices = random.sample(range(len(dataset)), num_samples)
|
174 |
+
|
175 |
+
for i, idx in enumerate(sampled_indices):
|
176 |
+
sequence = dataset[idx]
|
177 |
+
print(f"Sample {i + 1} (Index {idx}):")
|
178 |
+
print(sequence)
|
179 |
+
decoded_text = tokenizer.decode(sequence, skip_special_tokens=False, decode_special_tokens=False)
|
180 |
+
print(decoded_text)
|
181 |
+
print("-" * 40)
|
182 |
+
|
183 |
+
def main():
|
184 |
+
BATCH_SIZE = 36
|
185 |
+
SEQ_LENGTH = 512
|
186 |
+
EPOCHS = 3
|
187 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
188 |
+
tokenizer = AutoTokenizer.from_pretrained("./SmolLM2-135M-Instruct")
|
189 |
+
|
190 |
+
config_path = "config.json"
|
191 |
+
with open(config_path) as f:
|
192 |
+
config_dict = json.load(f)
|
193 |
+
config = LlamaConfig(**{k: v for k, v in config_dict.items() if k in LlamaConfig.__dataclass_fields__})
|
194 |
+
|
195 |
+
model = LlamaForCausalLM(config).to("cuda")
|
196 |
+
|
197 |
+
dataset = JSONLDataset(
|
198 |
+
directory_path="./Data_big",
|
199 |
+
tokenizer=tokenizer,
|
200 |
+
seq_length=SEQ_LENGTH,
|
201 |
+
text_key="text",
|
202 |
+
max_files=None,
|
203 |
+
)
|
204 |
+
|
205 |
+
train_loader = DataLoader(
|
206 |
+
dataset,
|
207 |
+
batch_size=BATCH_SIZE,
|
208 |
+
shuffle=True,
|
209 |
+
num_workers=4,
|
210 |
+
pin_memory=True,
|
211 |
+
drop_last=True
|
212 |
+
)
|
213 |
+
|
214 |
+
optimizer = DistributedShampoo(
|
215 |
+
model.parameters(),
|
216 |
+
lr=0.0001,
|
217 |
+
betas=(0.9, 0.999),
|
218 |
+
epsilon=1e-12,
|
219 |
+
weight_decay=1e-05,
|
220 |
+
max_preconditioner_dim=2048,
|
221 |
+
precondition_frequency=100,
|
222 |
+
start_preconditioning_step=250,
|
223 |
+
use_decoupled_weight_decay=False,
|
224 |
+
grafting_config=AdamGraftingConfig(
|
225 |
+
beta2=0.999,
|
226 |
+
epsilon=1e-12,
|
227 |
+
),
|
228 |
+
)
|
229 |
+
|
230 |
+
print("*"*100)
|
231 |
+
torch.set_float32_matmul_precision('high')
|
232 |
+
|
233 |
+
count_parameters_layerwise(model)
|
234 |
+
|
235 |
+
train_model(model, train_loader, optimizer, DEVICE, EPOCHS, forward_dtype=torch.bfloat16)
|
236 |
+
|
237 |
+
if __name__ == "__main__":
|
238 |
+
main()
|
utils/__init__.py
ADDED
File without changes
|
utils/trainutils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from tensorboardX import SummaryWriter
|
5 |
+
from safetensors.torch import save_file, load_file
|
6 |
+
from pathlib import Path
|
7 |
+
import time
|
8 |
+
|
9 |
+
def count_parameters_layerwise(model):
|
10 |
+
# Layerwise params, turn this into a util function.
|
11 |
+
total_params = 0
|
12 |
+
layer_params = {}
|
13 |
+
|
14 |
+
for name, parameter in model.named_parameters():
|
15 |
+
if not parameter.requires_grad:
|
16 |
+
continue
|
17 |
+
|
18 |
+
param_count = parameter.numel()
|
19 |
+
layer_params[name] = param_count
|
20 |
+
total_params += param_count
|
21 |
+
|
22 |
+
print(f"\nModel Parameter Summary:")
|
23 |
+
print("-" * 60)
|
24 |
+
for name, count in layer_params.items():
|
25 |
+
print(f"{name}: {count:,} parameters")
|
26 |
+
print("-" * 60)
|
27 |
+
print(f"Total Trainable Parameters: {total_params:,}\n")
|
28 |
+
|
29 |
+
return total_params
|
30 |
+
|
31 |
+
def save_checkpoint(model, filename="checkpoint.safetensors"):
|
32 |
+
if hasattr(model, '_orig_mod'):
|
33 |
+
model = model._orig_mod
|
34 |
+
|
35 |
+
torch.save(model.state_dict(), filename.replace('.safetensors', '.pt'))
|
36 |
+
|
37 |
+
def load_checkpoint(model, filename="checkpoint.safetensors"):
|
38 |
+
if hasattr(model, '_orig_mod'):
|
39 |
+
model = model._orig_mod
|
40 |
+
|
41 |
+
try:
|
42 |
+
model_state = load_file(filename)
|
43 |
+
model.load_state_dict(model_state)
|
44 |
+
except Exception as e:
|
45 |
+
model_state = torch.load(filename.replace('.safetensors', '.pt'), weights_only=True)
|
46 |
+
model.load_state_dict(model_state)
|
47 |
+
|
48 |
+
class TBLogger:
|
49 |
+
def __init__(self, log_dir='logs/current_run', flush_secs=10, enable_grad_logging=True):
|
50 |
+
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
51 |
+
self.writer = SummaryWriter(log_dir, flush_secs=flush_secs)
|
52 |
+
self.enable_grad_logging = enable_grad_logging
|
53 |
+
self.start_time = time.time()
|
54 |
+
|
55 |
+
def log(self, metrics, step=None, model=None, prefix='', grad_checking=False):
|
56 |
+
for name, value in metrics.items():
|
57 |
+
full_name = f"{prefix}{name}" if prefix else name
|
58 |
+
|
59 |
+
if isinstance(value, (int, float)):
|
60 |
+
self.writer.add_scalar(full_name, value, step)
|
61 |
+
elif isinstance(value, torch.Tensor):
|
62 |
+
self.writer.add_scalar(full_name, value.item(), step)
|
63 |
+
elif isinstance(value, (list, tuple)) and len(value) > 0:
|
64 |
+
if all(isinstance(x, (int, float)) for x in value):
|
65 |
+
self.writer.add_histogram(full_name, torch.tensor(value), step)
|
66 |
+
|
67 |
+
if self.enable_grad_logging and model is not None:
|
68 |
+
self._log_gradients(model, step, grad_checking)
|
69 |
+
|
70 |
+
def _log_gradients(self, model, step, grad_checking):
|
71 |
+
total_norm = 0.0
|
72 |
+
for name, param in model.named_parameters():
|
73 |
+
if grad_checking and param.grad is not None:
|
74 |
+
# Check for inf/nan in gradients
|
75 |
+
if torch.isnan(param.grad).any():
|
76 |
+
print(f"Warning: Found nan in gradients for layer: {name}")
|
77 |
+
continue
|
78 |
+
if torch.isinf(param.grad).any():
|
79 |
+
print(f"Warning: Found inf in gradients for layer: {name}")
|
80 |
+
continue
|
81 |
+
|
82 |
+
param_norm = param.grad.detach().data.norm(2)
|
83 |
+
self.writer.add_scalar(f"gradients/{name}_norm", param_norm, step)
|
84 |
+
total_norm += param_norm.item() ** 2
|
85 |
+
|
86 |
+
# Only compute total norm if we haven't encountered inf/nan
|
87 |
+
if total_norm > 0: # This means we had valid gradients
|
88 |
+
total_norm = total_norm ** 0.5
|
89 |
+
self.writer.add_scalar("gradients/total_norm", total_norm, step)
|
90 |
+
|
91 |
+
def close(self):
|
92 |
+
self.writer.close()
|