Check commited on
Commit
d995c83
·
1 Parent(s): c79f680

add prefetch

Browse files
Files changed (1) hide show
  1. main.py +97 -11
main.py CHANGED
@@ -10,6 +10,7 @@ import json
10
  import os, glob
11
  from callbacks import BreakEachEpoch
12
  import subprocess
 
13
 
14
  logging.set_verbosity_info()
15
 
@@ -70,8 +71,34 @@ def prepare_dataset(batch, processor):
70
  return batch
71
 
72
 
73
- def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=8):
74
  dataset = load_from_disk(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
76
  batch_size=32,
77
  num_proc=num_proc,
@@ -83,6 +110,7 @@ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_ma
83
  batched=True,
84
  fn_kwargs={"processor": processor},
85
  cache_file_name=cache_file_map_name)
 
86
  return processed_dataset
87
 
88
 
@@ -95,6 +123,44 @@ def commit_checkpoint():
95
  for command in submit_commands:
96
  print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if __name__ == "__main__":
99
 
100
  checkpoint_path = "./model-bin/finetune/base/"
@@ -106,9 +172,13 @@ if __name__ == "__main__":
106
  test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
107
 
108
  cache_processing_dataset_folder = './data-bin/cache/'
 
109
  if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
110
  os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
111
  os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
 
 
 
112
  num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
113
  num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
114
  num_epochs = 5000
@@ -121,7 +191,7 @@ if __name__ == "__main__":
121
  per_device_eval_batch_size=32,
122
  gradient_accumulation_steps=2,
123
  num_train_epochs=num_epochs, # each epoch per shard data
124
- logging_steps=1,
125
  learning_rate=1e-5,
126
  weight_decay=0.005,
127
  warmup_steps=1000,
@@ -150,13 +220,23 @@ if __name__ == "__main__":
150
  w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
151
  data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
152
 
 
 
153
  for epoch_idx in range(last_epoch_idx, num_epochs):
154
- # loop over training shards
155
- train_dataset_shard_idx = epoch_idx % num_train_shards
156
- # Get test shard depend on train shard id
157
- test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
158
- num_test_sub_shard = 8 # Split test shard into subset. Default is 8
159
- idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
 
 
 
 
 
 
 
 
160
 
161
  # load train shard
162
  train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
@@ -170,7 +250,7 @@ if __name__ == "__main__":
170
  'train',
171
  'cache-train-map-shard-{}.arrow'.format(
172
  train_dataset_shard_idx)),
173
- ) #.shard(1000, 0) # Remove shard split when train
174
  # load test shard subset
175
  test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
176
  'shard_{}'.format(test_dataset_shard_idx)),
@@ -184,6 +264,12 @@ if __name__ == "__main__":
184
  test_dataset_shard_idx))
185
  )
186
  test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
 
 
 
 
 
 
187
  # Init trainer
188
  if trainer is None:
189
  trainer = Trainer(
@@ -216,5 +302,5 @@ if __name__ == "__main__":
216
  test_dataset.cleanup_cache_files()
217
  train_dataset.cleanup_cache_files()
218
 
219
- if epoch_idx % 10 == 0:
220
- commit_checkpoint()
 
10
  import os, glob
11
  from callbacks import BreakEachEpoch
12
  import subprocess
13
+ from multiprocessing import Process
14
 
15
  logging.set_verbosity_info()
16
 
 
71
  return batch
72
 
73
 
74
+ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=6):
75
  dataset = load_from_disk(path)
76
+ list_cache_prefetch_files = glob.glob(
77
+ cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
78
+ '.arrow', '*'))
79
+
80
+ # Do not re-compute what already in cache folder
81
+ if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch):
82
+ if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch,
83
+ cache_processing_dataset_folder).replace('.arrow', '*'))) > 0:
84
+ return
85
+ if len(list_cache_prefetch_files) > 0:
86
+ return
87
+
88
+ # check cache file
89
+ if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
90
+ for item_file in list_cache_prefetch_files:
91
+ os.rename(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
92
+ cache_processing_dataset_folder))
93
+ if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
94
+ return dataset.map(prepare_dataset,
95
+ remove_columns=dataset.column_names,
96
+ batch_size=32,
97
+ num_proc=num_proc,
98
+ batched=True,
99
+ fn_kwargs={"processor": processor},
100
+ cache_file_name=cache_file_map_name)
101
+
102
  dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
103
  batch_size=32,
104
  num_proc=num_proc,
 
110
  batched=True,
