
Rename scripts/preference_dataset.py to scripts/custom_datasets/preference_dataset.py
af6cb0e
verified
from datasets import load_dataset | |
from torchtune.data import StackExchangedPairedTemplate | |
from torchtune.datasets._preference import PreferenceDataset | |
from torchtune.modules.tokenizers import Tokenizer | |
from typing import Optional, Tuple, List | |
def extract_assistant_content(sample): | |
""" | |
Extracts the text content of the assistant response from the lists of messages. | |
Args: | |
sample (dict): A dictionary containing the prompt, chosen, and rejected lists of messages. | |
Returns: | |
dict: The original sample dictionary with the extracted assistant content. | |
""" | |
sample['chosen'] = sample['chosen'][-1]['content'] | |
sample['rejected'] = sample['rejected'][-1]['content'] | |
return sample | |
class ModifiedPreferenceDataset(PreferenceDataset): | |
def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int], List[int]]: | |
sample = self._data[index] | |
batch = self._prepare_sample(sample) | |
return ( | |
batch["chosen_input_ids"], | |
batch["chosen_labels"], | |
batch["rejected_input_ids"], | |
batch["rejected_labels"], | |
) | |
def orpo_dpo_mix_40k_dataset( | |
tokenizer: Tokenizer, | |
*, | |
max_seq_len: int = 8192, | |
) -> ModifiedPreferenceDataset: | |
""" | |
Preference dataset for the 'mlabonne/orpo-dpo-mix-40k' dataset. | |
Args: | |
tokenizer (Tokenizer): Tokenizer used to encode data. | |
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. | |
Default is 8192. | |
data_dir (str): Directory to store the downloaded dataset. Default is "data". | |
Returns: | |
ModifiedPreferenceDataset: The modified preference dataset built from the 'mlabonne/orpo-dpo-mix-40k' dataset. | |
""" | |
return ModifiedPreferenceDataset( | |
tokenizer=tokenizer, | |
source="mlabonne/orpo-dpo-mix-40k", | |
template=StackExchangedPairedTemplate(), | |
transform=extract_assistant_content, | |
column_map={ | |
"prompt": "prompt", | |
"chosen": "chosen", | |
"rejected": "rejected", | |
}, | |
max_seq_len=max_seq_len, | |
split="train", | |
data_dir="data" | |
) | |