Commit
·
1e275bf
1
Parent(s):
b839dd6
add config for training multi epochs
Browse files- callbacks.py +12 -0
- main.py +29 -22
callbacks.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging
|
2 |
+
|
3 |
+
|
4 |
+
class BreakEachEpoch(TrainerCallback):
|
5 |
+
"""
|
6 |
+
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
|
7 |
+
and checkpoints.
|
8 |
+
"""
|
9 |
+
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
10 |
+
control.should_training_stop = True
|
11 |
+
logging.get_logger().info("Break each epoch for reload new shard dataset")
|
12 |
+
return control
|
main.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
|
|
2 |
from datasets import load_from_disk
|
3 |
from data_handler import DataCollatorCTCWithPadding
|
4 |
from transformers import TrainingArguments
|
5 |
from transformers import Trainer, logging
|
6 |
from metric_utils import compute_metrics_fn
|
7 |
from transformers.trainer_utils import get_last_checkpoint
|
8 |
-
import json
|
9 |
import os, glob
|
|
|
10 |
|
11 |
logging.set_verbosity_info()
|
12 |
|
@@ -68,8 +70,8 @@ def load_prepared_dataset(path, processor, cache_file_name):
|
|
68 |
dataset = load_from_disk(path)
|
69 |
processed_dataset = dataset.map(prepare_dataset,
|
70 |
remove_columns=dataset.column_names,
|
71 |
-
batch_size=
|
72 |
-
num_proc=
|
73 |
batched=True,
|
74 |
fn_kwargs={"processor": processor},
|
75 |
cache_file_name=cache_file_name)
|
@@ -90,8 +92,9 @@ if __name__ == "__main__":
|
|
90 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
91 |
|
92 |
cache_processing_dataset_folder = './data-bin/cache/'
|
93 |
-
if not os.path.exists(cache_processing_dataset_folder):
|
94 |
-
os.makedirs(cache_processing_dataset_folder)
|
|
|
95 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
96 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
97 |
num_epochs = 5000
|
@@ -100,20 +103,21 @@ if __name__ == "__main__":
|
|
100 |
output_dir=checkpoint_path,
|
101 |
# fp16=True,
|
102 |
group_by_length=True,
|
103 |
-
per_device_train_batch_size=
|
104 |
-
per_device_eval_batch_size=
|
105 |
gradient_accumulation_steps=8,
|
106 |
-
num_train_epochs=
|
107 |
logging_steps=1,
|
108 |
learning_rate=1e-4,
|
109 |
weight_decay=0.005,
|
110 |
-
warmup_steps=
|
111 |
save_total_limit=2,
|
112 |
ignore_data_skip=True,
|
113 |
logging_dir=os.path.join(checkpoint_path, 'log'),
|
114 |
metric_for_best_model='wer',
|
115 |
save_strategy="epoch",
|
116 |
evaluation_strategy="epoch",
|
|
|
117 |
# save_steps=5,
|
118 |
# eval_steps=5,
|
119 |
)
|
@@ -143,19 +147,19 @@ if __name__ == "__main__":
|
|
143 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
144 |
'shard_{}'.format(train_dataset_shard_idx)),
|
145 |
w2v_ctc_processor,
|
146 |
-
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
147 |
'cache-train-shard-{}.arrow'.format(
|
148 |
train_dataset_shard_idx))
|
149 |
-
)
|
150 |
# load test shard subset
|
151 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
152 |
'shard_{}'.format(test_dataset_shard_idx)),
|
153 |
w2v_ctc_processor,
|
154 |
-
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
155 |
'cache-test-shard-{}.arrow'.format(
|
156 |
test_dataset_shard_idx))
|
157 |
-
)
|
158 |
-
|
159 |
# Init trainer
|
160 |
trainer = Trainer(
|
161 |
model=w2v_ctc_model,
|
@@ -164,13 +168,16 @@ if __name__ == "__main__":
|
|
164 |
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
|
165 |
train_dataset=train_dataset,
|
166 |
eval_dataset=test_dataset,
|
167 |
-
tokenizer=w2v_ctc_processor.feature_extractor
|
|
|
168 |
)
|
169 |
-
# Manual add num_train_epochs because each epoch loop over a shard
|
170 |
-
training_args.num_train_epochs = epoch_idx + 1
|
171 |
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
|
175 |
if last_checkpoint_path is not None:
|
176 |
# start train from a checkpoint if exist
|
@@ -181,5 +188,5 @@ if __name__ == "__main__":
|
|
181 |
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
|
182 |
|
183 |
# Clear cache file to free disk
|
184 |
-
|
185 |
-
|
|
|
1 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \
|
2 |
+
TrainerCallback
|
3 |
from datasets import load_from_disk
|
4 |
from data_handler import DataCollatorCTCWithPadding
|
5 |
from transformers import TrainingArguments
|
6 |
from transformers import Trainer, logging
|
7 |
from metric_utils import compute_metrics_fn
|
8 |
from transformers.trainer_utils import get_last_checkpoint
|
9 |
+
import json
|
10 |
import os, glob
|
11 |
+
from callbacks import BreakEachEpoch
|
12 |
|
13 |
logging.set_verbosity_info()
|
14 |
|
|
|
70 |
dataset = load_from_disk(path)
|
71 |
processed_dataset = dataset.map(prepare_dataset,
|
72 |
remove_columns=dataset.column_names,
|
73 |
+
batch_size=32,
|
74 |
+
num_proc=4,
|
75 |
batched=True,
|
76 |
fn_kwargs={"processor": processor},
|
77 |
cache_file_name=cache_file_name)
|
|
|
92 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
93 |
|
94 |
cache_processing_dataset_folder = './data-bin/cache/'
|
95 |
+
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
96 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|
97 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
|
98 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
99 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
100 |
num_epochs = 5000
|
|
|
103 |
output_dir=checkpoint_path,
|
104 |
# fp16=True,
|
105 |
group_by_length=True,
|
106 |
+
per_device_train_batch_size=4,
|
107 |
+
per_device_eval_batch_size=4,
|
108 |
gradient_accumulation_steps=8,
|
109 |
+
num_train_epochs=num_epochs, # each epoch per shard data
|
110 |
logging_steps=1,
|
111 |
learning_rate=1e-4,
|
112 |
weight_decay=0.005,
|
113 |
+
warmup_steps=1000,
|
114 |
save_total_limit=2,
|
115 |
ignore_data_skip=True,
|
116 |
logging_dir=os.path.join(checkpoint_path, 'log'),
|
117 |
metric_for_best_model='wer',
|
118 |
save_strategy="epoch",
|
119 |
evaluation_strategy="epoch",
|
120 |
+
greater_is_better=False,
|
121 |
# save_steps=5,
|
122 |
# eval_steps=5,
|
123 |
)
|
|
|
147 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
148 |
'shard_{}'.format(train_dataset_shard_idx)),
|
149 |
w2v_ctc_processor,
|
150 |
+
cache_file_name=os.path.join(cache_processing_dataset_folder, 'train',
|
151 |
'cache-train-shard-{}.arrow'.format(
|
152 |
train_dataset_shard_idx))
|
153 |
+
).shard(1000, 0) # Remove shard split when train
|
154 |
# load test shard subset
|
155 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
156 |
'shard_{}'.format(test_dataset_shard_idx)),
|
157 |
w2v_ctc_processor,
|
158 |
+
cache_file_name=os.path.join(cache_processing_dataset_folder, 'test',
|
159 |
'cache-test-shard-{}.arrow'.format(
|
160 |
test_dataset_shard_idx))
|
161 |
+
)
|
162 |
+
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
163 |
# Init trainer
|
164 |
trainer = Trainer(
|
165 |
model=w2v_ctc_model,
|
|
|
168 |
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
|
169 |
train_dataset=train_dataset,
|
170 |
eval_dataset=test_dataset,
|
171 |
+
tokenizer=w2v_ctc_processor.feature_extractor,
|
172 |
+
callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
|
173 |
)
|
|
|
|
|
174 |
|
175 |
+
# training_args.num_train_epochs = epoch_idx + 1
|
176 |
+
|
177 |
+
logging.get_logger().info('Train epoch {}'.format(training_args.num_train_epochs))
|
178 |
+
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
|
179 |
+
logging.get_logger().info(
|
180 |
+
'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard))
|
181 |
|
182 |
if last_checkpoint_path is not None:
|
183 |
# start train from a checkpoint if exist
|
|
|
188 |
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
|
189 |
|
190 |
# Clear cache file to free disk
|
191 |
+
test_dataset.cleanup_cache_files()
|
192 |
+
train_dataset.cleanup_cache_files()
|