KoichiYasuoka
commited on
Commit
•
353db10
1
Parent(s):
7a3c511
bug fix
Browse files- config.json +1 -1
- maker.sh +2 -2
config.json
CHANGED
@@ -153,7 +153,7 @@
|
|
153 |
},
|
154 |
"max_position_embeddings": 32768,
|
155 |
"mlp_bias": false,
|
156 |
-
"model_type": "
|
157 |
"num_attention_heads": 32,
|
158 |
"num_hidden_layers": 32,
|
159 |
"num_key_value_heads": 8,
|
|
|
153 |
},
|
154 |
"max_position_embeddings": 32768,
|
155 |
"mlp_bias": false,
|
156 |
+
"model_type": "mistral",
|
157 |
"num_attention_heads": 32,
|
158 |
"num_hidden_layers": 32,
|
159 |
"num_key_value_heads": 8,
|
maker.sh
CHANGED
@@ -8,11 +8,11 @@ then TMPA=./maker$$a.py
|
|
8 |
src="Rakuten/RakutenAI-7B"
|
9 |
tgt="exRakutenAI-7B"
|
10 |
import json,torch,unicodedata
|
11 |
-
from transformers import LlamaTokenizerFast,
|
12 |
tkz=LlamaTokenizerFast.from_pretrained(src,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
|
13 |
d=json.loads(tkz.backend_tokenizer.to_str())
|
14 |
tkz.backend_tokenizer.from_str(json.dumps(d)).save("tokenizer.json")
|
15 |
-
mdl=
|
16 |
tkz=LlamaTokenizerFast(tokenizer_file="tokenizer.json",model_max_length=mdl.config.max_position_embeddings,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
|
17 |
e=mdl.resize_token_embeddings(len(tkz))
|
18 |
f=mdl.get_output_embeddings()
|
|
|
8 |
src="Rakuten/RakutenAI-7B"
|
9 |
tgt="exRakutenAI-7B"
|
10 |
import json,torch,unicodedata
|
11 |
+
from transformers import LlamaTokenizerFast,MistralForCausalLM
|
12 |
tkz=LlamaTokenizerFast.from_pretrained(src,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
|
13 |
d=json.loads(tkz.backend_tokenizer.to_str())
|
14 |
tkz.backend_tokenizer.from_str(json.dumps(d)).save("tokenizer.json")
|
15 |
+
mdl=MistralForCausalLM.from_pretrained(src)
|
16 |
tkz=LlamaTokenizerFast(tokenizer_file="tokenizer.json",model_max_length=mdl.config.max_position_embeddings,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
|
17 |
e=mdl.resize_token_embeddings(len(tkz))
|
18 |
f=mdl.get_output_embeddings()
|