Decision-Tree-Reward-Llama-3.1-8B / modeling_decision_tree_reward_model.py
Min-Li's picture
Update modeling_decision_tree_reward_model.py
3ea974c verified
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForSequenceClassification
from sklearn.tree import DecisionTreeClassifier
import os
import pickle
import json
from huggingface_hub import hf_hub_download
from typing import List, Dict, Union
import numpy as np
def convert_to_chat_format(prompt, response=None):
if "<extra_id_1>" in prompt:
"""
Handling HelpSteer2 prompts which may contain multi-turn conversations with the special token <extra_id_1>
"""
turns = prompt.split("<extra_id_1>")
conversation = []
conversation.append({
"role": "user",
"content": turns[0]
})
for i in range(1, len(turns)):
parts = turns[i].split("\n", 1)
role = parts[0]
content = parts[1]
conversation.append({
"role": "assistant" if role == "Assistant" else "user",
"content": content
})
else:
conversation = [{"role": "user", "content": prompt}]
if response is not None:
conversation.append({"role": "assistant", "content": response})
return conversation
def process_conversation(conversation):
for message in conversation:
message["content"] = message["content"].rstrip('\n')
return conversation
class LlamaForDecisionTreeRewardModel(LlamaForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=True)
# Initialize the decision tree
self.tree = None
# Define the default attributes (from HelpSteer2)
self.attributes = ['helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity']
print("Initialized LlamaForDecisionTreeRewardModel")
def load_decision_tree(self, repo_id, filename="decision_tree.pkl"):
# Load the tree from the model's directory
with open(hf_hub_download(repo_id=repo_id, filename=filename), "rb") as f:
self.tree = pickle.load(f)
assert isinstance(self.tree, DecisionTreeClassifier), f"The tree is not a DecisionTreeClassifier. It is a {type(self.tree)}"
with open(hf_hub_download(repo_id=repo_id, filename="config.json"), "r") as f:
config = json.load(f)
label2id_map = config["label2id"]
# Sort labels and ids by ids
labels, ids = zip(*sorted(label2id_map.items(), key=lambda x: x[1]))
labels = list(labels)
self.attributes = labels
@torch.no_grad()
def compare(self, prompt: Union[str, List[Dict[str, str]]], response_1: str, response_2: str, tokenizer, device):
"""
Compare two inputs and return the difference in scores
"""
assert self.tree is not None, "The decision tree is not loaded. Please call load_decision_tree(repo_id, filename) first."
if isinstance(prompt, str):
conversation = convert_to_chat_format(prompt)
elif isinstance(prompt, list):
conversation = prompt
else:
raise ValueError(f"The prompt must be a string or a list of dictionaries, but got {type(prompt)}")
assert isinstance(conversation, list), "The conversation must be a list of dictionaries"
assert len(conversation) >= 1, "The conversation must have at least one message (as prompt)"
assert conversation[-1]["role"] == "user", "The last message in the conversation must be from the user"
conversation_1 = conversation + [{"role": "assistant", "content": response_1}]
conversation_2 = conversation + [{"role": "assistant", "content": response_2}]
conversation_1 = process_conversation(conversation_1)
conversation_2 = process_conversation(conversation_2)
conv_tokenized_1 = tokenizer.apply_chat_template(conversation_1, tokenize=True, return_tensors="pt").to(device)
conv_tokenized_2 = tokenizer.apply_chat_template(conversation_2, tokenize=True, return_tensors="pt").to(device)
embedding_1 = self.forward(conv_tokenized_1, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
embedding_2 = self.forward(conv_tokenized_2, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
weight = self.score.weight.float().cpu().numpy()
bias = self.score.bias.float().cpu().numpy()
rewards_1 = embedding_1 @ weight.T + bias
rewards_2 = embedding_2 @ weight.T + bias
rewards_diff = rewards_2 - rewards_1
return {
"preference": self.tree.predict(rewards_diff)[0],
"rewards": np.concatenate([rewards_1, rewards_2]),
"attributes": self.attributes
}