cahya commited on
Commit
60f90b3
1 Parent(s): db85c97

add wandb integration

Browse files
Files changed (2) hide show
  1. README.md +5 -0
  2. run_clm_flax.py +0 -3
README.md CHANGED
@@ -1 +1,6 @@
 
 
 
 
 
1
  # GPT2-medium-indonesian
 
1
+ ---
2
+ widget:
3
+ - text: "Sewindu sudah kita tak berjumpa, rinduku padamu sudah tak terkira."
4
+ ---
5
+
6
  # GPT2-medium-indonesian
run_clm_flax.py CHANGED
@@ -606,8 +606,6 @@ def main():
606
  epochs.write(
607
  f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
608
  )
609
- if jax.process_index() == 0:
610
- wandb.log({'my_metric': train_metrics})
611
 
612
  train_metrics = []
613
 
@@ -640,7 +638,6 @@ def main():
640
  if has_tensorboard and jax.process_index() == 0:
641
  cur_step = epoch * (len(train_dataset) // train_batch_size)
642
  write_eval_metric(summary_writer, eval_metrics, cur_step)
643
- wandb.log({'my_metric': eval_metrics})
644
 
645
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
646
  # save checkpoint after each epoch and push checkpoint to the hub
 
606
  epochs.write(
607
  f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
608
  )
 
 
609
 
610
  train_metrics = []
611
 
 
638
  if has_tensorboard and jax.process_index() == 0:
639
  cur_step = epoch * (len(train_dataset) // train_batch_size)
640
  write_eval_metric(summary_writer, eval_metrics, cur_step)
 
641
 
642
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
643
  # save checkpoint after each epoch and push checkpoint to the hub