Update scripts to work around collator valueerror. Update weights
Browse files- config.json +2 -0
- flax_model.msgpack +1 -1
- opt_state.msgpack +3 -0
- pytorch_model.bin +1 -1
- run_t5.sh +37 -22
- run_t5_mlm_flax_custom_dataset.py +5 -0
- runs/{Jul11_17-06-36_t1v-n-0e7426e8-w-0/events.out.tfevents.1626023202.t1v-n-0e7426e8-w-0.178001.3.v2 → Jul12_06-43-08_t1v-n-0e7426e8-w-0/events.out.tfevents.1626072193.t1v-n-0e7426e8-w-0.238699.3.v2} +2 -2
- training_state.json +1 -0
config.json
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
{
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"T5ForConditionalGeneration"
|
| 4 |
],
|
|
@@ -50,6 +51,7 @@
|
|
| 50 |
"prefix": "translate English to Romanian: "
|
| 51 |
}
|
| 52 |
},
|
|
|
|
| 53 |
"transformers_version": "4.9.0.dev0",
|
| 54 |
"use_cache": true,
|
| 55 |
"vocab_size": 32103
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": ".",
|
| 3 |
"architectures": [
|
| 4 |
"T5ForConditionalGeneration"
|
| 5 |
],
|
|
|
|
| 51 |
"prefix": "translate English to Romanian: "
|
| 52 |
}
|
| 53 |
},
|
| 54 |
+
"torch_dtype": "float32",
|
| 55 |
"transformers_version": "4.9.0.dev0",
|
| 56 |
"use_cache": true,
|
| 57 |
"vocab_size": 32103
|
flax_model.msgpack
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 891548548
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c8d5a4eb1275b4c679b148f38edb974772997a3925809f39095204009f83502
|
| 3 |
size 891548548
|
opt_state.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97c0ff372805930fa4d7e81ae09094b7daf3cc2c1ba06224fc522a8e672af91a
|
| 3 |
+
size 1985609
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 891650495
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:782edc5c7aa8aa66320a3417abff572760287ee6a7759f1867486d2217563650
|
| 3 |
size 891650495
|
run_t5.sh
CHANGED
|
@@ -7,28 +7,42 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
| 7 |
# T5 paper lr 0.01 with batch size 128
|
| 8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
#git add pytorch_model.bin
|
|
@@ -37,3 +51,4 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
| 37 |
|
| 38 |
# --gradient_accumulation_steps="2" \
|
| 39 |
|
|
|
|
|
|
| 7 |
# T5 paper lr 0.01 with batch size 128
|
| 8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
| 9 |
|
| 10 |
+
while true; do
|
| 11 |
+
|
| 12 |
+
# Set the seed to random before each run, so date shuffling per epoch is different each run.
|
| 13 |
+
# This kills reproducibility, but is required as long as during training ValueError can be raised.
|
| 14 |
+
SEED=$RANDOM
|
| 15 |
+
|
| 16 |
+
./run_t5_mlm_flax_custom_dataset.py \
|
| 17 |
+
--output_dir="${MODEL_DIR}" \
|
| 18 |
+
--model_type="t5" \
|
| 19 |
+
--config_name="flax-community/${MODEL}" \
|
| 20 |
+
--tokenizer_name="${MODEL_DIR}" \
|
| 21 |
+
--seed="${SEED}" \
|
| 22 |
+
--preprocessing_num_workers="96" \
|
| 23 |
+
--do_train --do_eval \
|
| 24 |
+
--adafactor \
|
| 25 |
+
--max_seq_length="512" \
|
| 26 |
+
--per_device_train_batch_size="32" \
|
| 27 |
+
--per_device_eval_batch_size="32" \
|
| 28 |
+
--learning_rate="5e-3" \
|
| 29 |
+
--dtype="bfloat16" \
|
| 30 |
+
--overwrite_output_dir \
|
| 31 |
+
--num_train_epochs="3" \
|
| 32 |
+
--logging_steps="50" \
|
| 33 |
+
--save_steps="501" \
|
| 34 |
+
--eval_steps="10000000" \
|
| 35 |
+
--resume_from_checkpoint="${MODEL_DIR}" \
|
| 36 |
+
--warmup_steps="3413"
|
| 37 |
+
|
| 38 |
+
# \
|
| 39 |
+
# --push_to_hub
|
| 40 |
+
|
| 41 |
+
echo "RESTARTING"
|
| 42 |
+
sleep 20
|
| 43 |
+
done
|
| 44 |
+
#
|
| 45 |
+
# \
|
| 46 |
|
| 47 |
|
| 48 |
#git add pytorch_model.bin
|
|
|
|
| 51 |
|
| 52 |
# --gradient_accumulation_steps="2" \
|
| 53 |
|
| 54 |
+
# --resume_from_checkpoint="${MODEL_DIR}/ckpt-18000" \
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
|
@@ -432,6 +432,11 @@ def save_checkpoint(model, save_dir, state, with_opt: bool = True):
|
|
| 432 |
push_to_hub=training_args.push_to_hub,
|
| 433 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 434 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
logger.info("checkpoint saved")
|
| 436 |
|
| 437 |
|
|
|
|
| 432 |
push_to_hub=training_args.push_to_hub,
|
| 433 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 434 |
)
|
| 435 |
+
if with_opt:
|
| 436 |
+
with open(os.path.join(training_args.output_dir, "opt_state.msgpack"), "wb") as f:
|
| 437 |
+
f.write(to_bytes(state.opt_state))
|
| 438 |
+
with open(os.path.join(training_args.output_dir, "training_state.json"), "w") as f:
|
| 439 |
+
json.dump({"step": state.step.item()}, f)
|
| 440 |
logger.info("checkpoint saved")
|
| 441 |
|
| 442 |
|
runs/{Jul11_17-06-36_t1v-n-0e7426e8-w-0/events.out.tfevents.1626023202.t1v-n-0e7426e8-w-0.178001.3.v2 → Jul12_06-43-08_t1v-n-0e7426e8-w-0/events.out.tfevents.1626072193.t1v-n-0e7426e8-w-0.238699.3.v2}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f5f6fcc83f8cf7fac87cc276fa00a02c9ce4e252c6bb69a3988452bed73f67e
|
| 3 |
+
size 200238
|
training_state.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"step": 15004}
|