import torch import torch.nn as nn from transformers.models.llama.modeling_llama import LlamaForSequenceClassification from sklearn.tree import DecisionTreeClassifier import os import pickle import json from huggingface_hub import hf_hub_download from typing import List, Dict, Union import numpy as np def convert_to_chat_format(prompt, response=None): if "" in prompt: """ Handling HelpSteer2 prompts which may contain multi-turn conversations with the special token """ turns = prompt.split("") conversation = [] conversation.append({ "role": "user", "content": turns[0] }) for i in range(1, len(turns)): parts = turns[i].split("\n", 1) role = parts[0] content = parts[1] conversation.append({ "role": "assistant" if role == "Assistant" else "user", "content": content }) else: conversation = [{"role": "user", "content": prompt}] if response is not None: conversation.append({"role": "assistant", "content": response}) return conversation def process_conversation(conversation): for message in conversation: message["content"] = message["content"].rstrip('\n') return conversation class LlamaForDecisionTreeRewardModel(LlamaForSequenceClassification): def __init__(self, config): super().__init__(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=True) # Initialize the decision tree self.tree = None # Define the default attributes (from HelpSteer2) self.attributes = ['helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'] print("Initialized LlamaForDecisionTreeRewardModel") def load_decision_tree(self, repo_id, filename="decision_tree.pkl"): # Load the tree from the model's directory with open(hf_hub_download(repo_id=repo_id, filename=filename), "rb") as f: self.tree = pickle.load(f) assert isinstance(self.tree, DecisionTreeClassifier), f"The tree is not a DecisionTreeClassifier. It is a {type(self.tree)}" with open(hf_hub_download(repo_id=repo_id, filename="config.json"), "r") as f: config = json.load(f) label2id_map = config["label2id"] # Sort labels and ids by ids labels, ids = zip(*sorted(label2id_map.items(), key=lambda x: x[1])) labels = list(labels) self.attributes = labels @torch.no_grad() def compare(self, prompt: Union[str, List[Dict[str, str]]], response_1: str, response_2: str, tokenizer, device): """ Compare two inputs and return the difference in scores """ assert self.tree is not None, "The decision tree is not loaded. Please call load_decision_tree(repo_id, filename) first." if isinstance(prompt, str): conversation = convert_to_chat_format(prompt) elif isinstance(prompt, list): conversation = prompt else: raise ValueError(f"The prompt must be a string or a list of dictionaries, but got {type(prompt)}") assert isinstance(conversation, list), "The conversation must be a list of dictionaries" assert len(conversation) >= 1, "The conversation must have at least one message (as prompt)" assert conversation[-1]["role"] == "user", "The last message in the conversation must be from the user" conversation_1 = conversation + [{"role": "assistant", "content": response_1}] conversation_2 = conversation + [{"role": "assistant", "content": response_2}] conversation_1 = process_conversation(conversation_1) conversation_2 = process_conversation(conversation_2) conv_tokenized_1 = tokenizer.apply_chat_template(conversation_1, tokenize=True, return_tensors="pt").to(device) conv_tokenized_2 = tokenizer.apply_chat_template(conversation_2, tokenize=True, return_tensors="pt").to(device) embedding_1 = self.forward(conv_tokenized_1, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy() embedding_2 = self.forward(conv_tokenized_2, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy() weight = self.score.weight.float().cpu().numpy() bias = self.score.bias.float().cpu().numpy() rewards_1 = embedding_1 @ weight.T + bias rewards_2 = embedding_2 @ weight.T + bias rewards_diff = rewards_2 - rewards_1 return { "preference": self.tree.predict(rewards_diff)[0], "rewards": np.concatenate([rewards_1, rewards_2]), "attributes": self.attributes }