HaileyStorm's picture
Rename scripts/preference_dataset.py to scripts/custom_datasets/preference_dataset.py
af6cb0e verified
from datasets import load_dataset
from torchtune.data import StackExchangedPairedTemplate
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.tokenizers import Tokenizer
from typing import Optional, Tuple, List
def extract_assistant_content(sample):
"""
Extracts the text content of the assistant response from the lists of messages.
Args:
sample (dict): A dictionary containing the prompt, chosen, and rejected lists of messages.
Returns:
dict: The original sample dictionary with the extracted assistant content.
"""
sample['chosen'] = sample['chosen'][-1]['content']
sample['rejected'] = sample['rejected'][-1]['content']
return sample
class ModifiedPreferenceDataset(PreferenceDataset):
def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int], List[int]]:
sample = self._data[index]
batch = self._prepare_sample(sample)
return (
batch["chosen_input_ids"],
batch["chosen_labels"],
batch["rejected_input_ids"],
batch["rejected_labels"],
)
def orpo_dpo_mix_40k_dataset(
tokenizer: Tokenizer,
*,
max_seq_len: int = 8192,
) -> ModifiedPreferenceDataset:
"""
Preference dataset for the 'mlabonne/orpo-dpo-mix-40k' dataset.
Args:
tokenizer (Tokenizer): Tokenizer used to encode data.
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
Default is 8192.
data_dir (str): Directory to store the downloaded dataset. Default is "data".
Returns:
ModifiedPreferenceDataset: The modified preference dataset built from the 'mlabonne/orpo-dpo-mix-40k' dataset.
"""
return ModifiedPreferenceDataset(
tokenizer=tokenizer,
source="mlabonne/orpo-dpo-mix-40k",
template=StackExchangedPairedTemplate(),
transform=extract_assistant_content,
column_map={
"prompt": "prompt",
"chosen": "chosen",
"rejected": "rejected",
},
max_seq_len=max_seq_len,
split="train",
data_dir="data"
)