Commit
·
b839dd6
1
Parent(s):
cae4858
add config for training
Browse files
main.py
CHANGED
@@ -94,15 +94,15 @@ if __name__ == "__main__":
|
|
94 |
os.makedirs(cache_processing_dataset_folder)
|
95 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
96 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
97 |
-
num_epochs =
|
98 |
|
99 |
training_args = TrainingArguments(
|
100 |
output_dir=checkpoint_path,
|
101 |
# fp16=True,
|
102 |
group_by_length=True,
|
103 |
-
per_device_train_batch_size=
|
104 |
-
per_device_eval_batch_size=
|
105 |
-
gradient_accumulation_steps=
|
106 |
num_train_epochs=1, # each epoch per shard data
|
107 |
logging_steps=1,
|
108 |
learning_rate=1e-4,
|
@@ -146,7 +146,7 @@ if __name__ == "__main__":
|
|
146 |
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
147 |
'cache-train-shard-{}.arrow'.format(
|
148 |
train_dataset_shard_idx))
|
149 |
-
).shard(1000, 0) # Remove shard split when train
|
150 |
# load test shard subset
|
151 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
152 |
'shard_{}'.format(test_dataset_shard_idx)),
|
|
|
94 |
os.makedirs(cache_processing_dataset_folder)
|
95 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
96 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
97 |
+
num_epochs = 5000
|
98 |
|
99 |
training_args = TrainingArguments(
|
100 |
output_dir=checkpoint_path,
|
101 |
# fp16=True,
|
102 |
group_by_length=True,
|
103 |
+
per_device_train_batch_size=16,
|
104 |
+
per_device_eval_batch_size=16,
|
105 |
+
gradient_accumulation_steps=8,
|
106 |
num_train_epochs=1, # each epoch per shard data
|
107 |
logging_steps=1,
|
108 |
learning_rate=1e-4,
|
|
|
146 |
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
147 |
'cache-train-shard-{}.arrow'.format(
|
148 |
train_dataset_shard_idx))
|
149 |
+
) # .shard(1000, 0) # Remove shard split when train
|
150 |
# load test shard subset
|
151 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
152 |
'shard_{}'.format(test_dataset_shard_idx)),
|