add wandb integration
Browse files- README.md +5 -0
- run_clm_flax.py +0 -3
README.md
CHANGED
@@ -1 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# GPT2-medium-indonesian
|
|
|
1 |
+
---
|
2 |
+
widget:
|
3 |
+
- text: "Sewindu sudah kita tak berjumpa, rinduku padamu sudah tak terkira."
|
4 |
+
---
|
5 |
+
|
6 |
# GPT2-medium-indonesian
|
run_clm_flax.py
CHANGED
@@ -606,8 +606,6 @@ def main():
|
|
606 |
epochs.write(
|
607 |
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
608 |
)
|
609 |
-
if jax.process_index() == 0:
|
610 |
-
wandb.log({'my_metric': train_metrics})
|
611 |
|
612 |
train_metrics = []
|
613 |
|
@@ -640,7 +638,6 @@ def main():
|
|
640 |
if has_tensorboard and jax.process_index() == 0:
|
641 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
642 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
643 |
-
wandb.log({'my_metric': eval_metrics})
|
644 |
|
645 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
646 |
# save checkpoint after each epoch and push checkpoint to the hub
|
|
|
606 |
epochs.write(
|
607 |
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
608 |
)
|
|
|
|
|
609 |
|
610 |
train_metrics = []
|
611 |
|
|
|
638 |
if has_tensorboard and jax.process_index() == 0:
|
639 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
640 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
|
|
641 |
|
642 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
643 |
# save checkpoint after each epoch and push checkpoint to the hub
|