add prefetch
Browse files
main.py
CHANGED
@@ -10,6 +10,7 @@ import json
|
|
10 |
import os, glob
|
11 |
from callbacks import BreakEachEpoch
|
12 |
import subprocess
|
|
|
13 |
|
14 |
logging.set_verbosity_info()
|
15 |
|
@@ -70,8 +71,34 @@ def prepare_dataset(batch, processor):
|
|
70 |
return batch
|
71 |
|
72 |
|
73 |
-
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=
|
74 |
dataset = load_from_disk(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
|
76 |
batch_size=32,
|
77 |
num_proc=num_proc,
|
@@ -83,6 +110,7 @@ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_ma
|
|
83 |
batched=True,
|
84 |
fn_kwargs={"processor": processor},
|
85 |
cache_file_name=cache_file_map_name)
|
|
|
86 |
return processed_dataset
|
87 |
|
88 |
|
@@ -95,6 +123,44 @@ def commit_checkpoint():
|
|
95 |
for command in submit_commands:
|
96 |
print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
if __name__ == "__main__":
|
99 |
|
100 |
checkpoint_path = "./model-bin/finetune/base/"
|
@@ -106,9 +172,13 @@ if __name__ == "__main__":
|
|
106 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
107 |
|
108 |
cache_processing_dataset_folder = './data-bin/cache/'
|
|
|
109 |
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
110 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|
111 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
|
|
|
|
|
|
|
112 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
113 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
114 |
num_epochs = 5000
|
@@ -121,7 +191,7 @@ if __name__ == "__main__":
|
|
121 |
per_device_eval_batch_size=32,
|
122 |
gradient_accumulation_steps=2,
|
123 |
num_train_epochs=num_epochs, # each epoch per shard data
|
124 |
-
logging_steps=
|
125 |
learning_rate=1e-5,
|
126 |
weight_decay=0.005,
|
127 |
warmup_steps=1000,
|
@@ -150,13 +220,23 @@ if __name__ == "__main__":
|
|
150 |
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
|
151 |
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
|
152 |
|
|
|
|
|
153 |
for epoch_idx in range(last_epoch_idx, num_epochs):
|
154 |
-
# loop over training shards
|
155 |
-
train_dataset_shard_idx = epoch_idx % num_train_shards
|
156 |
-
# Get test shard depend on train shard id
|
157 |
-
test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
158 |
-
num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
159 |
-
idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
# load train shard
|
162 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
@@ -170,7 +250,7 @@ if __name__ == "__main__":
|
|
170 |
'train',
|
171 |
'cache-train-map-shard-{}.arrow'.format(
|
172 |
train_dataset_shard_idx)),
|
173 |
-
)
|
174 |
# load test shard subset
|
175 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
176 |
'shard_{}'.format(test_dataset_shard_idx)),
|
@@ -184,6 +264,12 @@ if __name__ == "__main__":
|
|
184 |
test_dataset_shard_idx))
|
185 |
)
|
186 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
# Init trainer
|
188 |
if trainer is None:
|
189 |
trainer = Trainer(
|
@@ -216,5 +302,5 @@ if __name__ == "__main__":
|
|
216 |
test_dataset.cleanup_cache_files()
|
217 |
train_dataset.cleanup_cache_files()
|
218 |
|
219 |
-
if epoch_idx %
|
220 |
-
|
|
|
10 |
import os, glob
|
11 |
from callbacks import BreakEachEpoch
|
12 |
import subprocess
|
13 |
+
from multiprocessing import Process
|
14 |
|
15 |
logging.set_verbosity_info()
|
16 |
|
|
|
71 |
return batch
|
72 |
|
73 |
|
74 |
+
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=6):
|
75 |
dataset = load_from_disk(path)
|
76 |
+
list_cache_prefetch_files = glob.glob(
|
77 |
+
cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
|
78 |
+
'.arrow', '*'))
|
79 |
+
|
80 |
+
# Do not re-compute what already in cache folder
|
81 |
+
if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch):
|
82 |
+
if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch,
|
83 |
+
cache_processing_dataset_folder).replace('.arrow', '*'))) > 0:
|
84 |
+
return
|
85 |
+
if len(list_cache_prefetch_files) > 0:
|
86 |
+
return
|
87 |
+
|
88 |
+
# check cache file
|
89 |
+
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
|
90 |
+
for item_file in list_cache_prefetch_files:
|
91 |
+
os.rename(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
|
92 |
+
cache_processing_dataset_folder))
|
93 |
+
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
|
94 |
+
return dataset.map(prepare_dataset,
|
95 |
+
remove_columns=dataset.column_names,
|
96 |
+
batch_size=32,
|
97 |
+
num_proc=num_proc,
|
98 |
+
batched=True,
|
99 |
+
fn_kwargs={"processor": processor},
|
100 |
+
cache_file_name=cache_file_map_name)
|
101 |
+
|
102 |
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
|
103 |
batch_size=32,
|
104 |
num_proc=num_proc,
|
|
|
110 |
batched=True,
|
111 |
fn_kwargs={"processor": processor},
|
112 |
cache_file_name=cache_file_map_name)
|
113 |
+
processed_dataset.cleanup_cache_files()
|
114 |
return processed_dataset
|
115 |
|
116 |
|
|
|
123 |
for command in submit_commands:
|
124 |
print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
|
125 |
|
126 |
+
|
127 |
+
def get_train_test_shard_id(epoch_count):
|
128 |
+
# loop over training shards
|
129 |
+
_train_dataset_shard_idx = epoch_count % num_train_shards
|
130 |
+
# Get test shard depend on train shard id
|
131 |
+
_test_dataset_shard_idx = round(_train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
132 |
+
_num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
133 |
+
_idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
|
134 |
+
return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
|
135 |
+
|
136 |
+
|
137 |
+
def process_prefetch_epoch(epoch_count):
|
138 |
+
train_shard_idx, test_shard_idx, _, _ = get_train_test_shard_id(epoch_count)
|
139 |
+
load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
140 |
+
'shard_{}'.format(train_shard_idx)),
|
141 |
+
w2v_ctc_processor,
|
142 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
|
143 |
+
'train',
|
144 |
+
'cache-train-filter-shard-{}.arrow'.format(
|
145 |
+
train_shard_idx)),
|
146 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch,
|
147 |
+
'train',
|
148 |
+
'cache-train-map-shard-{}.arrow'.format(
|
149 |
+
train_shard_idx)),
|
150 |
+
)
|
151 |
+
load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
152 |
+
'shard_{}'.format(test_shard_idx)),
|
153 |
+
w2v_ctc_processor,
|
154 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
|
155 |
+
'test',
|
156 |
+
'cache-test-filter-shard-{}.arrow'.format(
|
157 |
+
test_shard_idx)),
|
158 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test',
|
159 |
+
'cache-test-map-shard-{}.arrow'.format(
|
160 |
+
test_shard_idx))
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
if __name__ == "__main__":
|
165 |
|
166 |
checkpoint_path = "./model-bin/finetune/base/"
|
|
|
172 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
173 |
|
174 |
cache_processing_dataset_folder = './data-bin/cache/'
|
175 |
+
cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
|
176 |
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
177 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|
178 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
|
179 |
+
if not os.path.exists(os.path.join(cache_processing_dataset_folder_prefetch, 'train')):
|
180 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'train'))
|
181 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'test'))
|
182 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
183 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
184 |
num_epochs = 5000
|
|
|
191 |
per_device_eval_batch_size=32,
|
192 |
gradient_accumulation_steps=2,
|
193 |
num_train_epochs=num_epochs, # each epoch per shard data
|
194 |
+
logging_steps=5,
|
195 |
learning_rate=1e-5,
|
196 |
weight_decay=0.005,
|
197 |
warmup_steps=1000,
|
|
|
220 |
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
|
221 |
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
|
222 |
|
223 |
+
prefetch_process = []
|
224 |
+
|
225 |
for epoch_idx in range(last_epoch_idx, num_epochs):
|
226 |
+
# # loop over training shards
|
227 |
+
# train_dataset_shard_idx = epoch_idx % num_train_shards
|
228 |
+
# # Get test shard depend on train shard id
|
229 |
+
# test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
230 |
+
# num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
231 |
+
# idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
|
232 |
+
|
233 |
+
train_dataset_shard_idx, test_dataset_shard_idx, num_test_sub_shard, idx_sub_shard = get_train_test_shard_id(
|
234 |
+
epoch_idx)
|
235 |
+
|
236 |
+
# waiting for all prefetch process done
|
237 |
+
for process_instance in prefetch_process:
|
238 |
+
process_instance.join()
|
239 |
+
prefetch_process.clear()
|
240 |
|
241 |
# load train shard
|
242 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
|
|
250 |
'train',
|
251 |
'cache-train-map-shard-{}.arrow'.format(
|
252 |
train_dataset_shard_idx)),
|
253 |
+
) # .shard(1000, 0) # Remove shard split when train
|
254 |
# load test shard subset
|
255 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
256 |
'shard_{}'.format(test_dataset_shard_idx)),
|
|
|
264 |
test_dataset_shard_idx))
|
265 |
)
|
266 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
267 |
+
|
268 |
+
# Prefetch_dataset
|
269 |
+
prefetch_process.append(Process(target=process_prefetch_epoch, args=(epoch_idx + 1,)))
|
270 |
+
for process_instance in prefetch_process:
|
271 |
+
process_instance.start()
|
272 |
+
|
273 |
# Init trainer
|
274 |
if trainer is None:
|
275 |
trainer = Trainer(
|
|
|
302 |
test_dataset.cleanup_cache_files()
|
303 |
train_dataset.cleanup_cache_files()
|
304 |
|
305 |
+
if epoch_idx % 5 == 0:
|
306 |
+
commit_checkpoint()
|