# Copyright (c) 2024, EleutherAI # This file is based on code by the authors denoted below and has been modified from its original version. # # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # copied from fairseq/fairseq/data/indexed_dataset.py # Removed IndexedRawTextDataset since it relied on Fairseq dictionary # other slight modifications to remove fairseq dependencies # Added document index to index file and made it accessible. # An empty sentence no longer separates documents. import os import shutil import struct from functools import lru_cache from itertools import accumulate import numpy as np import torch from megatron import print_rank_0 def __best_fitting_dtype(vocab_size=None): if vocab_size is not None and vocab_size < 65500: return np.uint16 else: return np.int32 def infer_dataset_impl(path): if IndexedDataset.exists(path): with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: return "mmap" else: return None else: print(f"Dataset does not exist: {path}") print( "Path should be a basename that both .idx and .bin can be appended to get full filenames." ) return None def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": return MMapIndexedDatasetBuilder( out_file, dtype=__best_fitting_dtype(vocab_size) ) else: return IndexedDatasetBuilder(out_file) def make_dataset(path, impl, skip_warmup=False): if not IndexedDataset.exists(path): print(f"Dataset does not exist: {path}") print( "Path should be a basename that both .idx and .bin can be appended to get full filenames." ) return None if impl == "infer": impl = infer_dataset_impl(path) elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path) elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): if impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) def read_longs(f, n): a = np.empty(n, dtype=np.int64) f.readinto(a) return a def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16, } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) def index_file_path(prefix_path): return prefix_path + ".idx" def data_file_path(prefix_path): return prefix_path + ".bin" def create_doc_idx(sizes): doc_idx = [0] for i, s in enumerate(sizes): if s == 0: doc_idx.append(i + 1) return doc_idx class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path): super().__init__() self.path = path self.data_file = None self.read_index(path) def read_index(self, path): with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." ) version = f.read(8) assert struct.unpack("= self._len: raise IndexError("index out of range") def __del__(self): if self.data_file: self.data_file.close() # @lru_cache(maxsize=8) def __getitem__(self, idx): if not self.data_file: self.read_data(self.path) if isinstance(idx, int): i = idx self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) return a elif isinstance(idx, slice): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) self.data_file.readinto(a) offsets = list(accumulate(sizes)) sents = np.split(a, offsets[:-1]) return sents def __len__(self): return self._len def num_tokens(self, index): return self.sizes[index] def size(self, index): return self.sizes[index] @staticmethod def exists(path): return os.path.exists(index_file_path(path)) and os.path.exists( data_file_path(path) ) @property def supports_prefetch(self): return False # avoid prefetching to save memory class IndexedCachedDataset(IndexedDataset): def __init__(self, path): super().__init__(path) self.cache = None self.cache_index = {} @property def supports_prefetch(self): return True def prefetch(self, indices): if all(i in self.cache_index for i in indices): return if not self.data_file: self.read_data(self.path) indices = sorted(set(indices)) total_size = 0 for i in indices: total_size += self.data_offsets[i + 1] - self.data_offsets[i] self.cache = np.empty(total_size, dtype=self.dtype) ptx = 0 self.cache_index.clear() for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size if self.data_file: # close and delete data file after prefetch so we can pickle self.data_file.close() self.data_file = None # @lru_cache(maxsize=8) def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] np.copyto(a, self.cache[ptx : ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary sents = [] for i in range(*idx.indices(len(self))): sents.append(self[i]) return sents class IndexedDatasetBuilder(object): element_sizes = { np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float32: 4, np.float64: 8, } def __init__(self, out_file, dtype=np.int32): self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] self.sizes = [] self.element_size = self.element_sizes[self.dtype] self.doc_idx = [0] def add_item(self, np_array): assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype bytes = self.out_file.write(np_array) self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) for s in np_array.shape: self.sizes.append(s) self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape)) def end_document(self): self.doc_idx.append(len(self.sizes)) def merge_file_(self, another_file): index = IndexedDataset(another_file) assert index.dtype == self.dtype begin = self.data_offsets[-1] for offset in index.data_offsets[1:]: self.data_offsets.append(begin + offset) self.sizes.extend(index.sizes) begin = self.dim_offsets[-1] for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: self.out_file.write(data) else: break def finalize(self, index_file): self.out_file.close() index = open(index_file, "wb") index.write(b"TNTIDX\x00\x00") index.write(struct.pack("