update jax converter
Browse files- jax2torch.py +6 -2
- run_pretraining.sh +1 -0
jax2torch.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
-
from transformers import
|
2 |
|
3 |
'''
|
4 |
-
This is a script to convert the Jax model to Pytorch model
|
5 |
'''
|
6 |
|
7 |
model = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
|
8 |
model.save_pretrained(".")
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, GPT2LMHeadModel
|
2 |
|
3 |
'''
|
4 |
+
This is a script to convert the Jax model and the tokenizer to Pytorch model
|
5 |
'''
|
6 |
|
7 |
model = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
|
8 |
model.save_pretrained(".")
|
9 |
+
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(".")
|
11 |
+
tokenizer.save_pretrained(".")
|
12 |
+
|
run_pretraining.sh
CHANGED
@@ -4,6 +4,7 @@ export WANDB_PROJECT="hf-flax-gpt2-indonesian"
|
|
4 |
export WANDB_LOG_MODEL="true"
|
5 |
|
6 |
./run_clm_flax.py \
|
|
|
7 |
--output_dir="${MODEL_DIR}" \
|
8 |
--model_type="gpt2" \
|
9 |
--config_name="${MODEL_DIR}" \
|
|
|
4 |
export WANDB_LOG_MODEL="true"
|
5 |
|
6 |
./run_clm_flax.py \
|
7 |
+
--model_name_or_path="flax_model.msgpack" \
|
8 |
--output_dir="${MODEL_DIR}" \
|
9 |
--model_type="gpt2" \
|
10 |
--config_name="${MODEL_DIR}" \
|