111
  fn_kwargs={"processor": processor},
112
  cache_file_name=cache_file_map_name)
113
+ processed_dataset.cleanup_cache_files()
114
  return processed_dataset
115
 
116
 
 
123
  for command in submit_commands:
124
  print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
125
 
126
+
127
+ def get_train_test_shard_id(epoch_count):
128
+ # loop over training shards
129
+ _train_dataset_shard_idx = epoch_count % num_train_shards
130
+ # Get test shard depend on train shard id
131
+ _test_dataset_shard_idx = round(_train_dataset_shard_idx / (num_train_shards / num_test_shards))
132
+ _num_test_sub_shard = 8 # Split test shard into subset. Default is 8
133
+ _idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
134
+ return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
135
+
136
+
137
+ def process_prefetch_epoch(epoch_count):
138
+ train_shard_idx, test_shard_idx, _, _ = get_train_test_shard_id(epoch_count)
139
+ load_prepared_dataset(os.path.join(train_dataset_root_folder,
140
+ 'shard_{}'.format(train_shard_idx)),
141
+ w2v_ctc_processor,
142
+ cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
143
+ 'train',
144
+ 'cache-train-filter-shard-{}.arrow'.format(
145
+ train_shard_idx)),
146
+ cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch,
147
+ 'train',
148
+ 'cache-train-map-shard-{}.arrow'.format(
149
+ train_shard_idx)),
150
+ )
151
+ load_prepared_dataset(os.path.join(test_dataset_root_folder,
152
+ 'shard_{}'.format(test_shard_idx)),
153
+ w2v_ctc_processor,
154
+ cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
155
+ 'test',
156
+ 'cache-test-filter-shard-{}.arrow'.format(
157
+ test_shard_idx)),
158
+ cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test',
159
+ 'cache-test-map-shard-{}.arrow'.format(
160
+ test_shard_idx))
161
+ )
162
+
163
+
164
  if __name__ == "__main__":
165
 
166
  checkpoint_path = "./model-bin/finetune/base/"
 
172
  test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
173
 
174
  cache_processing_dataset_folder = './data-bin/cache/'
175
+ cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
176
  if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
177
  os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
178
  os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
179
+ if not os.path.exists(os.path.join(cache_processing_dataset_folder_prefetch, 'train')):
180
+ os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'train'))
181
+ os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'test'))
182
  num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
183
  num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
184
  num_epochs = 5000
 
191
  per_device_eval_batch_size=32,
192
  gradient_accumulation_steps=2,
193
  num_train_epochs=num_epochs, # each epoch per shard data
194
+ logging_steps=5,
195
  learning_rate=1e-5,
196
  weight_decay=0.005,
197
  warmup_steps=1000,
 
220
  w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
221
  data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
222
 
223
+ prefetch_process = []
224
+
225
  for epoch_idx in range(last_epoch_idx, num_epochs):
226
+ # # loop over training shards
227
+ # train_dataset_shard_idx = epoch_idx % num_train_shards
228
+ # # Get test shard depend on train shard id
229
+ # test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
230
+ # num_test_sub_shard = 8 # Split test shard into subset. Default is 8
231
+ # idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
232
+
233
+ train_dataset_shard_idx, test_dataset_shard_idx, num_test_sub_shard, idx_sub_shard = get_train_test_shard_id(
234
+ epoch_idx)
235
+
236
+ # waiting for all prefetch process done
237
+ for process_instance in prefetch_process:
238
+ process_instance.join()
239
+ prefetch_process.clear()
240
 
241
  # load train shard
242
  train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
 
250
  'train',
251
  'cache-train-map-shard-{}.arrow'.format(
252
  train_dataset_shard_idx)),
253
+ ) # .shard(1000, 0) # Remove shard split when train
254
  # load test shard subset
255
  test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
256
  'shard_{}'.format(test_dataset_shard_idx)),
 
264
  test_dataset_shard_idx))
265
  )
266
  test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
267
+
268
+ # Prefetch_dataset
269
+ prefetch_process.append(Process(target=process_prefetch_epoch, args=(epoch_idx + 1,)))
270
+ for process_instance in prefetch_process:
271
+ process_instance.start()
272
+
273
  # Init trainer
274
  if trainer is None:
275
  trainer = Trainer(
 
302
  test_dataset.cleanup_cache_files()
303
  train_dataset.cleanup_cache_files()
304
 
305
+ if epoch_idx % 5 == 0:
306
+ commit_checkpoint()