fixed the save_steps, make test
Browse files- run_clm_flax.py +2 -2
- run_pretraining.sh +6 -3
run_clm_flax.py
CHANGED
@@ -413,7 +413,8 @@ def main():
|
|
413 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
414 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
415 |
# customize this part to your needs.
|
416 |
-
|
|
|
417 |
# Split by chunks of max_len.
|
418 |
result = {
|
419 |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
@@ -636,7 +637,6 @@ def main():
|
|
636 |
|
637 |
# Save metrics
|
638 |
if has_tensorboard and jax.process_index() == 0:
|
639 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
640 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
641 |
|
642 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
|
|
413 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
414 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
415 |
# customize this part to your needs.
|
416 |
+
if total_length >= block_size:
|
417 |
+
total_length = (total_length // block_size) * block_size
|
418 |
# Split by chunks of max_len.
|
419 |
result = {
|
420 |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
|
637 |
|
638 |
# Save metrics
|
639 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
640 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
641 |
|
642 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
run_pretraining.sh
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
export WANDB_ENTITY="wandb"
|
2 |
export WANDB_PROJECT="hf-flax-gpt2-indonesian"
|
3 |
export WANDB_LOG_MODEL="true"
|
@@ -13,12 +14,14 @@ export WANDB_LOG_MODEL="true"
|
|
13 |
--block_size="512" \
|
14 |
--per_device_train_batch_size="24" \
|
15 |
--per_device_eval_batch_size="24" \
|
16 |
-
--learning_rate="
|
17 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
18 |
--overwrite_output_dir \
|
19 |
--num_train_epochs="20" \
|
20 |
--dataloader_num_workers="64" \
|
21 |
--preprocessing_num_workers="64" \
|
22 |
-
--save_steps="
|
23 |
-
--eval_steps="
|
|
|
|
|
24 |
--push_to_hub
|
|
|
1 |
+
export MODEL_DIR=`pwd`
|
2 |
export WANDB_ENTITY="wandb"
|
3 |
export WANDB_PROJECT="hf-flax-gpt2-indonesian"
|
4 |
export WANDB_LOG_MODEL="true"
|
|
|
14 |
--block_size="512" \
|
15 |
--per_device_train_batch_size="24" \
|
16 |
--per_device_eval_batch_size="24" \
|
17 |
+
--learning_rate="0.0024" --warmup_steps="1000" \
|
18 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
19 |
--overwrite_output_dir \
|
20 |
--num_train_epochs="20" \
|
21 |
--dataloader_num_workers="64" \
|
22 |
--preprocessing_num_workers="64" \
|
23 |
+
--save_steps="10" \
|
24 |
+
--eval_steps="10" \
|
25 |
+
--max_train_samples="10000" \
|
26 |
+
--max_eval_samples="1000" \
|
27 |
--push_to_hub
|