|
import os |
|
import re |
|
from typing import cast, Any |
|
from datasets import load_dataset, Dataset as HfDataset |
|
from unfat.extract import Extractor |
|
from unfat.client import OpenAiCompatClient |
|
from unfat.datasets import Dataset, Prompts, hub_prompts, HubSplit |
|
from unfat.together import llama_3_1_70b_together |
|
from unfat.lora import LoraSettings |
|
|
|
def gen_prompts( |
|
ds_name: str, |
|
text_field: str, |
|
start_regex: re.Pattern | None = None, |
|
end_regex: re.Pattern | None = None, |
|
): |
|
ds = cast(HfDataset, load_dataset(ds_name, split="train")) |
|
def items(): |
|
for row in ds: |
|
casted = cast(dict[Any, Any], row) |
|
text = casted[text_field] |
|
if start_regex and end_regex: |
|
yield end_regex.sub("", start_regex.sub("", text)) |
|
elif start_regex: |
|
yield start_regex.sub("", text) |
|
elif end_regex: |
|
yield end_regex.sub("", text) |
|
else: |
|
yield text |
|
|
|
return Prompts( |
|
output_path=f"hub/{ds_name}.jsonl", |
|
count=lambda: len(ds), |
|
items=items, |
|
) |
|
|
|
def extract_prompts_from_convos( |
|
ds_name: str, |
|
messages_field: str, |
|
role_field: str, |
|
content_field: str, |
|
user_role: str, |
|
): |
|
ds = cast(HfDataset, load_dataset(ds_name, split="train")) |
|
def items(): |
|
for row in ds: |
|
casted = cast(dict[Any, Any], row) |
|
for message in casted[messages_field]: |
|
if message[role_field] == user_role: |
|
yield message[content_field] |
|
break |
|
return Prompts( |
|
output_path=f"hub/{ds_name}.jsonl", |
|
count=lambda: len(ds), |
|
items=items, |
|
) |
|
|
|
def main(): |
|
output_dir = "output" |
|
rp_english = extract_prompts_from_convos( |
|
ds_name="OdiaGenAI/roleplay_english", |
|
messages_field="conversations", |
|
role_field="from", |
|
content_field="value", |
|
user_role="user", |
|
) |
|
bluemoon = extract_prompts_from_convos( |
|
ds_name="xDAN2099/RolePlay-Mixed-Bluemoon-Limarp", |
|
messages_field="conversations", |
|
role_field="from", |
|
content_field="value", |
|
user_role="human", |
|
) |
|
roleplay_prompts = gen_prompts( |
|
ds_name="AlekseyKorshuk/roleplay-io", |
|
text_field="input_text", |
|
start_regex=re.compile(r'^User: '), |
|
end_regex=re.compile(r'Bot:\s*$'), |
|
) |
|
roleplay_instr_prompts = gen_prompts( |
|
ds_name="iamketan25/roleplay-instructions-dataset", |
|
text_field="prompt", |
|
start_regex=re.compile(r'^Human: '), |
|
end_regex=re.compile(r'Assistant:\s*$'), |
|
) |
|
|
|
extractor = Extractor( |
|
max_concurrent=50, |
|
output_dir=output_dir, |
|
client=OpenAiCompatClient( |
|
base_url="https://glhf.chat/api/openai/v1", |
|
api_key=os.environ["GLHF_API_KEY"], |
|
model="hf:TheDrummer/Behemoth-123B-v1.2", |
|
retries=20, |
|
), |
|
dataset=Dataset( |
|
train=[ |
|
hub_prompts( |
|
name="mlabonne/harmful_behaviors", |
|
text_field="text", |
|
split="train", |
|
), |
|
roleplay_instr_prompts, |
|
roleplay_prompts, |
|
rp_english, |
|
bluemoon, |
|
hub_prompts( |
|
name="TheDrummer/AmoralQA-v2", |
|
text_field="prompt", |
|
split="train", |
|
), |
|
hub_prompts( |
|
name="vicgalle/OpenHermesPreferences-roleplay", |
|
text_field="prompt", |
|
split="train", |
|
), |
|
hub_prompts( |
|
name="mrcuddle/DPO_Pairs_Roleplay-Alpaca", |
|
text_field="prompt", |
|
split="train", |
|
), |
|
hub_prompts( |
|
name="ResplendentAI/theory_of_mind_fixed_output", |
|
text_field="instruction", |
|
split="train", |
|
), |
|
hub_prompts( |
|
name="mlabonne/harmless_alpaca", |
|
text_field="text", |
|
split=HubSplit(name="train", max_rows=1000), |
|
), |
|
], |
|
), |
|
) |
|
extractor.run() |
|
dataset = extractor.output_dataset() |
|
together_config = llama_3_1_70b_together( |
|
output_dir=output_dir, |
|
dataset=dataset, |
|
api_key=os.environ["TOGETHER_API_KEY"], |
|
settings=LoraSettings( |
|
rank=32, |
|
alpha=16, |
|
dropout=0.01, |
|
num_epochs=2, |
|
learning_rate=4e-4, |
|
evals_per_epoch=0, |
|
wandb_project="behemoth-distill", |
|
wandb_api_key=os.environ["WANDB_API_KEY"], |
|
) |
|
) |
|
files = together_config.upload_files() |
|
together_config.finetune(files) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|