nguyenvulebinh commited on
Commit
1e275bf
·
1 Parent(s): b839dd6

add config for training multi epochs

Browse files
Files changed (2) hide show
  1. callbacks.py +12 -0
  2. main.py +29 -22
callbacks.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging
2
+
3
+
4
+ class BreakEachEpoch(TrainerCallback):
5
+ """
6
+ A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
7
+ and checkpoints.
8
+ """
9
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
10
+ control.should_training_stop = True
11
+ logging.get_logger().info("Break each epoch for reload new shard dataset")
12
+ return control
main.py CHANGED
@@ -1,12 +1,14 @@
1
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
 
2
  from datasets import load_from_disk
3
  from data_handler import DataCollatorCTCWithPadding
4
  from transformers import TrainingArguments
5
  from transformers import Trainer, logging
6
  from metric_utils import compute_metrics_fn
7
  from transformers.trainer_utils import get_last_checkpoint
8
- import json, random
9
  import os, glob
 
10
 
11
  logging.set_verbosity_info()
12
 
@@ -68,8 +70,8 @@ def load_prepared_dataset(path, processor, cache_file_name):
68
  dataset = load_from_disk(path)
69
  processed_dataset = dataset.map(prepare_dataset,
70
  remove_columns=dataset.column_names,
71
- batch_size=8,
72
- num_proc=8,
73
  batched=True,
74
  fn_kwargs={"processor": processor},
75
  cache_file_name=cache_file_name)
@@ -90,8 +92,9 @@ if __name__ == "__main__":
90
  test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
91
 
92
  cache_processing_dataset_folder = './data-bin/cache/'
93
- if not os.path.exists(cache_processing_dataset_folder):
94
- os.makedirs(cache_processing_dataset_folder)
 
95
  num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
96
  num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
97
  num_epochs = 5000
@@ -100,20 +103,21 @@ if __name__ == "__main__":
100
  output_dir=checkpoint_path,
101
  # fp16=True,
102
  group_by_length=True,
103
- per_device_train_batch_size=16,
104
- per_device_eval_batch_size=16,
105
  gradient_accumulation_steps=8,
106
- num_train_epochs=1, # each epoch per shard data
107
  logging_steps=1,
108
  learning_rate=1e-4,
109
  weight_decay=0.005,
110
- warmup_steps=5000,
111
  save_total_limit=2,
112
  ignore_data_skip=True,
113
  logging_dir=os.path.join(checkpoint_path, 'log'),
114
  metric_for_best_model='wer',
115
  save_strategy="epoch",
116
  evaluation_strategy="epoch",
 
117
  # save_steps=5,
118
  # eval_steps=5,
119
  )
@@ -143,19 +147,19 @@ if __name__ == "__main__":
143
  train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
144
  'shard_{}'.format(train_dataset_shard_idx)),
145
  w2v_ctc_processor,
146
- cache_file_name=os.path.join(cache_processing_dataset_folder,
147
  'cache-train-shard-{}.arrow'.format(
148
  train_dataset_shard_idx))
149
- ) # .shard(1000, 0) # Remove shard split when train
150
  # load test shard subset
151
  test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
152
  'shard_{}'.format(test_dataset_shard_idx)),
153
  w2v_ctc_processor,
154
- cache_file_name=os.path.join(cache_processing_dataset_folder,
155
  'cache-test-shard-{}.arrow'.format(
156
  test_dataset_shard_idx))
157
- ).shard(num_test_sub_shard, idx_sub_shard)
158
-
159
  # Init trainer
160
  trainer = Trainer(
161
  model=w2v_ctc_model,
@@ -164,13 +168,16 @@ if __name__ == "__main__":
164
  compute_metrics=compute_metrics_fn(w2v_ctc_processor),
165
  train_dataset=train_dataset,
166
  eval_dataset=test_dataset,
167
- tokenizer=w2v_ctc_processor.feature_extractor
 
168
  )
169
- # Manual add num_train_epochs because each epoch loop over a shard
170
- training_args.num_train_epochs = epoch_idx + 1
171
 
