mpt-7b-8k / data_prep_utils.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
cf2d90f verified
raw
history blame
3.18 kB
import json
import os
from glob import glob
from typing import List, Optional
def with_id(basename: str, shard_id: int) -> str:
"""Get a new basename with the given shard_id.
From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset_conversion.ipynb.
Args:
basename (str): Old basename of file.
shard_id (int): New shard ID.
Returns:
str: New basename of file.
"""
parts = basename.split('.')
parts[1] = f'{shard_id:05}'
return '.'.join(parts)
def merge_shard_groups(root: str) -> None:
"""Merge ephemeral sub-datasets created in parallel into one dataset.
From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset
_conversion.ipynb.
Args:
root (str): Root directory.
"""
pattern = os.path.join(root, '*')
subdirs = sorted(glob(pattern))
shard_id = 0
infos = []
for subdir in subdirs:
index_filename = os.path.join(subdir, 'index.json')
with open(index_filename) as index_file:
obj = json.load(index_file)
for info in obj['shards']:
old_basename = info['raw_data']['basename']
new_basename = with_id(old_basename, shard_id)
info['raw_data']['basename'] = new_basename
if info['zip_data'] is not None:
old_basename = info['zip_data']['basename']
new_basename = with_id(old_basename, shard_id)
info['zip_data']['basename'] = new_basename
old_filename = os.path.join(subdir, old_basename)
new_filename = os.path.join(root, new_basename)
os.rename(old_filename, new_filename)
shard_id += 1
infos.append(info)
os.remove(index_filename)
os.rmdir(subdir)
index_filename = os.path.join(root, 'index.json')
obj = {'version': 2, 'shards': infos}
text = json.dumps(obj, sort_keys=True)
with open(index_filename, 'w') as out:
out.write(text)
class DownloadingIterable:
def __init__(self, object_names: List[str], output_folder: str, object_store: Optional[ObjectStore]):
"""Iterable that downloads files from an object store before yielding.
If object_store is None, input_folder_prefix is treated as a local path.
Args:
object_names (List[str]): Names of objects to download
output_folder (str): Local folder to write downloaded files to
object_store (Optiona[ObjectStore]): Object store to download from
"""
self.object_names = object_names
self.object_store = object_store
self.output_folder = output_folder
def __iter__(self):
for object_name in self.object_names:
output_filename = object_name
if self.object_store is not None:
output_filename = os.path.join(self.output_folder, object_name.strip('/'))
self.object_store.download_object(object_name=object_name, filename=output_filename, overwrite=True)
with open(output_filename) as _txt_file:
txt = _txt_file.read()
yield {'text': txt}