nguyenvulebinh commited on
Commit
b839dd6
·
1 Parent(s): cae4858

add config for training

Browse files
Files changed (1) hide show
  1. main.py +5 -5
main.py CHANGED
@@ -94,15 +94,15 @@ if __name__ == "__main__":
94
  os.makedirs(cache_processing_dataset_folder)
95
  num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
96
  num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
97
- num_epochs = 20
98
 
99
  training_args = TrainingArguments(
100
  output_dir=checkpoint_path,
101
  # fp16=True,
102
  group_by_length=True,
103
- per_device_train_batch_size=2,
104
- per_device_eval_batch_size=2,
105
- gradient_accumulation_steps=1,
106
  num_train_epochs=1, # each epoch per shard data
107
  logging_steps=1,
108
  learning_rate=1e-4,
@@ -146,7 +146,7 @@ if __name__ == "__main__":
146
  cache_file_name=os.path.join(cache_processing_dataset_folder,
147
  'cache-train-shard-{}.arrow'.format(
148
  train_dataset_shard_idx))
149
- ).shard(1000, 0) # Remove shard split when train
150
  # load test shard subset
151
  test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
152
  'shard_{}'.format(test_dataset_shard_idx)),
 
94
  os.makedirs(cache_processing_dataset_folder)
95
  num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
96
  num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
97
+ num_epochs = 5000
98
 
99
  training_args = TrainingArguments(
100
  output_dir=checkpoint_path,
101
  # fp16=True,
102
  group_by_length=True,
103
+ per_device_train_batch_size=16,
104
+ per_device_eval_batch_size=16,
105
+ gradient_accumulation_steps=8,
106
  num_train_epochs=1, # each epoch per shard data
107
  logging_steps=1,
108
  learning_rate=1e-4,
 
146
  cache_file_name=os.path.join(cache_processing_dataset_folder,
147
  'cache-train-shard-{}.arrow'.format(
148
  train_dataset_shard_idx))
149
+ ) # .shard(1000, 0) # Remove shard split when train
150
  # load test shard subset
151
  test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
152
  'shard_{}'.format(test_dataset_shard_idx)),