reissbaker's picture
Add unfat config
3d06b91
import os
import re
from typing import cast, Any
from datasets import load_dataset, Dataset as HfDataset
from unfat.extract import Extractor
from unfat.client import OpenAiCompatClient
from unfat.datasets import Dataset, Prompts, hub_prompts, HubSplit
from unfat.together import llama_3_1_70b_together
from unfat.lora import LoraSettings
def gen_prompts(
ds_name: str,
text_field: str,
start_regex: re.Pattern | None = None,
end_regex: re.Pattern | None = None,
):
ds = cast(HfDataset, load_dataset(ds_name, split="train"))
def items():
for row in ds:
casted = cast(dict[Any, Any], row)
text = casted[text_field]
if start_regex and end_regex:
yield end_regex.sub("", start_regex.sub("", text))
elif start_regex:
yield start_regex.sub("", text)
elif end_regex:
yield end_regex.sub("", text)
else:
yield text
return Prompts(
output_path=f"hub/{ds_name}.jsonl",
count=lambda: len(ds),
items=items,
)
def extract_prompts_from_convos(
ds_name: str,
messages_field: str,
role_field: str,
content_field: str,
user_role: str,
):
ds = cast(HfDataset, load_dataset(ds_name, split="train"))
def items():
for row in ds:
casted = cast(dict[Any, Any], row)
for message in casted[messages_field]:
if message[role_field] == user_role:
yield message[content_field]
break
return Prompts(
output_path=f"hub/{ds_name}.jsonl",
count=lambda: len(ds),
items=items,
)
def main():
output_dir = "output"
rp_english = extract_prompts_from_convos(
ds_name="OdiaGenAI/roleplay_english",
messages_field="conversations",
role_field="from",
content_field="value",
user_role="user",
)
bluemoon = extract_prompts_from_convos(
ds_name="xDAN2099/RolePlay-Mixed-Bluemoon-Limarp",
messages_field="conversations",
role_field="from",
content_field="value",
user_role="human",
)
roleplay_prompts = gen_prompts(
ds_name="AlekseyKorshuk/roleplay-io",
text_field="input_text",
start_regex=re.compile(r'^User: '),
end_regex=re.compile(r'Bot:\s*$'),
)
roleplay_instr_prompts = gen_prompts(
ds_name="iamketan25/roleplay-instructions-dataset",
text_field="prompt",
start_regex=re.compile(r'^Human: '),
end_regex=re.compile(r'Assistant:\s*$'),
)
extractor = Extractor(
max_concurrent=50,
output_dir=output_dir,
client=OpenAiCompatClient(
base_url="https://glhf.chat/api/openai/v1",
api_key=os.environ["GLHF_API_KEY"],
model="hf:TheDrummer/Behemoth-123B-v1.2",
retries=20,
),
dataset=Dataset(
train=[
hub_prompts(
name="mlabonne/harmful_behaviors",
text_field="text",
split="train",
),
roleplay_instr_prompts,
roleplay_prompts,
rp_english,
bluemoon,
hub_prompts(
name="TheDrummer/AmoralQA-v2",
text_field="prompt",
split="train",
),
hub_prompts(
name="vicgalle/OpenHermesPreferences-roleplay",
text_field="prompt",
split="train",
),
hub_prompts(
name="mrcuddle/DPO_Pairs_Roleplay-Alpaca",
text_field="prompt",
split="train",
),
hub_prompts(
name="ResplendentAI/theory_of_mind_fixed_output",
text_field="instruction",
split="train",
),
hub_prompts(
name="mlabonne/harmless_alpaca",
text_field="text",
split=HubSplit(name="train", max_rows=1000),
),
],
),
)
extractor.run()
dataset = extractor.output_dataset()
together_config = llama_3_1_70b_together(
output_dir=output_dir,
dataset=dataset,
api_key=os.environ["TOGETHER_API_KEY"],
settings=LoraSettings(
rank=32,
alpha=16,
dropout=0.01,
num_epochs=2,
learning_rate=4e-4,
evals_per_epoch=0,
wandb_project="behemoth-distill",
wandb_api_key=os.environ["WANDB_API_KEY"],
)
)
files = together_config.upload_files()
together_config.finetune(files)
if __name__ == "__main__":
main()