import os import re from typing import cast, Any from datasets import load_dataset, Dataset as HfDataset from unfat.extract import Extractor from unfat.client import OpenAiCompatClient from unfat.datasets import Dataset, Prompts, hub_prompts, HubSplit from unfat.together import llama_3_1_70b_together from unfat.lora import LoraSettings def gen_prompts( ds_name: str, text_field: str, start_regex: re.Pattern | None = None, end_regex: re.Pattern | None = None, ): ds = cast(HfDataset, load_dataset(ds_name, split="train")) def items(): for row in ds: casted = cast(dict[Any, Any], row) text = casted[text_field] if start_regex and end_regex: yield end_regex.sub("", start_regex.sub("", text)) elif start_regex: yield start_regex.sub("", text) elif end_regex: yield end_regex.sub("", text) else: yield text return Prompts( output_path=f"hub/{ds_name}.jsonl", count=lambda: len(ds), items=items, ) def extract_prompts_from_convos( ds_name: str, messages_field: str, role_field: str, content_field: str, user_role: str, ): ds = cast(HfDataset, load_dataset(ds_name, split="train")) def items(): for row in ds: casted = cast(dict[Any, Any], row) for message in casted[messages_field]: if message[role_field] == user_role: yield message[content_field] break return Prompts( output_path=f"hub/{ds_name}.jsonl", count=lambda: len(ds), items=items, ) def main(): output_dir = "output" rp_english = extract_prompts_from_convos( ds_name="OdiaGenAI/roleplay_english", messages_field="conversations", role_field="from", content_field="value", user_role="user", ) bluemoon = extract_prompts_from_convos( ds_name="xDAN2099/RolePlay-Mixed-Bluemoon-Limarp", messages_field="conversations", role_field="from", content_field="value", user_role="human", ) roleplay_prompts = gen_prompts( ds_name="AlekseyKorshuk/roleplay-io", text_field="input_text", start_regex=re.compile(r'^User: '), end_regex=re.compile(r'Bot:\s*$'), ) roleplay_instr_prompts = gen_prompts( ds_name="iamketan25/roleplay-instructions-dataset", text_field="prompt", start_regex=re.compile(r'^Human: '), end_regex=re.compile(r'Assistant:\s*$'), ) extractor = Extractor( max_concurrent=50, output_dir=output_dir, client=OpenAiCompatClient( base_url="https://glhf.chat/api/openai/v1", api_key=os.environ["GLHF_API_KEY"], model="hf:TheDrummer/Behemoth-123B-v1.2", retries=20, ), dataset=Dataset( train=[ hub_prompts( name="mlabonne/harmful_behaviors", text_field="text", split="train", ), roleplay_instr_prompts, roleplay_prompts, rp_english, bluemoon, hub_prompts( name="TheDrummer/AmoralQA-v2", text_field="prompt", split="train", ), hub_prompts( name="vicgalle/OpenHermesPreferences-roleplay", text_field="prompt", split="train", ), hub_prompts( name="mrcuddle/DPO_Pairs_Roleplay-Alpaca", text_field="prompt", split="train", ), hub_prompts( name="ResplendentAI/theory_of_mind_fixed_output", text_field="instruction", split="train", ), hub_prompts( name="mlabonne/harmless_alpaca", text_field="text", split=HubSplit(name="train", max_rows=1000), ), ], ), ) extractor.run() dataset = extractor.output_dataset() together_config = llama_3_1_70b_together( output_dir=output_dir, dataset=dataset, api_key=os.environ["TOGETHER_API_KEY"], settings=LoraSettings( rank=32, alpha=16, dropout=0.01, num_epochs=2, learning_rate=4e-4, evals_per_epoch=0, wandb_project="behemoth-distill", wandb_api_key=os.environ["WANDB_API_KEY"], ) ) files = together_config.upload_files() together_config.finetune(files) if __name__ == "__main__": main()