|
import json |
|
from pathlib import Path |
|
from typing import Callable, Optional |
|
|
|
import torch |
|
from megatron.core import parallel_state |
|
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( |
|
MegatronPretrainingRandomSampler, |
|
MegatronPretrainingSampler, |
|
) |
|
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( |
|
MegatronPretrainingBatchSampler, |
|
MegatronPretrainingRandomBatchSampler, |
|
) |
|
from nemo.core.classes import Dataset |
|
from nemo.utils import logging |
|
from nemo.utils.get_rank import is_global_rank_zero |
|
from omegaconf import DictConfig |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
def build_dataloader( |
|
dataset: Dataset, |
|
consumed_samples: int, |
|
micro_batch_size: int, |
|
global_batch_size: int, |
|
collate_fn: Optional[Callable] = None, |
|
seed: Optional[int] = None, |
|
) -> DataLoader: |
|
common_params: dict = { |
|
"total_samples": len(dataset), |
|
"consumed_samples": consumed_samples, |
|
"micro_batch_size": micro_batch_size, |
|
"global_batch_size": global_batch_size, |
|
"data_parallel_rank": parallel_state.get_data_parallel_rank(), |
|
"data_parallel_size": parallel_state.get_data_parallel_world_size(), |
|
"drop_last": True, |
|
"pad_samples_to_global_batch_size": False, |
|
} |
|
|
|
if seed is not None and seed >= 0: |
|
batch_sampler = MegatronPretrainingRandomBatchSampler( |
|
**common_params, seed=seed |
|
) |
|
else: |
|
batch_sampler = MegatronPretrainingBatchSampler(**common_params) |
|
|
|
return DataLoader( |
|
dataset, |
|
batch_sampler=batch_sampler, |
|
num_workers=0, |
|
pin_memory=True, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
|
|
def custom_build_dataloader( |
|
dataset: Dataset, |
|
consumed_samples: int, |
|
mbs: int, |
|
gbs: int, |
|
num_workers: int = 0, |
|
drop_last: bool = True, |
|
pad_samples_to_global_batch_size: bool = False, |
|
load_gbs: bool = True, |
|
seed: Optional[int] = None, |
|
use_random_sampler: bool = True, |
|
collate_fn=None, |
|
): |
|
|
|
common_params = { |
|
"total_samples": len(dataset), |
|
"consumed_samples": consumed_samples, |
|
"micro_batch_size": mbs, |
|
"data_parallel_rank": parallel_state.get_data_parallel_rank(), |
|
"data_parallel_size": parallel_state.get_data_parallel_world_size(), |
|
"drop_last": drop_last, |
|
"global_batch_size": gbs, |
|
"pad_samples_to_global_batch_size": pad_samples_to_global_batch_size, |
|
} |
|
|
|
if use_random_sampler: |
|
cls = ( |
|
MegatronPretrainingRandomBatchSampler |
|
if load_gbs |
|
else MegatronPretrainingRandomSampler |
|
) |
|
common_params["seed"] = seed |
|
else: |
|
cls = ( |
|
MegatronPretrainingBatchSampler if load_gbs else MegatronPretrainingSampler |
|
) |
|
batch_sampler = cls(**common_params) |
|
|
|
return torch.utils.data.DataLoader( |
|
dataset, |
|
batch_sampler=batch_sampler, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
|
|
def load_datasets(cfg: DictConfig) -> tuple[list[dict], list[dict]]: |
|
data_name2num_examples: dict[str, dict[str, int]] = {} |
|
total_train_examples: list[dict] = [] |
|
total_dev_examples: list[dict] = [] |
|
for data_name, data_info in cfg.datasets.items(): |
|
dataset_path: Path = Path(f"{cfg.data_dir}/{data_name}.jsonl") |
|
if not dataset_path.exists(): |
|
raise FileNotFoundError(f"{dataset_path} does not exist.") |
|
if data_info.max_train_samples == 0: |
|
if is_global_rank_zero(): |
|
logging.info( |
|
f"max_train_samples for {data_name} is set to 0. Skip them." |
|
) |
|
continue |
|
|
|
if is_global_rank_zero(): |
|
logging.info(f"processing {dataset_path}...") |
|
loaded_examples: list[dict] = [] |
|
with dataset_path.open(encoding="utf-8") as f: |
|
for line in f: |
|
loaded_examples.append(json.loads(line)) |
|
|
|
if data_info.max_train_samples > len(loaded_examples) and is_global_rank_zero(): |
|
logging.warning( |
|
f"{data_name} has only {len(loaded_examples)} examples, " |
|
f"but max_train_samples is set to {data_info.max_train_samples}. " |
|
"Use all examples." |
|
) |
|
|
|
max_train_samples: int = ( |
|
data_info.max_train_samples |
|
if data_info.max_train_samples != -1 |
|
else len(loaded_examples) |
|
) |
|
max_dev_samples: int = 0 |
|
if data_info.split_dev: |
|
max_dev_samples = min( |
|
cfg.max_dev_samples, |
|
int(len(loaded_examples) * cfg.max_dev_ratio), |
|
) |
|
train_examples: list[dict] = ( |
|
loaded_examples[max_dev_samples : max_dev_samples + max_train_samples] |
|
* data_info.upsampling_factor |
|
) |
|
dev_examples: list[dict] = ( |
|
loaded_examples[:max_dev_samples] * data_info.upsampling_factor |
|
) |
|
|
|
total_train_examples.extend(train_examples) |
|
total_dev_examples.extend(dev_examples) |
|
data_name2num_examples[data_name] = { |
|
"train": len(train_examples), |
|
"dev": len(dev_examples), |
|
"original": len(loaded_examples), |
|
"upsampling_factor": data_info.upsampling_factor, |
|
} |
|
|
|
if is_global_rank_zero(): |
|
num_total_original_examples: int = 0 |
|
logging.info("------------------------------") |
|
logging.info("Dataset summary (original -> train/dev)") |
|
for data_name, num_examples in data_name2num_examples.items(): |
|
num_total_original_examples += num_examples["original"] |
|
logging.info( |
|
f"{data_name}: {num_examples['original']} -> {num_examples['train']}/{num_examples['dev']} (upsampling factor: {num_examples['upsampling_factor']})" |
|
) |
|
logging.info( |
|
f"Total: {num_total_original_examples} -> {len(total_train_examples)}/{len(total_dev_examples)}" |
|
) |
|
logging.info("------------------------------") |
|
|
|
return total_train_examples, total_dev_examples |
|
|
|
|
|
class LLMJPSFTDataset(Dataset): |
|
def __init__( |
|
self, |
|
loaded_examples: list[dict], |
|
tokenizer: TokenizerSpec, |
|
use_loss_mask: bool, |
|
max_seq_length: int = 4096, |
|
): |
|
self.tokenizer = tokenizer |
|
self.use_loss_mask: bool = use_loss_mask |
|
self.max_seq_length: int = max_seq_length |
|
|
|
self.examples: list[dict[str, list[int]]] = self._process_examples( |
|
loaded_examples |
|
) |
|
|
|
def __len__(self) -> int: |
|
return len(self.examples) |
|
|
|
def __getitem__(self, idx: int) -> dict[str, list[int]]: |
|
return self.examples[idx] |
|
|
|
def _process_examples( |
|
self, loaded_examples: list[dict] |
|
) -> list[dict[str, list[int]]]: |
|
all_input_ids: list[int] = [] |
|
all_loss_mask: list[int] = [] |
|
for example_idx, loaded_example in enumerate(loaded_examples): |
|
conversation: list[dict[str, str]] = loaded_example["messages"] |
|
assert len(conversation) >= 3 |
|
assert conversation[0]["role"] == "system" |
|
|
|
input_ids: list[int] = [self.tokenizer.bos_id] + self.tokenizer.text_to_ids( |
|
conversation[0]["content"] |
|
) |
|
loss_mask: list[int] = ( |
|
[0] * len(input_ids) if self.use_loss_mask else [1] * len(input_ids) |
|
) |
|
for turn_idx in range(1, len(conversation[1:]) // 2 + 1): |
|
user_message: dict[str, str] = conversation[2 * turn_idx - 1] |
|
assistant_message: dict[str, str] = conversation[2 * turn_idx] |
|
assert user_message["role"] == "user" |
|
assert assistant_message["role"] == "assistant" |
|
|
|
if self.use_loss_mask: |
|
prompt_ids: list[int] = self.tokenizer.text_to_ids( |
|
f"\n\n### 指示:\n{user_message['content']}\n\n### 応答:\n" |
|
)[1:] |
|
response_ids: list[int] = self.tokenizer.text_to_ids( |
|
f"\n{assistant_message['content']}" |
|
)[2:] + [self.tokenizer.eos_id] |
|
input_ids.extend(prompt_ids + response_ids) |
|
loss_mask.extend([0] * len(prompt_ids) + [1] * len(response_ids)) |
|
else: |
|
prompt_response_ids: list[int] = self.tokenizer.text_to_ids( |
|
f"\n\n### 指示:\n{user_message['content']}\n\n### 応答:\n{assistant_message['content']}" |
|
)[1:] + [self.tokenizer.eos_id] |
|
input_ids.extend(prompt_response_ids) |
|
loss_mask.extend([1] * len(prompt_response_ids)) |
|
|
|
if is_global_rank_zero() and example_idx < 2: |
|
logging.info(f"{example_idx = }") |
|
logging.info(f"{input_ids = }") |
|
logging.info(f"{loss_mask = }") |
|
|
|
all_input_ids.extend(input_ids) |
|
all_loss_mask.extend(loss_mask) |
|
|
|
examples: list[dict[str, list[int]]] = [] |
|
for i in range(0, len(all_input_ids), self.max_seq_length + 1): |
|
chunked_input_ids: list[int] = all_input_ids[ |
|
i : i + self.max_seq_length + 1 |
|
] |
|
chunked_loss_mask: list[int] = all_loss_mask[ |
|
i : i + self.max_seq_length + 1 |
|
] |
|
if len(chunked_input_ids) == self.max_seq_length + 1: |
|
if set(chunked_loss_mask) == {0}: |
|
continue |
|
examples.append( |
|
{"input_ids": chunked_input_ids, "loss_mask": chunked_loss_mask} |
|
) |
|
return examples |
|
|
|
@torch.no_grad() |
|
def _create_attention_mask(self, seq_length: int) -> torch.Tensor: |
|
attention_mask = torch.tril(torch.ones((seq_length, seq_length))).unsqueeze( |
|
0 |
|
) |
|
attention_mask = attention_mask < 0.5 |
|
return attention_mask |
|
|
|
def collate_fn(self, batch: list[dict[str, list[int]]]) -> dict[str, torch.Tensor]: |
|
input_ids: list[list[int]] = [item["input_ids"][:-1] for item in batch] |
|
labels: list[list[int]] = [item["input_ids"][1:] for item in batch] |
|
loss_mask: list[list[int]] = [item["loss_mask"][1:] for item in batch] |
|
|
|
pro_batch = { |
|
"tokens": torch.LongTensor(input_ids), |
|
"position_ids": torch.LongTensor( |
|
[list(range(self.max_seq_length)) for _ in batch] |
|
), |
|
"attention_mask": torch.stack( |
|
[self._create_attention_mask(self.max_seq_length) for _ in batch] |
|
), |
|
"labels": torch.LongTensor(labels), |
|
"loss_mask": torch.LongTensor(loss_mask), |
|
} |
|
|
|
return pro_batch |
|
|