Blackroot commited on
Commit
6aced58
·
verified ·
1 Parent(s): f2c8e64

Upload 18 files

Browse files
epoch3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:849af991539dcc0d5b278fb81e96e2e72d005b5d671b024e18573c30ea51f676
3
+ size 507810914
inference.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from llama_modeling.front_end import LlamaForCausalLM
4
+ from llama_modeling.config import LlamaConfig
5
+ import json
6
+ import sys
7
+ from utils.trainutils import load_checkpoint
8
+
9
+ def generate_text(model, tokenizer, prompt, max_new_tokens=30):
10
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to("cuda")
11
+
12
+ with torch.inference_mode():
13
+ outputs = model.generate(
14
+ input_ids,
15
+ max_new_tokens=max_new_tokens,
16
+ temperature=0.7
17
+ )
18
+
19
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
20
+
21
+ def main():
22
+ if len(sys.argv) != 2:
23
+ print("Usage: python inference.py <path_to_model>")
24
+ sys.exit(1)
25
+
26
+ model_path = sys.argv[1]
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ with open("config.json") as f:
30
+ config_dict = json.load(f)
31
+ config = LlamaConfig(**{k: v for k, v in config_dict.items() if k in LlamaConfig.__dataclass_fields__})
32
+
33
+ model = LlamaForCausalLM(config).to(device)
34
+
35
+ load_checkpoint(model, model_path)
36
+ model.eval()
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained("./SmolLM2-135M-Instruct")
39
+
40
+ prompts = [
41
+ "Once upon a time,",
42
+ "The best way to learn programming is",
43
+ "Here's a recipe for chocolate cake:"
44
+ ]
45
+
46
+ with torch.no_grad(), torch.autocast(device_type='cuda', dtype=None):
47
+ for prompt in prompts:
48
+ print(f"\nPrompt: {prompt}")
49
+ output = generate_text(model, tokenizer, prompt)
50
+ print(f"Generated: {output}")
51
+ print("-" * 50)
52
+
53
+ if __name__ == "__main__":
54
+ main()
llama_modeling/__init__.py ADDED
File without changes
llama_modeling/attention.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flash_attn import flash_attn_func
2
+ from typing import Optional, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from .liger_rope import LigerRopeFunction
8
+ from .config import LlamaConfig
9
+
10
+ class LlamaAttention(nn.Module):
11
+ def __init__(self, config: LlamaConfig):
12
+ super().__init__()
13
+ self.hidden_size = config.hidden_size
14
+ self.num_heads = config.num_attention_heads
15
+ self.num_key_value_heads = config.num_key_value_heads
16
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
17
+ self.head_dim = config.hidden_size // config.num_attention_heads
18
+ self.max_position_embeddings = config.max_position_embeddings
19
+ self.rope_theta = config.rope_theta
20
+
21
+ if (self.head_dim * self.num_heads) != self.hidden_size:
22
+ raise ValueError(
23
+ f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size}"
24
+ f" and `num_attention_heads`: {self.num_heads})."
25
+ )
26
+
27
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
28
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
29
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
30
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
31
+
32
+ self.register_buffer(
33
+ "cos_cached",
34
+ self._compute_rope_embeddings(
35
+ self.max_position_embeddings,
36
+ self.head_dim,
37
+ self.rope_theta,
38
+ dtype=torch.float32,
39
+ device=self.q_proj.weight.device,
40
+ )[0],
41
+ persistent=False,
42
+ )
43
+ self.register_buffer(
44
+ "sin_cached",
45
+ self._compute_rope_embeddings(
46
+ self.max_position_embeddings,
47
+ self.head_dim,
48
+ self.rope_theta,
49
+ dtype=torch.float32,
50
+ device=self.q_proj.weight.device,
51
+ )[1],
52
+ persistent=False,
53
+ )
54
+
55
+ def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
56
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
57
+ t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
58
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
59
+ emb = torch.cat((freqs, freqs), dim=-1)
60
+ cos = emb.cos().to(dtype)
61
+ sin = emb.sin().to(dtype)
62
+ return cos.unsqueeze(0), sin.unsqueeze(0)
63
+
64
+ def forward(
65
+ self,
66
+ hidden_states: torch.Tensor,
67
+ attention_mask: Optional[torch.Tensor] = None,
68
+ position_ids: Optional[torch.LongTensor] = None,
69
+ ) -> torch.Tensor:
70
+ # In B S (H D)
71
+ bsz, seq_len, _ = hidden_states.size()
72
+
73
+ if position_ids is None:
74
+ position_ids = torch.arange(seq_len, device=hidden_states.device)
75
+ position_ids = repeat(position_ids, 'l -> b l', b=bsz)
76
+
77
+ query_states = self.q_proj(hidden_states)
78
+ key_states = self.k_proj(hidden_states)
79
+ value_states = self.v_proj(hidden_states)
80
+
81
+ query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.num_heads, d=self.head_dim)
82
+ key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)
83
+ value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)
84
+
85
+ # Slice off position specific rope freqs from the cached freqs
86
+ cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
87
+ sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
88
+
89
+ query_states, key_states = LigerRopeFunction.apply(
90
+ query_states,
91
+ key_states,
92
+ cos.squeeze(0),
93
+ sin.squeeze(0),
94
+ position_ids
95
+ )
96
+
97
+ attn_output = flash_attn_func(
98
+ query_states,
99
+ key_states,
100
+ value_states,
101
+ dropout_p=0.0,
102
+ causal=attention_mask is None
103
+ )
104
+
105
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
106
+ return self.o_proj(attn_output)
llama_modeling/config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class LlamaConfig:
5
+ hidden_size: int = 576
6
+ num_attention_heads: int = 16
7
+ num_key_value_heads: int = 4
8
+ num_hidden_layers: int = 30
9
+ intermediate_size: int = 1536
10
+ hidden_act: str = "silu"
11
+ rms_norm_eps: float = 1e-5
12
+ vocab_size: int = 49152
13
+ max_position_embeddings: int = 8192
14
+ rope_theta: int = 100000
15
+ tie_word_embeddings: bool = False
llama_modeling/decoder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .mlp import LlamaMLP
7
+ from .config import LlamaConfig
8
+ from .rms_norm import LlamaRMSNorm
9
+ from .attention import LlamaAttention
10
+ from .diff_attn import DifferentialAttention
11
+ from .tensor_prod_attn import CausalTensorProductSelfAttn
12
+
13
+ class LlamaDecoderLayer(nn.Module):
14
+ def __init__(self, config: LlamaConfig, layer_num):
15
+ super().__init__()
16
+ self.self_attn = CausalTensorProductSelfAttn(config)
17
+ self.mlp = LlamaMLP(config)
18
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
19
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
20
+
21
+ def forward(
22
+ self,
23
+ hidden_states: torch.Tensor,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ ) -> torch.Tensor:
27
+
28
+ residual = hidden_states
29
+ hidden_states = self.input_layernorm(hidden_states)
30
+ hidden_states = self.self_attn(
31
+ hidden_states=hidden_states,
32
+ attention_mask=attention_mask,
33
+ position_ids=position_ids,
34
+ )
35
+ hidden_states = residual + hidden_states
36
+
37
+ residual = hidden_states
38
+ hidden_states = self.post_attention_layernorm(hidden_states)
39
+ hidden_states = self.mlp(hidden_states)
40
+ hidden_states = residual + hidden_states
41
+
42
+ return hidden_states
llama_modeling/diff_attn.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flash_attn import flash_attn_func
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+
8
+ from .extact import xATGLU
9
+ from .liger_rope import LigerRopeFunction
10
+ from .config import LlamaConfig
11
+
12
+ # The four-flash attn strategy comes from here:
13
+ # https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_2.py
14
+
15
+ class DifferentialAttention(nn.Module):
16
+ def __init__(self, config: LlamaConfig, layer_num):
17
+ super().__init__()
18
+ self.hidden_size = config.hidden_size
19
+ self.num_heads = config.num_attention_heads
20
+ self.num_kv_heads = config.num_key_value_heads
21
+ self.n_rep = self.num_heads // self.num_kv_heads
22
+ self.head_dim = self.hidden_size // (2 * self.num_heads)
23
+ self.max_position_embeddings = config.max_position_embeddings
24
+ self.rope_theta = config.rope_theta
25
+ self.scaling = self.head_dim ** -0.5
26
+
27
+ self.q_proj = nn.Linear(self.hidden_size, 2 * self.num_heads * self.head_dim, bias=False)
28
+ self.k_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False)
29
+ self.v_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False)
30
+ self.o_proj = nn.Linear(2 * self.num_heads * self.head_dim, self.hidden_size, bias=False)
31
+
32
+ self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_num)
33
+ self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
34
+ self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
35
+ self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
36
+ self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
37
+
38
+ self.subln = nn.LayerNorm(2 * self.head_dim, elementwise_affine=False)
39
+
40
+ self.register_buffer(
41
+ "cos_cached",
42
+ self._compute_rope_embeddings(
43
+ self.max_position_embeddings,
44
+ self.head_dim,
45
+ self.rope_theta,
46
+ dtype=torch.float32,
47
+ device=self.q_proj.weight.device,
48
+ )[0],
49
+ persistent=False,
50
+ )
51
+ self.register_buffer(
52
+ "sin_cached",
53
+ self._compute_rope_embeddings(
54
+ self.max_position_embeddings,
55
+ self.head_dim,
56
+ self.rope_theta,
57
+ dtype=torch.float32,
58
+ device=self.q_proj.weight.device,
59
+ )[1],
60
+ persistent=False,
61
+ )
62
+
63
+ def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
64
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
65
+ t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
66
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
67
+ emb = torch.cat((freqs, freqs), dim=-1)
68
+ cos = emb.cos().to(dtype)
69
+ sin = emb.sin().to(dtype)
70
+ return cos.unsqueeze(0), sin.unsqueeze(0)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states,
75
+ attention_mask,
76
+ position_ids,
77
+ ) -> torch.Tensor:
78
+ bsz, seq_len, embed_dim = hidden_states.size()
79
+
80
+ if position_ids is None:
81
+ position_ids = torch.arange(seq_len, device=hidden_states.device)
82
+ position_ids = repeat(position_ids, 'l -> b l', b=bsz)
83
+
84
+ q = self.q_proj(hidden_states)
85
+ k = self.k_proj(hidden_states)
86
+ v = self.v_proj(hidden_states)
87
+
88
+ q = rearrange(q, 'b s (h d) -> b s h d', h=2*self.num_heads, d=self.head_dim)
89
+ k = rearrange(k, 'b s (h d) -> b s h d', h=2*self.num_kv_heads, d=self.head_dim)
90
+
91
+ # Reshaped for GQA
92
+ v = rearrange(v, 'b s (h g d) -> b s h g d', h=self.num_kv_heads, g=2, d=self.head_dim)
93
+
94
+ # Apply rotary embeddings using LigerRopeFunction
95
+ cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
96
+ sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
97
+ q, k = LigerRopeFunction.apply(q, k, cos, sin, position_ids)
98
+
99
+ # Rearrange into GQA style
100
+ q = rearrange(q, 'b s (h g) d -> b s h g d', h=self.num_heads, g=2)
101
+ k = rearrange(k, 'b s (h g) d -> b s h g d', h=self.num_kv_heads, g=2)
102
+
103
+ q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
104
+ k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
105
+ v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
106
+
107
+ # First attention group on q1/k1 and the v's
108
+ attn11 = flash_attn_func(
109
+ q1,
110
+ k1,
111
+ v1,
112
+ dropout_p=0.0, # @Z TODO::
113
+ causal=attention_mask is None
114
+ )
115
+ attn12 = flash_attn_func(
116
+ q1,
117
+ k1,
118
+ v2,
119
+ dropout_p=0.0,
120
+ causal=attention_mask is None
121
+ )
122
+ attn1 = torch.cat([attn11, attn12], dim=-1)
123
+
124
+ # Second attention group on q2/k2 and the v's
125
+ attn21 = flash_attn_func(
126
+ q2,
127
+ k2,
128
+ v1,
129
+ dropout_p=0.0,
130
+ causal=attention_mask is None
131
+ )
132
+ attn22 = flash_attn_func(
133
+ q2,
134
+ k2,
135
+ v2,
136
+ dropout_p=0.0,
137
+ causal=attention_mask is None
138
+ )
139
+ attn2 = torch.cat([attn21, attn22], dim=-1)
140
+
141
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
142
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
143
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
144
+ attn = attn1 - lambda_full * attn2
145
+
146
+ attn = self.subln(attn)
147
+ attn = attn * (1 - self.lambda_init)
148
+
149
+ attn_output = rearrange(attn, "b s h d -> b s (h d)")
150
+ return self.o_proj(attn_output)
llama_modeling/extact.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn.
6
+ class xATGLU(nn.Module):
7
+ def __init__(self, input_dim, output_dim, bias=True):
8
+ super().__init__()
9
+ # GATE path | VALUE path
10
+ self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias)
11
+ nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear')
12
+
13
+ self.alpha = nn.Parameter(torch.zeros(1))
14
+ self.half_pi = torch.pi / 2
15
+ self.inv_pi = 1 / torch.pi
16
+
17
+ def forward(self, x):
18
+ projected = self.proj(x)
19
+ gate_path, value_path = projected.chunk(2, dim=-1)
20
+
21
+ # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768
22
+ gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi
23
+ expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha
24
+
25
+ return expanded_gate * value_path # g(x) × y
llama_modeling/front_end.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+ from .config import LlamaConfig
8
+ from .model import LlamaModel
9
+
10
+ class LlamaForCausalLM(nn.Module):
11
+ def __init__(self, config: LlamaConfig):
12
+ super().__init__()
13
+ self.model = LlamaModel(config)
14
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
15
+
16
+ # Weight tying uses the head weights as the classifier for the token embeddings for both in and out.
17
+ if config.tie_word_embeddings:
18
+ self.lm_head.weight = self.model.embed_tokens.weight
19
+
20
+ self._init_weights()
21
+
22
+ def _init_weights(self):
23
+ """Initialize weights for all layers."""
24
+ # Initialize embeddings
25
+ if hasattr(self.model, 'embed_tokens'):
26
+ nn.init.normal_(self.model.embed_tokens.weight, mean=0.0, std=0.041666666666666664)
27
+
28
+ # Initialize linear layers
29
+ for module in self.modules():
30
+ if isinstance(module, nn.Linear):
31
+ # Xavier/Glorot initialization for weights
32
+ nn.init.xavier_uniform_(module.weight)
33
+ if module.bias is not None:
34
+ # Zero initialization for biases
35
+ nn.init.zeros_(module.bias)
36
+
37
+ def forward(
38
+ self,
39
+ input_ids: torch.LongTensor,
40
+ attention_mask: Optional[torch.Tensor] = None,
41
+ position_ids: Optional[torch.LongTensor] = None,
42
+ labels: Optional[torch.LongTensor] = None,
43
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
44
+ hidden_states = self.model(
45
+ input_ids,
46
+ attention_mask=attention_mask,
47
+ position_ids=position_ids,
48
+ )
49
+
50
+ return hidden_states, self.lm_head.weight
51
+
52
+ @torch.no_grad()
53
+ def generate(
54
+ self,
55
+ input_ids: torch.LongTensor,
56
+ max_new_tokens: int = 30,
57
+ temperature: float = 0.0,
58
+ ) -> torch.LongTensor:
59
+ self.eval()
60
+ bsz, seq_len = input_ids.shape
61
+
62
+ position_ids = repeat(
63
+ torch.arange(seq_len, device=input_ids.device),
64
+ 'l -> b l',
65
+ b=bsz
66
+ )
67
+
68
+ for _ in range(max_new_tokens):
69
+ hidden_states, classifier_weights = self.forward(input_ids, position_ids=position_ids)
70
+
71
+ # Get logits by computing hidden_states @ classifier_weights.T
72
+ next_token_logits = hidden_states[:, -1] @ classifier_weights.T
73
+
74
+ if temperature == 0:
75
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
76
+ else:
77
+ scaled_logits = next_token_logits / temperature
78
+ probs = torch.softmax(scaled_logits, dim=-1)
79
+ next_token = torch.multinomial(probs, num_samples=1)
80
+
81
+ input_ids = torch.cat([input_ids, next_token], dim=1)
82
+ new_position_ids = position_ids[:, -1:] + 1
83
+ position_ids = torch.cat([position_ids, new_position_ids], dim=1)
84
+
85
+ return input_ids
llama_modeling/liger_rope.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py
6
+ # BSD 2-CLAUSE LICENSE
7
+ # Copyright 2024 LinkedIn Corporation
8
+ # All Rights Reserved.
9
+ # Redistribution and use in source and binary forms, with or
10
+ # without modification, are permitted provided that the following
11
+ # conditions are met:
12
+ # 1. Redistributions of source code must retain the above copyright
13
+ # notice, this list of conditions and the following disclaimer.
14
+ # 2. Redistributions in binary form must reproduce the above
15
+ # copyright notice, this list of conditions and the following
16
+ # disclaimer in the documentation and/or other materials provided
17
+ # with the distribution.
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22
+ # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23
+ # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24
+ # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26
+ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+
30
+ @triton.jit
31
+ def _triton_rope(
32
+ q_ptr,
33
+ q_row_stride,
34
+ k_ptr,
35
+ k_row_stride,
36
+ cos,
37
+ cos_row_stride,
38
+ sin,
39
+ sin_row_stride,
40
+ sl,
41
+ bs: tl.constexpr,
42
+ cos_bs: tl.constexpr,
43
+ n_qh: tl.constexpr,
44
+ n_kh: tl.constexpr,
45
+ hd: tl.constexpr,
46
+ pad_n_qh: tl.constexpr,
47
+ pad_n_kh: tl.constexpr,
48
+ pad_hd: tl.constexpr,
49
+ BLOCK_SIZE: tl.constexpr,
50
+ BACKWARD_PASS: tl.constexpr = False,
51
+ ):
52
+ # q size: (bsz, seq_len, num_q_heads, head_dim)
53
+ # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
54
+ # k size: (bsz, seq_len, num_kv_heads, head_dim)
55
+ # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
56
+
57
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
58
+ # stride: (seq_len * head_dim, head_dim, 1)
59
+ pid = tl.program_id(0)
60
+
61
+ # locate start address
62
+ q_ptr = q_ptr + pid * q_row_stride
63
+ k_ptr = k_ptr + pid * k_row_stride
64
+
65
+ # ####################################################################
66
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
67
+ # m of this program instance
68
+ # ####################################################################
69
+
70
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
71
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
72
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
73
+ # and pid % sl to get the sequence index.
74
+ # 2. We only need the left half of cos and sin matrix because the right half is just
75
+ # a clone of the left half.
76
+ batch_idx = pid // sl
77
+ cos_row_idx = pid % sl
78
+ cos = cos + tl.where(
79
+ cos_bs == 1,
80
+ cos_row_idx * cos_row_stride,
81
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
82
+ )
83
+ sin = sin + tl.where(
84
+ cos_bs == 1,
85
+ cos_row_idx * sin_row_stride,
86
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
87
+ )
88
+
89
+ cos_offsets = tl.arange(0, pad_hd // 2)
90
+ cos_mask = cos_offsets < hd // 2
91
+ cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
92
+ sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
93
+
94
+ # ####################################################################
95
+ # Load the left and right half of q and k for the current
96
+ # program instance (i.e. for the current token) separately
97
+ # ####################################################################
98
+ # left half of the head
99
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
100
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
101
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
102
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
103
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
104
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
105
+
106
+ # right half of the head
107
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
108
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
109
+ second_q_mask = first_q_mask
110
+ second_k_mask = first_k_mask
111
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
112
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
113
+
114
+ if not BACKWARD_PASS:
115
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
116
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
117
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
118
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
119
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
120
+
121
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
122
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
123
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
124
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
125
+ else:
126
+ # with some math, we can get:
127
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
128
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
129
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
130
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
131
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
132
+
133
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
134
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
135
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
136
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
137
+
138
+
139
+ def rope_forward(q, k, cos, sin):
140
+ # transpose it back to the physical shape because Triton looks at the physical storage
141
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
142
+ batch_size, seq_len, n_q_head, head_dim = q.shape
143
+ n_kv_head = k.shape[2]
144
+
145
+ pad_hd = triton.next_power_of_2(head_dim)
146
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
147
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
148
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
149
+
150
+ n_row = batch_size * seq_len
151
+
152
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
153
+ q = q.contiguous()
154
+ k = k.contiguous()
155
+ cos = cos.contiguous()
156
+ sin = sin.contiguous()
157
+ cos_batch_size = cos.shape[0]
158
+
159
+ _triton_rope[(n_row,)](
160
+ q,
161
+ q.stride(1),
162
+ k,
163
+ k.stride(1),
164
+ cos,
165
+ cos.stride(-2),
166
+ sin,
167
+ sin.stride(-2),
168
+ seq_len,
169
+ batch_size,
170
+ cos_batch_size,
171
+ n_q_head,
172
+ n_kv_head,
173
+ head_dim,
174
+ pad_n_q_head,
175
+ pad_n_kv_head,
176
+ pad_hd,
177
+ BLOCK_SIZE=BLOCK_SIZE,
178
+ BACKWARD_PASS=False,
179
+ )
180
+ return q, k, cos, sin
181
+
182
+
183
+ def rope_backward(dq, dk, cos, sin):
184
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
185
+ cos_batch_size = cos.shape[0]
186
+ n_kv_head = dk.shape[2]
187
+ pad_hd = triton.next_power_of_2(head_dim)
188
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
189
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
190
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
191
+
192
+ n_row = batch_size * seq_len
193
+
194
+ # ensure dq and dk are contiguous
195
+ dq = dq.contiguous()
196
+ dk = dk.contiguous()
197
+
198
+ # backward is similar to forward except swapping few ops
199
+ _triton_rope[(n_row,)](
200
+ dq,
201
+ dq.stride(1),
202
+ dk,
203
+ dk.stride(1),
204
+ cos,
205
+ cos.stride(-2),
206
+ sin,
207
+ sin.stride(-2),
208
+ seq_len,
209
+ batch_size,
210
+ cos_batch_size,
211
+ n_q_head,
212
+ n_kv_head,
213
+ head_dim,
214
+ pad_n_q_head,
215
+ pad_n_kv_head,
216
+ pad_hd,
217
+ BLOCK_SIZE=BLOCK_SIZE,
218
+ BACKWARD_PASS=True,
219
+ )
220
+ return dq, dk
221
+
222
+
223
+ class LigerRopeFunction(torch.autograd.Function):
224
+ """
225
+ Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
226
+ this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
227
+ than the original RoPE paper.
228
+
229
+ Please find the corresponding HuggingFace implementation here:
230
+ https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
231
+
232
+ For more details about the rotation matrix used here, please refer to:
233
+ https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
234
+ """
235
+
236
+ @staticmethod
237
+ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
238
+ """
239
+ q size: (bsz, n_q_head, seq_len, head_dim)
240
+ k size: (bsz, n_kv_head, seq_len, head_dim)
241
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
242
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
243
+ """
244
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
245
+ ctx.save_for_backward(cos, sin)
246
+ return q, k
247
+
248
+ def backward(ctx, dq, dk):
249
+ """
250
+ dq size: (bsz, n_q_head, seq_len, head_dim)
251
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
252
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
253
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
254
+ """
255
+
256
+ cos, sin = ctx.saved_tensors
257
+ dq, dk = rope_backward(dq, dk, cos, sin)
258
+ return dq, dk, None, None, None, None
llama_modeling/mlp.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .config import LlamaConfig
6
+
7
+ class LlamaMLP(nn.Module):
8
+ def __init__(self, config: LlamaConfig):
9
+ super().__init__()
10
+ self.hidden_size = config.hidden_size
11
+ self.intermediate_size = config.intermediate_size
12
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
13
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
14
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
15
+ self.act_fn = nn.SiLU()
16
+
17
+ def forward(self, x):
18
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
llama_modeling/model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .mlp import LlamaMLP
7
+ from .config import LlamaConfig
8
+ from .rms_norm import LlamaRMSNorm
9
+ from .decoder import LlamaDecoderLayer
10
+
11
+ class LlamaModel(nn.Module):
12
+ def __init__(self, config: LlamaConfig):
13
+ super().__init__()
14
+ self.config = config
15
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=None)
16
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
17
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
18
+
19
+ def forward(
20
+ self,
21
+ input_ids: torch.LongTensor,
22
+ attention_mask: Optional[torch.Tensor] = None,
23
+ position_ids: Optional[torch.LongTensor] = None,
24
+ ) -> torch.Tensor:
25
+ hidden_states = self.embed_tokens(input_ids)
26
+
27
+ for decoder_layer in self.layers:
28
+ hidden_states = decoder_layer(
29
+ hidden_states,
30
+ attention_mask=attention_mask,
31
+ position_ids=position_ids,
32
+ )
33
+
34
+ hidden_states = self.norm(hidden_states)
35
+ return hidden_states
llama_modeling/rms_norm.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class LlamaRMSNorm(nn.Module):
6
+ def __init__(self, hidden_size, eps=1e-5):
7
+ super().__init__()
8
+ self.weight = nn.Parameter(torch.ones(hidden_size))
9
+ self.variance_epsilon = eps
10
+
11
+ def forward(self, hidden_states):
12
+ input_dtype = hidden_states.dtype
13
+ hidden_states = hidden_states.to(torch.float32)
14
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
15
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
16
+ return self.weight * hidden_states.to(input_dtype)
llama_modeling/rope.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def rotate_half(x):
6
+ x1, x2 = torch.chunk(x, 2, dim=-1)
7
+ return torch.cat((-x2, x1), dim=-1)
8
+
9
+ def apply_rotary_pos_emb(q, k, cos, sin):
10
+ q_embed = (q * cos) + (rotate_half(q) * sin)
11
+ k_embed = (k * cos) + (rotate_half(k) * sin)
12
+ return q_embed, k_embed
13
+
14
+ class LlamaRotaryEmbedding(nn.Module):
15
+ def __init__(self, dim, max_position_embeddings=8192, base=10000):
16
+ super().__init__()
17
+ self.dim = dim
18
+ self.base = base
19
+ self.max_position_embeddings = max_position_embeddings
20
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
21
+ self.register_buffer("inv_freq", inv_freq)
22
+
23
+ def forward(self, position_ids: torch.LongTensor):
24
+ # position_ids: [batch_size, seq_len]
25
+ inv_freq = self.inv_freq.to(device=position_ids.device)
26
+ inv_freq_expanded = inv_freq[None, None, :] # [1, 1, dim//2]
27
+ position_ids_expanded = position_ids[:, :, None].float() # [batch_size, seq_len, 1]
28
+ freqs = torch.matmul(position_ids_expanded, inv_freq_expanded) # [batch_size, seq_len, dim//2]
29
+ freqs = torch.cat([freqs, freqs], dim=-1) # [batch_size, seq_len, dim]
30
+ cos = torch.cos(freqs)
31
+ sin = torch.sin(freqs)
32
+ cos = cos.unsqueeze(1) # [batch_size, 1, seq_len, dim]
33
+ sin = sin.unsqueeze(1) # [batch_size, 1, seq_len, dim]
34
+ return cos, sin
llama_modeling/tensor_prod_attn.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass
6
+ from einops import rearrange, repeat
7
+
8
+ from flash_attn import flash_attn_func
9
+ from .liger_rope import LigerRopeFunction
10
+ from .rms_norm import LlamaRMSNorm
11
+ from .config import LlamaConfig
12
+
13
+ class CPLinear(nn.Module):
14
+ def __init__(self, in_features, n_head, head_dim, kv_rank=2, q_rank=6):
15
+ super().__init__()
16
+ self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False)
17
+ self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False)
18
+ self.W_A_k = nn.Linear(in_features, n_head * kv_rank, bias=False)
19
+ self.W_B_k = nn.Linear(in_features, kv_rank * head_dim, bias=False)
20
+ self.W_A_v = nn.Linear(in_features, n_head * kv_rank, bias=False)
21
+ self.W_B_v = nn.Linear(in_features, kv_rank * head_dim, bias=False)
22
+
23
+ nn.init.xavier_uniform_(self.W_A_q.weight)
24
+ nn.init.xavier_uniform_(self.W_B_q.weight)
25
+ nn.init.xavier_uniform_(self.W_A_k.weight)
26
+ nn.init.xavier_uniform_(self.W_B_k.weight)
27
+ nn.init.xavier_uniform_(self.W_A_v.weight)
28
+ nn.init.xavier_uniform_(self.W_B_v.weight)
29
+
30
+ self.n_head = n_head
31
+ self.q_rank = q_rank
32
+ self.head_dim = head_dim
33
+ self.kv_rank = kv_rank
34
+
35
+ def forward(self, x):
36
+ batch_size, seq_len, _ = x.size()
37
+
38
+ A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank)
39
+ A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.kv_rank)
40
+ A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.kv_rank)
41
+
42
+ B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim)
43
+ B_k = self.W_B_k(x).view(batch_size, seq_len, self.kv_rank, self.head_dim)
44
+ B_v = self.W_B_v(x).view(batch_size, seq_len, self.kv_rank, self.head_dim)
45
+
46
+ A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank)
47
+ A_k = A_k.view(batch_size * seq_len, self.n_head, self.kv_rank)
48
+ A_v = A_v.view(batch_size * seq_len, self.n_head, self.kv_rank)
49
+
50
+ B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim)
51
+ B_k = B_k.view(batch_size * seq_len, self.kv_rank, self.head_dim)
52
+ B_v = B_v.view(batch_size * seq_len, self.kv_rank, self.head_dim)
53
+
54
+ q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
55
+ k = torch.bmm(A_k, B_k).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
56
+ v = torch.bmm(A_v, B_v).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
57
+
58
+ return q, k, v
59
+
60
+ class CausalTensorProductSelfAttn(nn.Module):
61
+ def __init__(self, config, kv_rank=2, q_rank=6):
62
+ super().__init__()
63
+ self.n_head = config.num_attention_heads
64
+ self.head_dim = config.hidden_size // config.num_attention_heads
65
+ self.n_embd = config.hidden_size
66
+ self.rank = kv_rank
67
+ self.q_rank = q_rank
68
+ self.max_position_embeddings = config.max_position_embeddings
69
+ self.rope_theta = config.rope_theta
70
+
71
+ self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.rank, self.q_rank)
72
+ self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=False)
73
+
74
+ self.register_buffer(
75
+ "cos_cached",
76
+ self._compute_rope_embeddings(
77
+ self.max_position_embeddings,
78
+ self.head_dim,
79
+ self.rope_theta,
80
+ dtype=torch.float32,
81
+ device=self.o_proj.weight.device,
82
+ )[0],
83
+ persistent=False,
84
+ )
85
+ self.register_buffer(
86
+ "sin_cached",
87
+ self._compute_rope_embeddings(
88
+ self.max_position_embeddings,
89
+ self.head_dim,
90
+ self.rope_theta,
91
+ dtype=torch.float32,
92
+ device=self.o_proj.weight.device,
93
+ )[1],
94
+ persistent=False,
95
+ )
96
+
97
+ self.using_groupnorm = getattr(config, 'using_groupnorm', False)
98
+ self.subln = LlamaRMSNorm(self.head_dim, eps=1e-5)
99
+
100
+ def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
101
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
102
+ t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
103
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
104
+ emb = torch.cat((freqs, freqs), dim=-1)
105
+ cos = emb.cos().to(dtype)
106
+ sin = emb.sin().to(dtype)
107
+ return cos.unsqueeze(0), sin.unsqueeze(0)
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states: torch.Tensor,
112
+ attention_mask: Optional[torch.Tensor] = None,
113
+ position_ids: Optional[torch.LongTensor] = None,
114
+ ) -> torch.Tensor:
115
+ # In B S (H D)
116
+ bsz, seq_len, _ = hidden_states.size()
117
+
118
+ if position_ids is None:
119
+ position_ids = torch.arange(seq_len, device=hidden_states.device)
120
+ position_ids = repeat(position_ids, 'l -> b l', b=bsz)
121
+
122
+ q, k, v = self.c_qkv(hidden_states) # B S (HD) -> B S H D
123
+
124
+ cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
125
+ sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
126
+
127
+ q, k = LigerRopeFunction.apply(
128
+ q,
129
+ k,
130
+ cos.squeeze(0),
131
+ sin.squeeze(0),
132
+ position_ids
133
+ )
134
+
135
+ attn_out = flash_attn_func(
136
+ q,
137
+ k,
138
+ v,
139
+ dropout_p=0.0,
140
+ causal=attention_mask is None
141
+ )
142
+
143
+ attn_out = self.subln(attn_out)
144
+
145
+ attn_out = rearrange(attn_out, "b s h d -> b s (h d)")
146
+ attn_out = self.o_proj(attn_out)
147
+ return attn_out
test-train.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.cuda.amp import autocast
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from tqdm import tqdm
6
+ import math, os, sys, json, glob, time, random
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from transformers import AutoTokenizer
9
+ from distributed_shampoo import AdamGraftingConfig, DistributedShampoo
10
+ from cut_cross_entropy import linear_cross_entropy
11
+ from torch.nn.utils import clip_grad_norm_
12
+ from utils.trainutils import count_parameters_layerwise, save_checkpoint, TBLogger
13
+
14
+ from llama_modeling.front_end import LlamaForCausalLM
15
+ from llama_modeling.config import LlamaConfig
16
+
17
+ class JSONLDataset(Dataset):
18
+ def __init__(self, directory_path, tokenizer, seq_length=1024,
19
+ text_key="text", max_files=None, batch_size=1000,
20
+ pad_token_id=0):
21
+ self.seq_length = seq_length
22
+ self.tokenizer = tokenizer
23
+ self.pad_token_id = pad_token_id
24
+ self.sequences = []
25
+
26
+ files = glob.glob(os.path.join(directory_path, "*.jsonl"))
27
+ if max_files is not None:
28
+ files = files[:max_files]
29
+
30
+ text_batch = []
31
+ for file_idx, file_path in enumerate(files):
32
+ with open(file_path, 'r', encoding='utf-8') as f:
33
+ for line in f:
34
+ try:
35
+ data = json.loads(line)
36
+ text = data.get(text_key, "")
37
+ if len(text) >= 100:
38
+ text_batch.append(text)
39
+
40
+ if len(text_batch) >= batch_size:
41
+ self._process_batch(text_batch)
42
+ text_batch = []
43
+ except:
44
+ continue
45
+
46
+ if text_batch:
47
+ self._process_batch(text_batch)
48
+
49
+ if self.sequences:
50
+ self.sequences = torch.tensor(self.sequences, dtype=torch.long)
51
+ else:
52
+ self.sequences = torch.empty((0, seq_length), dtype=torch.long)
53
+
54
+ def _process_batch(self, texts):
55
+ encoded = self.tokenizer(
56
+ texts,
57
+ add_special_tokens=False,
58
+ truncation=True,
59
+ padding=False,
60
+ return_attention_mask=False,
61
+ return_tensors=None
62
+ )['input_ids']
63
+
64
+ mlen = 0
65
+ for token_ids in encoded:
66
+ for i in range(0, len(token_ids), self.seq_length):
67
+ chunk = token_ids[i:i+self.seq_length]
68
+
69
+ # Pad
70
+ if len(chunk) < self.seq_length:
71
+ chunk += [self.pad_token_id] * (self.seq_length - len(chunk))
72
+
73
+ self.sequences.append(chunk)
74
+ mlen = max(mlen, len(chunk))
75
+
76
+ print("MAX: ", mlen)
77
+
78
+ def __len__(self):
79
+ return len(self.sequences)
80
+
81
+ def __getitem__(self, idx):
82
+ return self.sequences[idx]
83
+
84
+ def train_model(model, train_loader, optimizer, device, epochs=5, forward_dtype=torch.float32):
85
+ model.train()
86
+ criterion = nn.CrossEntropyLoss()
87
+ scaler = torch.amp.GradScaler("cuda")
88
+
89
+ logger = TBLogger(log_dir=f'logs/run-{time.time()}')
90
+
91
+ total_steps = len(train_loader) * epochs
92
+ scheduler = CosineAnnealingLR(
93
+ optimizer,
94
+ T_max=total_steps,
95
+ eta_min=5e-6
96
+ )
97
+
98
+ model = torch.compile(
99
+ model,
100
+ )
101
+
102
+ global_step = 0
103
+ for epoch in range(epochs):
104
+ running_loss = 0.0
105
+ total_batches = 0
106
+ progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
107
+
108
+ for batch_idx, data in enumerate(progress_bar):
109
+ data = data.to(device)
110
+ optimizer.zero_grad(set_to_none=True)
111
+
112
+ with torch.autocast(device_type='cuda', dtype=forward_dtype):
113
+ hidden_states, classifier_weights = model(data)
114
+
115
+ loss = linear_cross_entropy(
116
+ hidden_states,
117
+ classifier_weights,
118
+ data,
119
+ shift=True,
120
+ reduction="mean"
121
+ )
122
+
123
+ scaler.scale(loss).backward()
124
+ scaler.unscale_(optimizer)
125
+ clip_grad_norm_(model.parameters(), max_norm=1.0)
126
+ scaler.step(optimizer)
127
+ scaler.update()
128
+ scheduler.step()
129
+
130
+ # Update metrics - just add the loss itself
131
+ running_loss += loss.item()
132
+ total_batches += 1
133
+ global_step += 1
134
+ avg_loss = running_loss / total_batches
135
+ perplexity = math.exp(min(avg_loss, 100))
136
+
137
+ progress_bar.set_postfix({
138
+ 'loss': f'{avg_loss:.4f}',
139
+ 'ppl': f'{perplexity:.2f}'
140
+ })
141
+
142
+ metrics = {
143
+ 'loss': loss.item(),
144
+ 'perplexity': perplexity,
145
+ 'learning_rate': optimizer.param_groups[0]['lr'],
146
+ 'batch_size': data.size(0)
147
+ }
148
+
149
+ logger.log(metrics, step=global_step, model=model, grad_checking=True)
150
+
151
+ if batch_idx % 100 == 0:
152
+ print(f'\nBatch {batch_idx}/{len(train_loader)}: '
153
+ f'Loss: {avg_loss:.4f}, '
154
+ f'Perplexity: {perplexity:.2f}, '
155
+ f'Batches Processed: {total_batches}')
156
+
157
+ epoch_loss = running_loss / total_batches
158
+ epoch_ppl = math.exp(min(epoch_loss, 100))
159
+ print(f'\nEpoch {epoch+1} Summary:')
160
+ print(f'Average Loss: {epoch_loss:.4f}')
161
+ print(f'Perplexity: {epoch_ppl:.2f}')
162
+ print(f'Total Batches Processed: {total_batches}\n')
163
+
164
+ save_checkpoint(model, f'epoch_{epoch+1}.safetensors')
165
+
166
+ def sample_examples(dataset, tokenizer, num_samples=5):
167
+ if len(dataset) == 0:
168
+ print("The dataset is empty.")
169
+ return
170
+
171
+ num_samples = min(num_samples, len(dataset))
172
+
173
+ sampled_indices = random.sample(range(len(dataset)), num_samples)
174
+
175
+ for i, idx in enumerate(sampled_indices):
176
+ sequence = dataset[idx]
177
+ print(f"Sample {i + 1} (Index {idx}):")
178
+ print(sequence)
179
+ decoded_text = tokenizer.decode(sequence, skip_special_tokens=False, decode_special_tokens=False)
180
+ print(decoded_text)
181
+ print("-" * 40)
182
+
183
+ def main():
184
+ BATCH_SIZE = 36
185
+ SEQ_LENGTH = 512
186
+ EPOCHS = 3
187
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
188
+ tokenizer = AutoTokenizer.from_pretrained("./SmolLM2-135M-Instruct")
189
+
190
+ config_path = "config.json"
191
+ with open(config_path) as f:
192
+ config_dict = json.load(f)
193
+ config = LlamaConfig(**{k: v for k, v in config_dict.items() if k in LlamaConfig.__dataclass_fields__})
194
+
195
+ model = LlamaForCausalLM(config).to("cuda")
196
+
197
+ dataset = JSONLDataset(
198
+ directory_path="./Data_big",
199
+ tokenizer=tokenizer,
200
+ seq_length=SEQ_LENGTH,
201
+ text_key="text",
202
+ max_files=None,
203
+ )
204
+
205
+ train_loader = DataLoader(
206
+ dataset,
207
+ batch_size=BATCH_SIZE,
208
+ shuffle=True,
209
+ num_workers=4,
210
+ pin_memory=True,
211
+ drop_last=True
212
+ )
213
+
214
+ optimizer = DistributedShampoo(
215
+ model.parameters(),
216
+ lr=0.0001,
217
+ betas=(0.9, 0.999),
218
+ epsilon=1e-12,
219
+ weight_decay=1e-05,
220
+ max_preconditioner_dim=2048,
221
+ precondition_frequency=100,
222
+ start_preconditioning_step=250,
223
+ use_decoupled_weight_decay=False,
224
+ grafting_config=AdamGraftingConfig(
225
+ beta2=0.999,
226
+ epsilon=1e-12,
227
+ ),
228
+ )
229
+
230
+ print("*"*100)
231
+ torch.set_float32_matmul_precision('high')
232
+
233
+ count_parameters_layerwise(model)
234
+
235
+ train_model(model, train_loader, optimizer, DEVICE, EPOCHS, forward_dtype=torch.bfloat16)
236
+
237
+ if __name__ == "__main__":
238
+ main()
utils/__init__.py ADDED
File without changes
utils/trainutils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from tensorboardX import SummaryWriter
5
+ from safetensors.torch import save_file, load_file
6
+ from pathlib import Path
7
+ import time
8
+
9
+ def count_parameters_layerwise(model):
10
+ # Layerwise params, turn this into a util function.
11
+ total_params = 0
12
+ layer_params = {}
13
+
14
+ for name, parameter in model.named_parameters():
15
+ if not parameter.requires_grad:
16
+ continue
17
+
18
+ param_count = parameter.numel()
19
+ layer_params[name] = param_count
20
+ total_params += param_count
21
+
22
+ print(f"\nModel Parameter Summary:")
23
+ print("-" * 60)
24
+ for name, count in layer_params.items():
25
+ print(f"{name}: {count:,} parameters")
26
+ print("-" * 60)
27
+ print(f"Total Trainable Parameters: {total_params:,}\n")
28
+
29
+ return total_params
30
+
31
+ def save_checkpoint(model, filename="checkpoint.safetensors"):
32
+ if hasattr(model, '_orig_mod'):
33
+ model = model._orig_mod
34
+
35
+ torch.save(model.state_dict(), filename.replace('.safetensors', '.pt'))
36
+
37
+ def load_checkpoint(model, filename="checkpoint.safetensors"):
38
+ if hasattr(model, '_orig_mod'):
39
+ model = model._orig_mod
40
+
41
+ try:
42
+ model_state = load_file(filename)
43
+ model.load_state_dict(model_state)
44
+ except Exception as e:
45
+ model_state = torch.load(filename.replace('.safetensors', '.pt'), weights_only=True)
46
+ model.load_state_dict(model_state)
47
+
48
+ class TBLogger:
49
+ def __init__(self, log_dir='logs/current_run', flush_secs=10, enable_grad_logging=True):
50
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
51
+ self.writer = SummaryWriter(log_dir, flush_secs=flush_secs)
52
+ self.enable_grad_logging = enable_grad_logging
53
+ self.start_time = time.time()
54
+
55
+ def log(self, metrics, step=None, model=None, prefix='', grad_checking=False):
56
+ for name, value in metrics.items():
57
+ full_name = f"{prefix}{name}" if prefix else name
58
+
59
+ if isinstance(value, (int, float)):
60
+ self.writer.add_scalar(full_name, value, step)
61
+ elif isinstance(value, torch.Tensor):
62
+ self.writer.add_scalar(full_name, value.item(), step)
63
+ elif isinstance(value, (list, tuple)) and len(value) > 0:
64
+ if all(isinstance(x, (int, float)) for x in value):
65
+ self.writer.add_histogram(full_name, torch.tensor(value), step)
66
+
67
+ if self.enable_grad_logging and model is not None:
68
+ self._log_gradients(model, step, grad_checking)
69
+
70
+ def _log_gradients(self, model, step, grad_checking):
71
+ total_norm = 0.0
72
+ for name, param in model.named_parameters():
73
+ if grad_checking and param.grad is not None:
74
+ # Check for inf/nan in gradients
75
+ if torch.isnan(param.grad).any():
76
+ print(f"Warning: Found nan in gradients for layer: {name}")
77
+ continue
78
+ if torch.isinf(param.grad).any():
79
+ print(f"Warning: Found inf in gradients for layer: {name}")
80
+ continue
81
+
82
+ param_norm = param.grad.detach().data.norm(2)
83
+ self.writer.add_scalar(f"gradients/{name}_norm", param_norm, step)
84
+ total_norm += param_norm.item() ** 2
85
+
86
+ # Only compute total norm if we haven't encountered inf/nan
87
+ if total_norm > 0: # This means we had valid gradients
88
+ total_norm = total_norm ** 0.5
89
+ self.writer.add_scalar("gradients/total_norm", total_norm, step)
90
+
91
+ def close(self):
92
+ self.writer.close()