Check commited on
Commit
2808233
·
1 Parent(s): 6417fe2

fix error when read shard

Browse files
Files changed (1) hide show
  1. main.py +46 -40
main.py CHANGED
@@ -73,46 +73,49 @@ def prepare_dataset(batch, processor):
73
 
74
 
75
  def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5):
76
- dataset = load_from_disk(path)
77
- list_cache_prefetch_files = glob.glob(
78
- cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
79
- '.arrow', '*'))
80
-
81
- # Do not re-compute what already in cache folder
82
- if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch):
83
- if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch,
84
- cache_processing_dataset_folder).replace('.arrow', '*'))) > 0:
85
- return
86
- if len(list_cache_prefetch_files) > 0:
87
- return
88
-
89
- # check cache file
90
- if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
91
- for item_file in list_cache_prefetch_files:
92
- shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
93
- cache_processing_dataset_folder))
94
- if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
95
- return dataset.map(prepare_dataset,
96
- remove_columns=dataset.column_names,
97
- batch_size=32,
98
- num_proc=num_proc,
99
- batched=True,
100
- fn_kwargs={"processor": processor},
101
- cache_file_name=cache_file_map_name)
102
-
103
- dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
104
- batch_size=32,
105
- num_proc=num_proc,
106
- cache_file_name=cache_file_filter_name)
107
- processed_dataset = dataset.map(prepare_dataset,
108
- remove_columns=dataset.column_names,
109
- batch_size=32,
110
- num_proc=num_proc,
111
- batched=True,
112
- fn_kwargs={"processor": processor},
113
- cache_file_name=cache_file_map_name)
114
- processed_dataset.cleanup_cache_files()
115
- return processed_dataset
 
 
 
116
 
117
 
118
  def commit_checkpoint():
@@ -264,6 +267,9 @@ if __name__ == "__main__":
264
  'cache-test-map-shard-{}.arrow'.format(
265
  test_dataset_shard_idx))
266
  )
 
 
 
267
  test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
268
 
269
  # Prefetch_dataset
 
73
 
74
 
75
  def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5):
76
+ try:
77
+ dataset = load_from_disk(path)
78
+ list_cache_prefetch_files = glob.glob(
79
+ cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
80
+ '.arrow', '*'))
81
+
82
+ # Do not re-compute what already in cache folder
83
+ if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch):
84
+ if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch,
85
+ cache_processing_dataset_folder).replace('.arrow', '*'))) > 0:
86
+ return
87
+ if len(list_cache_prefetch_files) > 0:
88
+ return
89
+
90
+ # check cache file
91
+ if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
92
+ for item_file in list_cache_prefetch_files:
93
+ shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
94
+ cache_processing_dataset_folder))
95
+ if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
96
+ return dataset.map(prepare_dataset,
97
+ remove_columns=dataset.column_names,
98
+ batch_size=32,
99
+ num_proc=num_proc,
100
+ batched=True,
101
+ fn_kwargs={"processor": processor},
102
+ cache_file_name=cache_file_map_name)
103
+
104
+ dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
105
+ batch_size=32,
106
+ num_proc=num_proc,
107
+ cache_file_name=cache_file_filter_name)
108
+ processed_dataset = dataset.map(prepare_dataset,
109
+ remove_columns=dataset.column_names,
110
+ batch_size=32,
111
+ num_proc=num_proc,
112
+ batched=True,
113
+ fn_kwargs={"processor": processor},
114
+ cache_file_name=cache_file_map_name)
115
+ processed_dataset.cleanup_cache_files()
116
+ return processed_dataset
117
+ except:
118
+ return None
119
 
120
 
121
  def commit_checkpoint():
 
267
  'cache-test-map-shard-{}.arrow'.format(
268
  test_dataset_shard_idx))
269
  )
270
+ if train_dataset is None or test_dataset is None:
271
+ print("Ignore Shard {}".format(train_dataset_shard_idx))
272
+ continue
273
  test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
274
 
275
  # Prefetch_dataset