from typing import Optional, Iterator, Callable, Any import torch from datasets import load_dataset, concatenate_datasets from transformers import AutoTokenizer def load_text_dataset(tokenizer: AutoTokenizer, kind: str, path: str, name: Optional[str]=None, data_dir: Optional[str]=None, data_files: Optional[str]=None, keep_in_memory: bool=False, revision: Optional[str]=None, split: str='train', num_proc: Optional[int]=None, format: Optional[Callable|str]=None) -> Any: assert isinstance(format, str) or callable(format), f'{path=} {format=}' assert kind == 'base' dataset = load_dataset(path=path, name=name, data_dir=data_dir, data_files=data_files, keep_in_memory=keep_in_memory, revision=revision, split=split, trust_remote_code=True, num_proc=num_proc) EOS_TOKEN = tokenizer.eos_token def format_dataset(batch): nonlocal EOS_TOKEN nonlocal format texts: list = [] rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())] if callable(format): for row in rows: # print(f'{row=}') text = format(row) if not text: text = '[NONE]' text += EOS_TOKEN texts.append(text) else: for row in rows: # print(f'{row=}') text = format.format(**row) if not text: text = '[NONE]' text += EOS_TOKEN texts.append(text) return {'text': texts} dataset = dataset.map(format_dataset, batched=True) return dataset def load_chat_dataset(tokenizer: AutoTokenizer, kind: str, path: str, name: Optional[str]=None, data_dir: Optional[str]=None, data_files: Optional[str]=None, keep_in_memory: bool=False, revision: Optional[str]=None, split: str='train', num_proc: Optional[int]=None, field: Optional[str]=None, transform: Optional[Callable]=None) -> Any: assert kind == 'instruct' dataset = load_dataset(path=path, name=name, data_dir=data_dir, data_files=data_files, keep_in_memory=keep_in_memory, revision=revision, split=split, trust_remote_code=True, num_proc=num_proc) EOS_TOKEN = tokenizer.eos_token def format_dataset(batch): nonlocal EOS_TOKEN nonlocal tokenizer nonlocal field nonlocal transform texts: list = [] rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())] if callable(transform): for row in rows: if field: messages = transform(row[field]) else: messages = transform(row) text = tokenizer.apply_chat_template(messages, tokenize=False) text += EOS_TOKEN texts.append(text) else: for row in rows: if field: messages = row[field] else: raise ValueError(field) text = tokenizer.apply_chat_template(messages, tokenize=False) text += EOS_TOKEN texts.append(text) return {'text': texts} dataset = dataset.map(format_dataset, batched=True) return dataset