|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from datasets import load_dataset |
|
import argilla as rg |
|
import os |
|
from tqdm import tqdm |
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
rg.init( |
|
api_url="<argilla-api-url>", |
|
api_key="<argilla-api-key>" |
|
) |
|
|
|
ds = rg.FeedbackDataset.for_preference_modeling( |
|
number_of_responses=2, |
|
context=False, |
|
use_markdown=False, |
|
guidelines=None, |
|
metadata_properties=None, |
|
vectors_settings=None, |
|
) |
|
|
|
model_dir = dir_path +'/lora-out/merged' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
|
dataset = load_dataset("gardner/glaive-function-calling-v2-sharegpt", split="test") |
|
|
|
def preprocess_function_calling(examples): |
|
texts = [] |
|
answers = [] |
|
for chat in examples["conversations"]: |
|
for turn in chat: |
|
|
|
turn['role'] = turn.pop('from') |
|
turn['content'] = turn.pop('value') |
|
|
|
if turn['role'] == 'human': |
|
turn['role'] = 'user' |
|
if turn['role'] == 'gpt': |
|
turn['role'] = 'assistant' |
|
|
|
if chat[-1]['role'] == 'assistant': |
|
answers.append(chat[-1]['content']) |
|
del chat[-1] |
|
|
|
texts.append(tokenizer.apply_chat_template(chat, tokenize=False)) |
|
|
|
return { "texts": texts, "answer": answers } |
|
|
|
texts = dataset.map(preprocess_function_calling, batched=True) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_dir).to("cuda") |
|
|
|
records = [] |
|
|
|
for text in tqdm(texts): |
|
prompt = text['texts'] + "<|im_start|>assistant\n" |
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
|
|
outputs = model.generate(**inputs, max_new_tokens=512) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True, temperature=0.9) |
|
prompt_len = len(text['texts'] + "<|im_start|>assistant\n") |
|
|
|
response1 = response[prompt_len:].replace("<|im_end|>\n", "").strip() |
|
response2 = text['answer'] |
|
print(response1) |
|
|
|
records.append(rg.FeedbackRecord(fields={ |
|
"prompt": prompt, |
|
"response1": response1, |
|
"response2": response2, |
|
})) |
|
|
|
|
|
|
|
ds.add_records(records) |
|
ds.push_to_argilla(name="function-calling", workspace="argilla") |
|
|
|
|