add cache to mem
Browse files
main.py
CHANGED
@@ -11,6 +11,7 @@ import os, glob
|
|
11 |
from callbacks import BreakEachEpoch
|
12 |
import subprocess
|
13 |
from multiprocessing import Process
|
|
|
14 |
|
15 |
logging.set_verbosity_info()
|
16 |
|
@@ -71,7 +72,7 @@ def prepare_dataset(batch, processor):
|
|
71 |
return batch
|
72 |
|
73 |
|
74 |
-
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=
|
75 |
dataset = load_from_disk(path)
|
76 |
list_cache_prefetch_files = glob.glob(
|
77 |
cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
|
@@ -88,7 +89,7 @@ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_ma
|
|
88 |
# check cache file
|
89 |
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
|
90 |
for item_file in list_cache_prefetch_files:
|
91 |
-
|
92 |
cache_processing_dataset_folder))
|
93 |
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
|
94 |
return dataset.map(prepare_dataset,
|
@@ -128,7 +129,7 @@ def get_train_test_shard_id(epoch_count):
|
|
128 |
# loop over training shards
|
129 |
_train_dataset_shard_idx = epoch_count % num_train_shards
|
130 |
# Get test shard depend on train shard id
|
131 |
-
_test_dataset_shard_idx = round(_train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
132 |
_num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
133 |
_idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
|
134 |
return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
|
@@ -171,7 +172,7 @@ if __name__ == "__main__":
|
|
171 |
train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
|
172 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
173 |
|
174 |
-
cache_processing_dataset_folder = '
|
175 |
cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
|
176 |
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
177 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|
|
|
11 |
from callbacks import BreakEachEpoch
|
12 |
import subprocess
|
13 |
from multiprocessing import Process
|
14 |
+
import shutil
|
15 |
|
16 |
logging.set_verbosity_info()
|
17 |
|
|
|
72 |
return batch
|
73 |
|
74 |
|
75 |
+
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5):
|
76 |
dataset = load_from_disk(path)
|
77 |
list_cache_prefetch_files = glob.glob(
|
78 |
cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
|
|
|
89 |
# check cache file
|
90 |
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
|
91 |
for item_file in list_cache_prefetch_files:
|
92 |
+
shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
|
93 |
cache_processing_dataset_folder))
|
94 |
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
|
95 |
return dataset.map(prepare_dataset,
|
|
|
129 |
# loop over training shards
|
130 |
_train_dataset_shard_idx = epoch_count % num_train_shards
|
131 |
# Get test shard depend on train shard id
|
132 |
+
_test_dataset_shard_idx = min(round(_train_dataset_shard_idx / (num_train_shards / num_test_shards)), num_test_shards - 1)
|
133 |
_num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
134 |
_idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
|
135 |
return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
|
|
|
172 |
train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
|
173 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
174 |
|
175 |
+
cache_processing_dataset_folder = '/dev/shm/cache/'
|
176 |
cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
|
177 |
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
178 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|