Check commited on
Commit
7bf0ac3
·
1 Parent(s): d0051dc

add cache to mem

Browse files
Files changed (1) hide show
  1. main.py +5 -4
main.py CHANGED
@@ -11,6 +11,7 @@ import os, glob
11
  from callbacks import BreakEachEpoch
12
  import subprocess
13
  from multiprocessing import Process
 
14
 
15
  logging.set_verbosity_info()
16
 
@@ -71,7 +72,7 @@ def prepare_dataset(batch, processor):
71
  return batch
72
 
73
 
74
- def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=6):
75
  dataset = load_from_disk(path)
76
  list_cache_prefetch_files = glob.glob(
77
  cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
@@ -88,7 +89,7 @@ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_ma
88
  # check cache file
89
  if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
90
  for item_file in list_cache_prefetch_files:
91
- os.rename(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
92
  cache_processing_dataset_folder))
93
  if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
94
  return dataset.map(prepare_dataset,
@@ -128,7 +129,7 @@ def get_train_test_shard_id(epoch_count):
128
  # loop over training shards
129
  _train_dataset_shard_idx = epoch_count % num_train_shards
130
  # Get test shard depend on train shard id
131
- _test_dataset_shard_idx = round(_train_dataset_shard_idx / (num_train_shards / num_test_shards))
132
  _num_test_sub_shard = 8 # Split test shard into subset. Default is 8
133
  _idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
134
  return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
@@ -171,7 +172,7 @@ if __name__ == "__main__":
171
  train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
172
  test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
173
 
174
- cache_processing_dataset_folder = './data-bin/cache/'
175
  cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
176
  if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
177
  os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
 
11
  from callbacks import BreakEachEpoch
12
  import subprocess
13
  from multiprocessing import Process
14
+ import shutil
15
 
16
  logging.set_verbosity_info()
17
 
 
72
  return batch
73
 
74
 
75
+ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5):
76
  dataset = load_from_disk(path)
77
  list_cache_prefetch_files = glob.glob(
78
  cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
 
89
  # check cache file
90
  if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
91
  for item_file in list_cache_prefetch_files:
92
+ shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
93
  cache_processing_dataset_folder))
94
  if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
95
  return dataset.map(prepare_dataset,
 
129
  # loop over training shards
130
  _train_dataset_shard_idx = epoch_count % num_train_shards
131
  # Get test shard depend on train shard id
132
+ _test_dataset_shard_idx = min(round(_train_dataset_shard_idx / (num_train_shards / num_test_shards)), num_test_shards - 1)
133
  _num_test_sub_shard = 8 # Split test shard into subset. Default is 8
134
  _idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
135
  return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
 
172
  train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
173
  test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
174
 
175
+ cache_processing_dataset_folder = '/dev/shm/cache/'
176
  cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
177
  if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
178
  os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))