172
- logging.get_logger().info('Train shard idx: {}'.format(train_dataset_shard_idx))
173
- logging.get_logger().info('Valid shard idx: {} sub_shard: {}'.format(test_dataset_shard_idx, idx_sub_shard))
 
 
 
 
174
 
175
  if last_checkpoint_path is not None:
176
  # start train from a checkpoint if exist
@@ -181,5 +188,5 @@ if __name__ == "__main__":
181
  last_checkpoint_path = get_last_checkpoint(checkpoint_path)
182
 
183
  # Clear cache file to free disk
184
- # test_dataset.cleanup_cache_files()
185
- # train_dataset.cleanup_cache_files()
 
1
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \
2
+ TrainerCallback
3
  from datasets import load_from_disk
4
  from data_handler import DataCollatorCTCWithPadding
5
  from transformers import TrainingArguments
6
  from transformers import Trainer, logging
7
  from metric_utils import compute_metrics_fn
8
  from transformers.trainer_utils import get_last_checkpoint
9
+ import json
10
  import os, glob
11
+ from callbacks import BreakEachEpoch
12
 
13
  logging.set_verbosity_info()
14
 
 
70
  dataset = load_from_disk(path)
71
  processed_dataset = dataset.map(prepare_dataset,
72
  remove_columns=dataset.column_names,
73
+ batch_size=32,
74
+ num_proc=4,
75
  batched=True,
76
  fn_kwargs={"processor": processor},
77
  cache_file_name=cache_file_name)
 
92
  test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
93
 
94
  cache_processing_dataset_folder = './data-bin/cache/'
95
+ if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
96
+ os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
97
+ os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
98
  num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
99
  num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
100
  num_epochs = 5000
 
103
  output_dir=checkpoint_path,
104
  # fp16=True,
105
  group_by_length=True,
106
+ per_device_train_batch_size=4,
107
+ per_device_eval_batch_size=4,
108
  gradient_accumulation_steps=8,
109
+ num_train_epochs=num_epochs, # each epoch per shard data
110
  logging_steps=1,
111
  learning_rate=1e-4,
112
  weight_decay=0.005,
113
+ warmup_steps=1000,
114
  save_total_limit=2,
115
  ignore_data_skip=True,
116
  logging_dir=os.path.join(checkpoint_path, 'log'),
117
  metric_for_best_model='wer',
118
  save_strategy="epoch",
119
  evaluation_strategy="epoch",
120
+ greater_is_better=False,
121
  # save_steps=5,
122
  # eval_steps=5,
123
  )
 
147
  train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
148
  'shard_{}'.format(train_dataset_shard_idx)),
149
  w2v_ctc_processor,
150
+ cache_file_name=os.path.join(cache_processing_dataset_folder, 'train',
151
  'cache-train-shard-{}.arrow'.format(
152
  train_dataset_shard_idx))
153
+ ).shard(1000, 0) # Remove shard split when train
154
  # load test shard subset
155
  test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
156
  'shard_{}'.format(test_dataset_shard_idx)),
157
  w2v_ctc_processor,
158
+ cache_file_name=os.path.join(cache_processing_dataset_folder, 'test',
159
  'cache-test-shard-{}.arrow'.format(
160
  test_dataset_shard_idx))
161
+ )
162
+ test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
163
  # Init trainer
164
  trainer = Trainer(
165
  model=w2v_ctc_model,
 
168
  compute_metrics=compute_metrics_fn(w2v_ctc_processor),
169
  train_dataset=train_dataset,
170
  eval_dataset=test_dataset,
171
+ tokenizer=w2v_ctc_processor.feature_extractor,
172
+ callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
173
  )
 
 
174
 
175
+ # training_args.num_train_epochs = epoch_idx + 1
176
+
177
+ logging.get_logger().info('Train epoch {}'.format(training_args.num_train_epochs))
178
+ logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
179
+ logging.get_logger().info(
180
+ 'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard))
181
 
182
  if last_checkpoint_path is not None:
183
  # start train from a checkpoint if exist
 
188
  last_checkpoint_path = get_last_checkpoint(checkpoint_path)
189
 
190
  # Clear cache file to free disk
191
+ test_dataset.cleanup_cache_files()
192
+ train_dataset.cleanup_cache_files()