KoichiYasuoka commited on
Commit
353db10
1 Parent(s): 7a3c511
Files changed (2) hide show
  1. config.json +1 -1
  2. maker.sh +2 -2
config.json CHANGED
@@ -153,7 +153,7 @@
153
  },
154
  "max_position_embeddings": 32768,
155
  "mlp_bias": false,
156
- "model_type": "llama",
157
  "num_attention_heads": 32,
158
  "num_hidden_layers": 32,
159
  "num_key_value_heads": 8,
 
153
  },
154
  "max_position_embeddings": 32768,
155
  "mlp_bias": false,
156
+ "model_type": "mistral",
157
  "num_attention_heads": 32,
158
  "num_hidden_layers": 32,
159
  "num_key_value_heads": 8,
maker.sh CHANGED
@@ -8,11 +8,11 @@ then TMPA=./maker$$a.py
8
  src="Rakuten/RakutenAI-7B"
9
  tgt="exRakutenAI-7B"
10
  import json,torch,unicodedata
11
- from transformers import LlamaTokenizerFast,LlamaForCausalLM
12
  tkz=LlamaTokenizerFast.from_pretrained(src,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
13
  d=json.loads(tkz.backend_tokenizer.to_str())
14
  tkz.backend_tokenizer.from_str(json.dumps(d)).save("tokenizer.json")
15
- mdl=LlamaForCausalLM.from_pretrained(src)
16
  tkz=LlamaTokenizerFast(tokenizer_file="tokenizer.json",model_max_length=mdl.config.max_position_embeddings,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
17
  e=mdl.resize_token_embeddings(len(tkz))
18
  f=mdl.get_output_embeddings()
 
8
  src="Rakuten/RakutenAI-7B"
9
  tgt="exRakutenAI-7B"
10
  import json,torch,unicodedata
11
+ from transformers import LlamaTokenizerFast,MistralForCausalLM
12
  tkz=LlamaTokenizerFast.from_pretrained(src,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
13
  d=json.loads(tkz.backend_tokenizer.to_str())
14
  tkz.backend_tokenizer.from_str(json.dumps(d)).save("tokenizer.json")
15
+ mdl=MistralForCausalLM.from_pretrained(src)
16
  tkz=LlamaTokenizerFast(tokenizer_file="tokenizer.json",model_max_length=mdl.config.max_position_embeddings,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
17
  e=mdl.resize_token_embeddings(len(tkz))
18
  f=mdl.get_output_embeddings()