add wandb integration
Browse files- run_clm_flax.py +18 -0
run_clm_flax.py
CHANGED
@@ -53,6 +53,7 @@ from transformers import (
|
|
53 |
is_tensorboard_available,
|
54 |
)
|
55 |
from transformers.testing_utils import CaptureLogger
|
|
|
56 |
|
57 |
|
58 |
logger = logging.getLogger(__name__)
|
@@ -232,6 +233,13 @@ def main():
|
|
232 |
# or by passing the --help flag to this script.
|
233 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
236 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
237 |
# If we pass only one argument to the script and it's the path to a json file,
|
@@ -250,6 +258,13 @@ def main():
|
|
250 |
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
251 |
"Use --overwrite_output_dir to overcome."
|
252 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
# Make one log on every process with the configuration for debugging.
|
255 |
logging.basicConfig(
|
@@ -591,6 +606,8 @@ def main():
|
|
591 |
epochs.write(
|
592 |
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
593 |
)
|
|
|
|
|
594 |
|
595 |
train_metrics = []
|
596 |
|
@@ -623,6 +640,7 @@ def main():
|
|
623 |
if has_tensorboard and jax.process_index() == 0:
|
624 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
625 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
|
|
626 |
|
627 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
628 |
# save checkpoint after each epoch and push checkpoint to the hub
|
|
|
53 |
is_tensorboard_available,
|
54 |
)
|
55 |
from transformers.testing_utils import CaptureLogger
|
56 |
+
import wandb
|
57 |
|
58 |
|
59 |
logger = logging.getLogger(__name__)
|
|
|
233 |
# or by passing the --help flag to this script.
|
234 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
235 |
|
236 |
+
if jax.process_index() == 0:
|
237 |
+
wandb.init(
|
238 |
+
entity = os.getenv("WANDB_ENTITY", "indonesian-nlp"),
|
239 |
+
project = os.getenv("WANDB_PROJECT", "huggingface"),
|
240 |
+
sync_tensorboard =True
|
241 |
+
)
|
242 |
+
|
243 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
244 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
245 |
# If we pass only one argument to the script and it's the path to a json file,
|
|
|
258 |
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
259 |
"Use --overwrite_output_dir to overcome."
|
260 |
)
|
261 |
+
# log your configs with wandb.config, accepts a dict
|
262 |
+
if jax.process_index() == 0:
|
263 |
+
wandb.config.update(training_args) # optional, log your configs
|
264 |
+
wandb.config.update(model_args) # optional, log your configs
|
265 |
+
wandb.config.update(data_args) # optional, log your configs
|
266 |
+
|
267 |
+
wandb.config['test_log'] = 12345 # log additional things
|
268 |
|
269 |
# Make one log on every process with the configuration for debugging.
|
270 |
logging.basicConfig(
|
|
|
606 |
epochs.write(
|
607 |
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
608 |
)
|
609 |
+
if jax.process_index() == 0:
|
610 |
+
wandb.log({'my_metric': train_metrics})
|
611 |
|
612 |
train_metrics = []
|
613 |
|
|
|
640 |
if has_tensorboard and jax.process_index() == 0:
|
641 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
642 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
643 |
+
wandb.log({'my_metric': eval_metrics})
|
644 |
|
645 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
646 |
# save checkpoint after each epoch and push checkpoint to the hub
|