Commit
·
fbde94d
1
Parent(s):
7907b9c
add model code
Browse files- README.md +72 -0
- config.json +39 -0
- configuration_muddpythia.py +81 -0
- generation_demo.py +50 -0
- modeling_muddpythia.py +454 -0
- requirements.txt +2 -0
- tokenizer.json +0 -0
- tokenizer_config.json +9 -0
README.md
CHANGED
@@ -1,3 +1,75 @@
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: mit
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
tags:
|
5 |
+
- pytorch
|
6 |
+
- causal-lm
|
7 |
+
- muddformer
|
8 |
license: mit
|
9 |
---
|
10 |
+
In comparison with Pythia-2.8B, MUDDPythia-2.8B is a pretrained language model on the Pile with 300B tokens, which uses a simple yet effective method to address the limitations of residual connections and enhance cross-layer information flow in Transformers. Please see downstrem evaluations and more details in the paper[(MUDDFormer: Breaking Residual Bottlenecks in Transformers via Multiway Dynamic Dense Connections)](https://arxiv.org). In addition, we open-source Jax training code on [(Github)](https://github.com/Caiyun-AI/MUDDFormer/).
|
11 |
+
|
12 |
+
We recommend <strong>compiled version</strong> of MUDDPythia with *torch.compile* for inference acceleration. Please refer to Generation section for compile implementation.
|
13 |
+
|
14 |
+
# Usage
|
15 |
+
|
16 |
+
## Env
|
17 |
+
|
18 |
+
```
|
19 |
+
pip install transformers==4.35.0 torch==2.5.1
|
20 |
+
```
|
21 |
+
|
22 |
+
## Generation
|
23 |
+
|
24 |
+
```
|
25 |
+
import time
|
26 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
27 |
+
import torch
|
28 |
+
|
29 |
+
import os
|
30 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
31 |
+
|
32 |
+
device = torch.device('cuda:0')
|
33 |
+
dtype = torch.bfloat16
|
34 |
+
MAX_BATCH_SIZE = 1
|
35 |
+
MAX_SEQ_LENGTH = 2048
|
36 |
+
NUM_TOKENS_TO_GENERATE = 10
|
37 |
+
COMPILE = True
|
38 |
+
OPTIMIZED_COMPPILE = False
|
39 |
+
|
40 |
+
if OPTIMIZED_COMPPILE:
|
41 |
+
import torch._dynamo.config
|
42 |
+
import torch._inductor.config
|
43 |
+
torch._dynamo.config.cache_size_limit = 64
|
44 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
45 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
46 |
+
torch._inductor.config.fx_graph_cache = True
|
47 |
+
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/MUDDPythia-2.8B")
|
49 |
+
model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/MUDDPythia-2.8B", trust_remote_code=True)
|
50 |
+
|
51 |
+
_ = model.to(device=device,dtype=dtype)
|
52 |
+
with torch.device(device):
|
53 |
+
model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH,dtype=dtype)
|
54 |
+
|
55 |
+
def decode_one_token(model, cur_token, input_pos):
|
56 |
+
logits = model(cur_token, input_pos=input_pos, return_tensor=True)
|
57 |
+
new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
|
58 |
+
return new_token
|
59 |
+
|
60 |
+
prompt = "Beijing is the capital of China. London is the capital of"
|
61 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
62 |
+
|
63 |
+
compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
|
64 |
+
|
65 |
+
print('Start generating tokens, but it will take a few minutes to compile at the first time.')
|
66 |
+
for i in range(10):
|
67 |
+
t0 = time.time()
|
68 |
+
with torch.no_grad():
|
69 |
+
generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
|
70 |
+
text = tokenizer.decode(generated_ids[0])
|
71 |
+
if i ==0:
|
72 |
+
print(f'Generated text: {text}')
|
73 |
+
t1 = time.time()
|
74 |
+
print(f'Time consumed at iteration {i}: {t1-t0}s')
|
75 |
+
```
|
config.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MUDDPythia"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_muddpythia.MUDDPythiaConfig",
|
7 |
+
"AutoModelForCausalLM": "modeling_muddpythia.MUDDPythia"
|
8 |
+
},
|
9 |
+
"block_size": 2048,
|
10 |
+
"bos_token_id": 0,
|
11 |
+
"dense": true,
|
12 |
+
"dense_type": "qkvr",
|
13 |
+
"dim": 2560,
|
14 |
+
"dynamic_dense": true,
|
15 |
+
"eos_token_id": 0,
|
16 |
+
"expand_last": true,
|
17 |
+
"head_dim": 80,
|
18 |
+
"intermediate_size": 10240,
|
19 |
+
"is_training": false,
|
20 |
+
"model_type": "muddpythia",
|
21 |
+
"n_head": 32,
|
22 |
+
"n_layer": 32,
|
23 |
+
"n_local_heads": 32,
|
24 |
+
"norm_eps": 1e-05,
|
25 |
+
"rope_base": 10000,
|
26 |
+
"rotary_pct": 0.25,
|
27 |
+
"round64": true,
|
28 |
+
"sepln": true,
|
29 |
+
"stack_hidden": false,
|
30 |
+
"tie_word_embeddings": false,
|
31 |
+
"torch_dtype": "bfloat16",
|
32 |
+
"transformers_version": "4.35.0",
|
33 |
+
"use_gradient_checkpointing": false,
|
34 |
+
"use_layer_cache": true,
|
35 |
+
"use_linear_bias": true,
|
36 |
+
"use_parallel_residual": true,
|
37 |
+
"use_qk_norm": true,
|
38 |
+
"vocab_size": 50432
|
39 |
+
}
|
configuration_muddpythia.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.configuration_utils import PretrainedConfig
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
|
5 |
+
class MUDDPythiaConfig(PretrainedConfig):
|
6 |
+
model_type = "muddpythia"
|
7 |
+
|
8 |
+
'''
|
9 |
+
MUDDPythiaConfig is a config class for MUDDPythia, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
|
10 |
+
'''
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
block_size: int = 2048,
|
14 |
+
vocab_size: int = 32000,
|
15 |
+
n_layer: int = 32,
|
16 |
+
n_head: int = 32,
|
17 |
+
dim: int = 2560,
|
18 |
+
intermediate_size: int = None,
|
19 |
+
n_local_heads: int = -1,
|
20 |
+
head_dim: int = 64,
|
21 |
+
rope_base: float = 10000,
|
22 |
+
norm_eps: float = 1e-5,
|
23 |
+
use_gradient_checkpointing: bool = False,
|
24 |
+
is_training: bool = False,
|
25 |
+
use_qk_norm: bool = False ,
|
26 |
+
pad_token_id: Optional[int]= None,
|
27 |
+
use_parallel_residual: bool =True,
|
28 |
+
use_linear_bias: bool = True,
|
29 |
+
rotary_pct: float = 0.25,
|
30 |
+
bos_token_id: int =1,
|
31 |
+
eos_token_id: int =2,
|
32 |
+
tie_word_embeddings: bool =False,
|
33 |
+
use_layer_cache: bool = True,
|
34 |
+
stack_hidden: bool = False,
|
35 |
+
dense: bool = True,
|
36 |
+
dynamic_dense: bool = True,
|
37 |
+
sepln: bool = True,
|
38 |
+
dense_type: str = 'qkvr',
|
39 |
+
expand_last: bool = False,
|
40 |
+
round64: bool = False,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
self.block_size=block_size
|
44 |
+
self.vocab_size=vocab_size
|
45 |
+
self.n_layer=n_layer
|
46 |
+
self.n_head=n_head
|
47 |
+
self.dim=dim
|
48 |
+
self.intermediate_size=intermediate_size
|
49 |
+
self.n_local_heads=n_local_heads
|
50 |
+
self.head_dim=head_dim
|
51 |
+
self.rope_base=rope_base
|
52 |
+
self.norm_eps=norm_eps
|
53 |
+
self.use_gradient_checkpointing=use_gradient_checkpointing
|
54 |
+
self.is_training=is_training
|
55 |
+
self.use_qk_norm=use_qk_norm
|
56 |
+
self.use_parallel_residual = use_parallel_residual
|
57 |
+
self.use_linear_bias = use_linear_bias
|
58 |
+
self.rotary_pct = rotary_pct
|
59 |
+
|
60 |
+
self.use_layer_cache= use_layer_cache
|
61 |
+
self.stack_hidden= stack_hidden
|
62 |
+
self.dense= dense
|
63 |
+
self.dynamic_dense= dynamic_dense
|
64 |
+
self.sepln= sepln
|
65 |
+
self.dense_type=dense_type
|
66 |
+
self.expand_last= expand_last
|
67 |
+
self.round64 = round64
|
68 |
+
# post init
|
69 |
+
if self.n_local_heads == -1:
|
70 |
+
self.n_local_heads = self.n_head
|
71 |
+
if self.intermediate_size is None:
|
72 |
+
self.intermediate_size = 4 * self.dim
|
73 |
+
self.head_dim = self.dim // self.n_head
|
74 |
+
|
75 |
+
super().__init__(
|
76 |
+
pad_token_id=pad_token_id,
|
77 |
+
bos_token_id=bos_token_id,
|
78 |
+
eos_token_id=eos_token_id,
|
79 |
+
tie_word_embeddings=tie_word_embeddings,
|
80 |
+
**kwargs,
|
81 |
+
)
|
generation_demo.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import os
|
6 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
7 |
+
|
8 |
+
device = torch.device('cuda:0')
|
9 |
+
dtype = torch.bfloat16
|
10 |
+
MAX_BATCH_SIZE = 1
|
11 |
+
MAX_SEQ_LENGTH = 2048
|
12 |
+
NUM_TOKENS_TO_GENERATE = 10
|
13 |
+
COMPILE = True
|
14 |
+
OPTIMIZED_COMPPILE = False
|
15 |
+
|
16 |
+
if OPTIMIZED_COMPPILE:
|
17 |
+
import torch._dynamo.config
|
18 |
+
import torch._inductor.config
|
19 |
+
torch._dynamo.config.cache_size_limit = 64
|
20 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
21 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
22 |
+
torch._inductor.config.fx_graph_cache = True
|
23 |
+
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/MUDDPythia-2.8B")
|
25 |
+
model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/MUDDPythia-2.8B", trust_remote_code=True)
|
26 |
+
|
27 |
+
_ = model.to(device=device,dtype=dtype)
|
28 |
+
with torch.device(device):
|
29 |
+
model.setup_caches(max_batch_size=MAX_BATCH_SIZE, max_seq_length=MAX_SEQ_LENGTH,dtype=dtype)
|
30 |
+
|
31 |
+
def decode_one_token(model, cur_token, input_pos):
|
32 |
+
logits = model(cur_token, input_pos=input_pos, return_tensor=True)
|
33 |
+
new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
|
34 |
+
return new_token
|
35 |
+
|
36 |
+
prompt = "Beijing is the capital of China. London is the capital of"
|
37 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
38 |
+
|
39 |
+
compiled_decode_one_token = torch.compile(decode_one_token,mode="reduce-overhead", fullgraph=True) if COMPILE else None
|
40 |
+
|
41 |
+
print('Start generating tokens, but it will take a few minutes to compile at the first time.')
|
42 |
+
for i in range(10):
|
43 |
+
t0 = time.time()
|
44 |
+
with torch.no_grad():
|
45 |
+
generated_ids = model.generate(input_ids.to(device),num_tokens_to_generate=NUM_TOKENS_TO_GENERATE, compiled_decode_one_token=compiled_decode_one_token)
|
46 |
+
text = tokenizer.decode(generated_ids[0])
|
47 |
+
if i ==0:
|
48 |
+
print(f'Generated text: {text}')
|
49 |
+
t1 = time.time()
|
50 |
+
print(f'Time consumed at iteration {i}: {t1-t0}s')
|
modeling_muddpythia.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
from collections import namedtuple
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.utils.checkpoint import checkpoint
|
11 |
+
|
12 |
+
try:
|
13 |
+
from .configuration_muddpythia import MUDDPythiaConfig
|
14 |
+
except:
|
15 |
+
from configuration_muddpythia import MUDDPythiaConfig
|
16 |
+
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
|
19 |
+
|
20 |
+
class KVCache(nn.Module):
|
21 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.float16):
|
22 |
+
super().__init__()
|
23 |
+
self.seq_length = max_seq_length
|
24 |
+
cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim)
|
25 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
26 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
27 |
+
|
28 |
+
def update(self, input_pos, k_val, v_val):
|
29 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
30 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
31 |
+
k_out = self.k_cache
|
32 |
+
v_out = self.v_cache
|
33 |
+
k_out[:, :, input_pos] = k_val
|
34 |
+
v_out[:, :, input_pos] = v_val
|
35 |
+
return k_out, v_out
|
36 |
+
|
37 |
+
class LayerCache(nn.Module):
|
38 |
+
def __init__(self, max_batch_size, num_layers, model_dim, dtype=torch.float16):
|
39 |
+
super().__init__()
|
40 |
+
cache_shape = (num_layers+1, max_batch_size, 1, model_dim) # LBTD
|
41 |
+
self.register_buffer('layer_cache', torch.zeros(cache_shape, dtype=dtype))
|
42 |
+
|
43 |
+
def update(self, x, lidx):
|
44 |
+
self.layer_cache[lidx] = x
|
45 |
+
return self.layer_cache[:lidx+1]
|
46 |
+
|
47 |
+
class MultiwayDynamicDenseBlock(nn.Module):
|
48 |
+
def __init__(self, config: MUDDPythiaConfig, lidx: int, last_layer=False) -> None:
|
49 |
+
super().__init__()
|
50 |
+
self.norm = RMSnormNoscale(epsilon=config.norm_eps)
|
51 |
+
self.C = len(config.dense_type) if not last_layer else 1
|
52 |
+
self.lidx = lidx
|
53 |
+
l = lidx + 2
|
54 |
+
hid_dim, out_dim = l * self.C, l * self.C
|
55 |
+
if last_layer and config.expand_last: hid_dim *= 4
|
56 |
+
if config.round64: hid_dim = (hid_dim// 64 +1) * 64
|
57 |
+
self.w1 = nn.Linear(config.dim, hid_dim, bias=False)
|
58 |
+
self.act = nn.GELU()
|
59 |
+
self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
|
60 |
+
|
61 |
+
def forward(self, x: Tensor) -> Tensor:
|
62 |
+
x = self.norm(x)
|
63 |
+
dw = self.w2(self.act(self.w1(x))) # BTD->BTL
|
64 |
+
dw = rearrange(dw, 'B T (C L) -> C B T L', C=self.C)
|
65 |
+
return dw
|
66 |
+
|
67 |
+
def layer_mix(self, hids, dw)-> Tensor:
|
68 |
+
x = tuple([sum(dw[cidx,:,:,j,None] * hids[j] for j in range(self.lidx+2)) for cidx in range(self.C)]) # BTL, LBTD-> BTD
|
69 |
+
return x
|
70 |
+
|
71 |
+
class MUDDPythia(PreTrainedModel):
|
72 |
+
config_class=MUDDPythiaConfig
|
73 |
+
def __init__(self, config: MUDDPythiaConfig) -> None:
|
74 |
+
super().__init__(config)
|
75 |
+
self.config = config
|
76 |
+
self.use_gradient_checkpointing = config.use_gradient_checkpointing
|
77 |
+
self.is_training = config.is_training
|
78 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
79 |
+
self.layers = nn.ModuleList(TransformerBlock(config, lidx) for lidx in range(config.n_layer))
|
80 |
+
self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
|
81 |
+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
82 |
+
C = len(self.config.dense_type)
|
83 |
+
self.dense_bs = nn.ParameterList([nn.Parameter(data=torch.randn(C if lidx != config.n_layer-1 else 1, lidx+2)) for lidx in range(config.n_layer)])
|
84 |
+
|
85 |
+
self.layer_cache = None
|
86 |
+
self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
|
87 |
+
self.stack_hidden = self.config.stack_hidden
|
88 |
+
self.dynamic = self.config.dynamic_dense
|
89 |
+
self.dense = self.config.dense
|
90 |
+
if self.dynamic:
|
91 |
+
self.dynamic_dense = nn.ModuleList([MultiwayDynamicDenseBlock(config, lidx, last_layer=lidx==config.n_layer-1) for lidx in range(config.n_layer)])
|
92 |
+
|
93 |
+
self.rotary_ndims = int(config.head_dim * config.rotary_pct)
|
94 |
+
self.freqs_cis: Optional[Tensor] = None
|
95 |
+
self.mask_cache: Optional[Tensor] = None
|
96 |
+
self.max_batch_size = -1
|
97 |
+
self.max_seq_length = -1
|
98 |
+
|
99 |
+
def tie_weights(self): # placeholder
|
100 |
+
return
|
101 |
+
|
102 |
+
def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float16):
|
103 |
+
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
104 |
+
return
|
105 |
+
head_dim = self.config.dim // self.config.n_head
|
106 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
107 |
+
self.max_seq_length = max_seq_length
|
108 |
+
self.max_batch_size = max_batch_size
|
109 |
+
if not self.config.is_training:
|
110 |
+
if self.use_layer_cache:
|
111 |
+
self.layer_cache = LayerCache(max_batch_size, self.config.n_layer, self.config.dim, dtype=dtype)
|
112 |
+
for b in self.layers:
|
113 |
+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=dtype)
|
114 |
+
|
115 |
+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.rotary_ndims, self.config.rope_base).to(self.tok_embeddings.weight.device)
|
116 |
+
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device))
|
117 |
+
|
118 |
+
def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
|
119 |
+
batch_size, seq_length = input_ids.shape
|
120 |
+
input_pos = torch.arange(seq_length, device=self.device)
|
121 |
+
generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
|
122 |
+
generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
|
123 |
+
logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
|
124 |
+
_next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
125 |
+
next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int)
|
126 |
+
next_token[:batch_size] = _next_token
|
127 |
+
generated_ids[:, seq_length] = next_token[:batch_size, 0]
|
128 |
+
input_pos = torch.tensor([seq_length], device=self.device)
|
129 |
+
for _ in range(1, num_tokens_to_generate):
|
130 |
+
if compiled_decode_one_token is not None:
|
131 |
+
next_token = compiled_decode_one_token(self, next_token.clone(), input_pos)
|
132 |
+
else:
|
133 |
+
next_token = self.decode_one_token(next_token.clone(), input_pos)
|
134 |
+
generated_ids[:, input_pos+1] = next_token.int()[:batch_size]
|
135 |
+
input_pos += 1
|
136 |
+
return generated_ids
|
137 |
+
|
138 |
+
def decode_one_token(self, cur_token, input_pos):
|
139 |
+
logits = self.forward(
|
140 |
+
cur_token,
|
141 |
+
input_pos=input_pos,
|
142 |
+
return_tensor=True
|
143 |
+
)
|
144 |
+
new_token = torch.argmax(logits[:, -1], dim=-1)[:,None]
|
145 |
+
return new_token
|
146 |
+
|
147 |
+
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor:
|
148 |
+
assert self.freqs_cis is not None, "Caches must be initialized first"
|
149 |
+
if input_pos is None:
|
150 |
+
input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int)
|
151 |
+
mask = self.causal_mask[None, None, input_pos]
|
152 |
+
freqs_cis = self.freqs_cis[input_pos]
|
153 |
+
x = self.tok_embeddings(idx)
|
154 |
+
_, seqlen, _ = x.shape
|
155 |
+
use_layer_cache = self.use_layer_cache and seqlen == 1
|
156 |
+
if use_layer_cache:
|
157 |
+
self.layer_cache.update(x, 0)
|
158 |
+
else:
|
159 |
+
hiddens = [x]
|
160 |
+
for i, layer in enumerate(self.layers):
|
161 |
+
if self.use_gradient_checkpointing:
|
162 |
+
x = checkpoint(layer, x, input_pos, freqs_cis, mask)
|
163 |
+
else:
|
164 |
+
x = layer(x, input_pos, freqs_cis, mask)
|
165 |
+
if use_layer_cache:
|
166 |
+
_hidden = self.layer_cache.update(x, i+1) # LBTD
|
167 |
+
else:
|
168 |
+
hiddens.append(x)
|
169 |
+
_hidden = hiddens if not self.stack_hidden else hiddens
|
170 |
+
if self.dynamic and self.dense:
|
171 |
+
dw = self.dynamic_dense[i](x) # BTD -> CBTL
|
172 |
+
dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
|
173 |
+
if self.stack_hidden:
|
174 |
+
x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
|
175 |
+
else:
|
176 |
+
x = self.dynamic_dense[i].layer_mix(_hidden, dw)
|
177 |
+
if self.config.dense_type == 'qkvr' and self.config.dense and self.config.dynamic_dense:
|
178 |
+
x = x[0]
|
179 |
+
x = self.norm(x)
|
180 |
+
logits = self.output(x)
|
181 |
+
if return_tensor:
|
182 |
+
return logits
|
183 |
+
else:
|
184 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
185 |
+
return CausalLMOutput(logits=logits)
|
186 |
+
|
187 |
+
class TransformerBlock(nn.Module):
|
188 |
+
def __init__(self, config: MUDDPythiaConfig, lidx) -> None:
|
189 |
+
super().__init__()
|
190 |
+
self.lidx = lidx
|
191 |
+
self.config = config
|
192 |
+
self.attention = Attention(config, lidx)
|
193 |
+
|
194 |
+
self.feed_forward = FeedForward(config, lidx)
|
195 |
+
self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
|
196 |
+
self.use_parallel_residual = config.use_parallel_residual
|
197 |
+
|
198 |
+
if self.config.sepln and self.lidx > 0:
|
199 |
+
self.attention_norms = torch.nn.ModuleList([ nn.LayerNorm(config.dim, eps=config.norm_eps) for _ in range(3)])
|
200 |
+
else:
|
201 |
+
self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
|
202 |
+
|
203 |
+
def forward(self, x: Union[Tuple[Tensor], Tensor], input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
|
204 |
+
if self.config.dense_type == 'l' or self.lidx == 0 or not self.config.dense:
|
205 |
+
res = x
|
206 |
+
normed_x = self.attention_norm(x)
|
207 |
+
elif self.config.dense_type == 'qkvr':
|
208 |
+
res = x[-1] # for mlp
|
209 |
+
if self.config.stack_hidden or not self.config.sepln:
|
210 |
+
normed_x = self.attention_norm(x[:3])
|
211 |
+
else:
|
212 |
+
normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
|
213 |
+
h = res + self.attention(normed_x, freqs_cis, mask, input_pos)
|
214 |
+
out = h + self.feed_forward(self.ffn_norm(res if self.use_parallel_residual else h))
|
215 |
+
return out
|
216 |
+
|
217 |
+
class Attention(nn.Module):
|
218 |
+
def __init__(self, config: MUDDPythiaConfig, lidx):
|
219 |
+
super().__init__()
|
220 |
+
assert config.dim % config.n_head == 0
|
221 |
+
|
222 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
223 |
+
# key, query, value projections for all heads, but in a batch
|
224 |
+
self.config = config
|
225 |
+
if self.config.dense_type == 'l' or not self.config.dense:
|
226 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=True)
|
227 |
+
elif self.config.dense_type == 'qkvr':
|
228 |
+
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=True)
|
229 |
+
self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True)
|
230 |
+
self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=True)
|
231 |
+
|
232 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=True)
|
233 |
+
self.lidx = lidx
|
234 |
+
self.kv_cache = None
|
235 |
+
|
236 |
+
self.n_head = config.n_head
|
237 |
+
self.head_dim = config.head_dim
|
238 |
+
self.scale_factor = 1 / math.sqrt(self.head_dim)
|
239 |
+
self.n_local_heads = config.n_local_heads
|
240 |
+
self.dim = config.dim
|
241 |
+
self.use_qk_norm = config.use_qk_norm
|
242 |
+
if self.use_qk_norm:
|
243 |
+
self.q_norm = RMSNorm(self.head_dim, config.norm_eps)
|
244 |
+
self.k_norm = RMSNorm(self.head_dim, config.norm_eps)
|
245 |
+
|
246 |
+
self.rotary_ndims = int(self.head_dim * config.rotary_pct)
|
247 |
+
|
248 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
249 |
+
|
250 |
+
def load_hook(self, state_dict, prefix, *args):
|
251 |
+
if prefix + "wq.weight" in state_dict and (self.config.dense_type == 'l' or not self.config.dense):
|
252 |
+
wq = state_dict.pop(prefix + "wq.weight")
|
253 |
+
wk = state_dict.pop(prefix + "wk.weight")
|
254 |
+
wv = state_dict.pop(prefix + "wv.weight")
|
255 |
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
256 |
+
|
257 |
+
def forward(self, x: Union[Tuple[Tensor], Tensor], freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
258 |
+
if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
|
259 |
+
bsz, seqlen, _ = x.shape
|
260 |
+
else:
|
261 |
+
if self.config.stack_hidden:
|
262 |
+
C, bsz, seqlen, _ = x.shape
|
263 |
+
else:
|
264 |
+
C, (bsz, seqlen, _) = len(x), x[0].shape
|
265 |
+
kv_size = self.n_local_heads * self.head_dim
|
266 |
+
|
267 |
+
if self.config.dense_type == 'l' or not self.config.dense:
|
268 |
+
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
269 |
+
|
270 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
271 |
+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
272 |
+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
273 |
+
elif self.config.dense_type == 'qkvr':
|
274 |
+
if self.lidx == 0:
|
275 |
+
xq, xk, xv = x, x, x
|
276 |
+
else:
|
277 |
+
xq, xk, xv = x[0], x[1], x[2]
|
278 |
+
q = self.wq(xq).view(bsz, seqlen, self.n_head, self.head_dim)
|
279 |
+
k = self.wk(xk).view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
280 |
+
v = self.wv(xv).view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
281 |
+
|
282 |
+
if self.use_qk_norm:
|
283 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
284 |
+
|
285 |
+
if self.rotary_ndims == self.head_dim:
|
286 |
+
q = apply_rotary_emb(q, freqs_cis) #BTND
|
287 |
+
k = apply_rotary_emb(k, freqs_cis)
|
288 |
+
else:
|
289 |
+
q_rot = q[..., : self.rotary_ndims]
|
290 |
+
q_pass = q[..., self.rotary_ndims :]
|
291 |
+
k_rot = k[..., : self.rotary_ndims]
|
292 |
+
k_pass = k[..., self.rotary_ndims :]
|
293 |
+
q_rot = apply_rotary_emb(q_rot, freqs_cis, mode='half') #BTND
|
294 |
+
k_rot = apply_rotary_emb(k_rot, freqs_cis, mode='half')
|
295 |
+
q = torch.cat((q_rot, q_pass), dim=-1)
|
296 |
+
k = torch.cat((k_rot, k_pass), dim=-1)
|
297 |
+
|
298 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
299 |
+
|
300 |
+
if self.kv_cache is not None:
|
301 |
+
if seqlen == 1:
|
302 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
303 |
+
else:
|
304 |
+
_, _ = self.kv_cache.update(input_pos, k, v)
|
305 |
+
|
306 |
+
if seqlen == 1: # one-token generation
|
307 |
+
k_mask = mask[:,:,:,:self.kv_cache.seq_length]
|
308 |
+
else:# prefill
|
309 |
+
k_mask = mask[:,:,:,:k.shape[-2]]
|
310 |
+
|
311 |
+
logits = q @ k.transpose(-2, -1) * self.scale_factor
|
312 |
+
min_value = torch.finfo(torch.float16).min
|
313 |
+
logits = torch.where(k_mask, logits, min_value)
|
314 |
+
probs = logits.softmax(-1)
|
315 |
+
y = probs @ v
|
316 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
317 |
+
y = self.wo(y)
|
318 |
+
return y
|
319 |
+
|
320 |
+
class RMSnormNoscale(nn.Module):
|
321 |
+
|
322 |
+
def __init__(self, epsilon=1e-6, dim=-1):
|
323 |
+
super().__init__()
|
324 |
+
self.dim = dim
|
325 |
+
self.epsilon = epsilon
|
326 |
+
|
327 |
+
def forward(self, inputs):
|
328 |
+
var = inputs.pow(2).mean(dim=self.dim, keepdim=True)
|
329 |
+
normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
|
330 |
+
return normed_inputs
|
331 |
+
|
332 |
+
class RMSNorm(nn.Module):
|
333 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
334 |
+
super().__init__()
|
335 |
+
self.eps = eps
|
336 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
337 |
+
|
338 |
+
def _norm(self, x):
|
339 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
340 |
+
|
341 |
+
def forward(self, x: Tensor) -> Tensor:
|
342 |
+
output = self._norm(x.float()).type_as(x)
|
343 |
+
return output * self.weight
|
344 |
+
|
345 |
+
class FeedForward(nn.Module):
|
346 |
+
def __init__(self, config: MUDDPythiaConfig, lidx, round128=True, scale_with_layer=True) -> None:
|
347 |
+
super().__init__()
|
348 |
+
hid_dim = config.intermediate_size
|
349 |
+
if scale_with_layer:
|
350 |
+
hid_dim = hid_dim * (lidx/(config.n_layer -1) +0.5)
|
351 |
+
if round128:
|
352 |
+
hid_dim = round(hid_dim / 128) * 128
|
353 |
+
self.w1 = nn.Linear(config.dim, hid_dim, bias=config.use_linear_bias)
|
354 |
+
self.w2 = nn.Linear(hid_dim, config.dim, bias=config.use_linear_bias)
|
355 |
+
|
356 |
+
def forward(self, x: Tensor) -> Tensor:
|
357 |
+
return self.w2(F.gelu(self.w1(x)))
|
358 |
+
|
359 |
+
def find_multiple(n: int, k: int) -> int:
|
360 |
+
if n % k == 0:
|
361 |
+
return n
|
362 |
+
return n + k - (n % k)
|
363 |
+
|
364 |
+
def precompute_freqs_cis(
|
365 |
+
seq_len: int, n_elem: int, base: int = 10000
|
366 |
+
) -> Tensor:
|
367 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
368 |
+
t = torch.arange(seq_len, device=freqs.device)
|
369 |
+
freqs = torch.outer(t, freqs)
|
370 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
371 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
372 |
+
return cache.to(dtype=torch.float16)
|
373 |
+
|
374 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor:
|
375 |
+
if mode == 'half':
|
376 |
+
xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2)
|
377 |
+
elif mode == 'alternative':
|
378 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
379 |
+
freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
|
380 |
+
x_out2 = torch.stack(
|
381 |
+
[
|
382 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
383 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
384 |
+
],
|
385 |
+
-1,
|
386 |
+
)
|
387 |
+
x_out2 = x_out2.flatten(3)
|
388 |
+
return x_out2.type_as(x)
|
389 |
+
|
390 |
+
def match_weight_muddpythia(model, w, strict=False, pythia=True):
|
391 |
+
if pythia:
|
392 |
+
map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1':'ffn_layer1', 'w2': 'ffn_layer2',
|
393 |
+
'weight': 'w', 'bias': 'b'}
|
394 |
+
else:
|
395 |
+
map_dict={'wq':'query', 'wk':'key', 'wv':'value', 'wo':'post', 'w1': 'ffn_layer1_gate', 'w3': 'ffn_layer1', 'w2': 'ffn_layer2',
|
396 |
+
'weight': 'w', 'bias': 'b'}
|
397 |
+
ln_dict={'weight':'scale','bias':'bias'}
|
398 |
+
E, H, D = model.config.dim, model.config.n_head, model.config.head_dim
|
399 |
+
N = model.config.vocab_size
|
400 |
+
state_dict = {}
|
401 |
+
for k, v in model.named_parameters():
|
402 |
+
if k == 'tok_embeddings.weight':
|
403 |
+
v = w['state.mdl_vars.params.lm.embedding_lookup.emb_var'][:N,:]
|
404 |
+
elif k == 'norm.weight':
|
405 |
+
v = w['state.mdl_vars.params.lm.final_ln.scale']
|
406 |
+
elif k == 'norm.bias':
|
407 |
+
v = w['state.mdl_vars.params.lm.final_ln.bias']
|
408 |
+
elif k == 'output.weight':
|
409 |
+
v = w['state.mdl_vars.params.lm.softmax.logits_ffn.linear.w'].T[:N,:] # E,N -> N,E
|
410 |
+
elif 'dense_bs' in k: # static dense w
|
411 |
+
lidx = int(k.split('.')[-1])
|
412 |
+
v = w[f'state.mdl_vars.params.lm.transformer.dense_conn_{lidx}']
|
413 |
+
elif 'dynamic_dense' in k:
|
414 |
+
lidx = int(k.split('.')[1])
|
415 |
+
widx = int(k.split('.')[2][-1]) # 1 or 2 in w1, w2
|
416 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.dynamic_dense_conn{widx}_{lidx}'].T
|
417 |
+
else:
|
418 |
+
assert 'layers' in k
|
419 |
+
lidx = int(k.split('.')[1])
|
420 |
+
if '.attention.' in k:
|
421 |
+
_, _, _, ptype, wtype = k.split('.')
|
422 |
+
if ptype in ['wq', 'wk', 'wv', 'wo']:
|
423 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.{map_dict.get(wtype, wtype)}']#.reshape(E,E)
|
424 |
+
if wtype == 'weight':
|
425 |
+
v = v.reshape(E,E)
|
426 |
+
elif wtype == 'bias':
|
427 |
+
v = v.reshape(E)
|
428 |
+
if ptype != 'wo' and wtype == 'weight':
|
429 |
+
v = v.T
|
430 |
+
elif ptype in ['q_norm', 'k_norm']:
|
431 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.self_attention.{map_dict.get(ptype, ptype)}.scale']
|
432 |
+
elif 'feed_forward' in k:
|
433 |
+
ptype = k.split('.')[3] # w1, w3,w2
|
434 |
+
wtype = k.split('.')[4] # weight or bias
|
435 |
+
if wtype=='weight':
|
436 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.linear.{map_dict[wtype]}']
|
437 |
+
v = v.T
|
438 |
+
elif wtype=='bias':
|
439 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.{map_dict[ptype]}.bias.{map_dict[wtype]}']
|
440 |
+
elif 'ffn_norm' in k: # mlp layernorm
|
441 |
+
wtype = k.split('.')[-1] # weight or bias
|
442 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.ff_layer.layer_norm.{ln_dict[wtype]}']
|
443 |
+
elif 'attention_norm' in k: # attention layernorm
|
444 |
+
wtype = k.split('.')[-1] # weight or bias
|
445 |
+
if 'attention_norms' in k:
|
446 |
+
ln_idx = int(k.split('.')[3])
|
447 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norms_{ln_idx}.{ln_dict[wtype]}']
|
448 |
+
else:
|
449 |
+
v = w[f'state.mdl_vars.params.lm.transformer.x_layers_{lidx}.layer_norm.{ln_dict[wtype]}']
|
450 |
+
if pythia and 'weight' in k and 'norm' in k and 'q_norm' not in k and 'k_norm' not in k:
|
451 |
+
v = v+1
|
452 |
+
state_dict[k] = torch.tensor(v)
|
453 |
+
model.load_state_dict(state_dict, strict=strict)
|
454 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
transformers==4.35.0
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<|endoftext|>",
|
4 |
+
"clean_up_tokenization_spaces": true,
|
5 |
+
"eos_token": "<|endoftext|>",
|
6 |
+
"model_max_length": 1000000000000000019884624838656,
|
7 |
+
"tokenizer_class": "GPTNeoXTokenizer",
|
8 |
+
"unk_token": "<|endoftext|>"
|
9 |
+
}
|