|
import gc |
|
from typing import Optional, Iterator, Callable |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from litgpt.tokenizer import Tokenizer |
|
from transformers import AutoTokenizer |
|
|
|
|
|
def batch_text_iterator(kind: str, |
|
path: str, |
|
name: Optional[str]=None, |
|
data_dir: Optional[str]=None, |
|
data_files: Optional[str]=None, |
|
keep_in_memory: bool=False, |
|
revision: Optional[str]=None, |
|
split: str='train', |
|
num_proc: Optional[int]=None, |
|
format: Optional[Callable|str]=None) -> Iterator[str]: |
|
assert isinstance(format, str) or callable(format), f'{path=} {format=}' |
|
assert kind == 'base' |
|
|
|
dataset = load_dataset(path=path, |
|
name=name, |
|
data_dir=data_dir, |
|
data_files=data_files, |
|
keep_in_memory=keep_in_memory, |
|
revision=revision, |
|
split=split, |
|
trust_remote_code=True, |
|
num_proc=num_proc) |
|
|
|
if callable(format): |
|
for row in dataset: |
|
text = format(row) |
|
yield text |
|
else: |
|
for row in dataset: |
|
text = format.format(**row) |
|
yield text |
|
|
|
del dataset |
|
gc.collect() |
|
|
|
|
|
def batch_chat_iterator(kind: str, |
|
path: str, |
|
name: Optional[str]=None, |
|
data_dir: Optional[str]=None, |
|
data_files: Optional[str]=None, |
|
keep_in_memory: bool=False, |
|
revision: Optional[str]=None, |
|
split: str='train', |
|
num_proc: Optional[int]=None, |
|
field: Optional[str]=None, |
|
transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]: |
|
assert kind == 'instruct' |
|
|
|
dataset = load_dataset(path=path, |
|
name=name, |
|
data_dir=data_dir, |
|
data_files=data_files, |
|
keep_in_memory=keep_in_memory, |
|
revision=revision, |
|
split=split, |
|
trust_remote_code=True, |
|
num_proc=num_proc) |
|
|
|
if callable(transform): |
|
for row in dataset: |
|
if field: |
|
messages = transform(row[field]) |
|
else: |
|
messages = transform(row) |
|
|
|
yield messages |
|
else: |
|
for row in dataset: |
|
if field: |
|
messages = row[field] |
|
else: |
|
raise ValueError(field) |
|
|
|
yield messages |
|
|
|
del dataset |
|
gc.collect() |
|
|
|
|
|
def tokenize_text_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: |
|
for text in batch_text_iterator(**dataset_config): |
|
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) |
|
yield text_ids |
|
|
|
|
|
def tokenize_chat_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: |
|
for messages in batch_chat_iterator(**dataset_config): |
|
text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) |
|
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) |
|
yield text_ids |
|
|
|
|
|
def tokenize_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: |
|
if dataset_config['kind'] == 'base': |
|
for text in batch_text_iterator(**dataset_config): |
|
try: |
|
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) |
|
except Exception as e: |
|
print(f'Skip base raw: {e=} {type(text)=} {text=}') |
|
continue |
|
|
|
yield text_ids |
|
elif dataset_config['kind'] == 'instruct': |
|
for messages in batch_chat_iterator(**dataset_config): |
|
try: |
|
text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) |
|
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) |
|
except Exception as e: |
|
print(f'Skip instruct row: {e=} {type(messages)=} {messages=}') |
|
continue |
|
|
|
yield text_ids |
|
else: |
|
raise ValueError(dataset_config['kind']) |
|
|