Safetensors
gemma2
kirigayahitsugi commited on
Commit
d48bb7a
1 Parent(s): 26af4b7

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +281 -0
README.md ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - google/gemma-2-9b-it
5
+ ---
6
+
7
+ # General Preference Representation Model (GPM)
8
+
9
+ + **Authors** (* indicates equal contribution)
10
+
11
+ Yifan Zhang*, Ge Zhang*, Yue Wu*, Kangping Xu, Quanquan Gu
12
+
13
+ + **Paper**: [General Preference Modeling with Preference Representations for Aligning Language Models (https://arxiv.org/abs/2410.02197)](https://arxiv.org/abs/2410.02197)
14
+ + **As Huggingface Daily Papers**: [https://huggingface.co/papers/2410.02197](https://huggingface.co/papers/2410.02197)
15
+ + **Code Repository**: [General-Preference-Model (https://github.com/general-preference/general-preference-model)](https://github.com/general-preference/general-preference-model)
16
+ + **Dataset**: [natolambert/skywork-preferences-80k-v0.1-cleaned](https://huggingface.co/datasets/natolambert/skywork-preferences-80k-v0.1-cleaned)
17
+ + **Base Model**: [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it)
18
+
19
+ ## Overview
20
+
21
+ The General Preference Representation Model (GPM) improves preference-based reward modeling by embedding responses into a latent space to efficiently capture complex, intransitive human preferences. GPM achieves linear query complexity, allowing for expressive preference representation, and outperforms traditional Bradley-Terry (BT) reward models, particularly in handling cyclic preferences.
22
+
23
+ ## Key Features
24
+ - **Preference Representation Learning**: Embeds responses in a multi-dimensional latent space to model intricate human preferences, including cyclic and intransitive structures.
25
+ - **Efficient Querying**: Reduces computational complexity to O(K), compared to O(K²) for traditional methods, making GPM scalable for large response sets.
26
+ - **General Preference Optimization (GPO)**: Introduces a preference score that integrates with reinforcement learning methods to optimize policy alignment with human preferences.
27
+
28
+ ## Evaluation
29
+
30
+ The GPM is evaluated using the [RewardBench](https://github.com/allenai/reward-bench) leaderboard, showing significant improvements over the BT model, with a performance margin of up to 5.6%. GPM also excels in modeling cyclic preferences, achieving 100% accuracy on cyclic datasets.
31
+
32
+ ## Usage
33
+
34
+ To use this model, please refer to the [General Preference Model Code Repository](https://github.com/general-preference/general-preference-model). The repository includes detailed instructions for finetuning, evaluation, and integration of the GPM with downstream tasks. Below is an example code snippet:
35
+
36
+ ```python
37
+ from typing import Optional, List, Dict
38
+ import torch
39
+ import torch.nn as nn
40
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
41
+ import torch.nn.functional as F
42
+ from transformers import AutoTokenizer
43
+
44
+ def get_tokenizer(pretrain, model, padding_side="left", use_fast=True):
45
+ tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
46
+ tokenizer.padding_side = padding_side
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ tokenizer.pad_token_id = tokenizer.eos_token_id
50
+ model.config.pad_token_id = tokenizer.pad_token_id
51
+ return tokenizer
52
+
53
+ def get_reward_model(base_causal_model, base_llm_model, is_general_preference: bool=False, add_prompt_head: bool=False, value_head_dim: int=2):
54
+ class CustomRewardModel(base_causal_model):
55
+
56
+ def __init__(self, config: AutoConfig):
57
+ super().__init__(config)
58
+ setattr(self, self.base_model_prefix, base_llm_model(config))
59
+ if not is_general_preference:
60
+ self.value_head = nn.Linear(config.hidden_size, 1, bias=False)
61
+ else:
62
+ self.value_head = nn.Linear(config.hidden_size, value_head_dim, bias=False)
63
+ if add_prompt_head:
64
+ self.prompt_head = nn.Linear(config.hidden_size, value_head_dim // 2, bias=False)
65
+
66
+ self.is_general_preference = is_general_preference
67
+
68
+ self.post_init()
69
+
70
+ def custom_forward(
71
+ self,
72
+ input_ids: torch.LongTensor = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ return_output=False,
75
+ ) -> torch.Tensor:
76
+ position_ids = attention_mask.long().cumsum(-1) - 1
77
+ position_ids.masked_fill_(attention_mask == 0, 1)
78
+ outputs = getattr(self, self.base_model_prefix)(
79
+ input_ids, attention_mask=attention_mask, position_ids=position_ids
80
+ )
81
+ last_hidden_states = outputs["last_hidden_state"]
82
+
83
+ if not self.is_general_preference:
84
+ values = self.value_head(last_hidden_states).squeeze(-1)
85
+ # left padding in training mode
86
+ if self.training:
87
+ reward = values[:, -1]
88
+ else:
89
+ eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
90
+ reward = values.gather(dim=1, index=eos_indices).squeeze(1)
91
+ if return_output:
92
+ return reward, outputs
93
+ else:
94
+ return reward, None
95
+ else:
96
+ values = self.value_head(last_hidden_states)
97
+ # left padding in training mode
98
+ if self.training:
99
+ reward = values[:, -1, :]
100
+ reward = F.normalize(reward, p=2, dim=-1) # Shape will be [batch_size, value_head_dim]
101
+ else:
102
+ eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1)
103
+ eos_indices = eos_indices.unsqueeze(1) # Change shape to [batch_size, 1]
104
+ reward_list = []
105
+ for dim in range(value_head_dim):
106
+ reward_list.append(values[:,:,dim].gather(dim=1, index=eos_indices))
107
+ reward = torch.cat(reward_list, dim=1)
108
+ reward = F.normalize(reward, p=2, dim=-1) # Shape will be [batch_size, value_head_dim]
109
+ if return_output:
110
+ return reward, outputs
111
+ else:
112
+ return reward, None
113
+
114
+ def create_skew_symmetric_block_matrix(self, dim, device, dtype, prompt_hidden_states):
115
+ """
116
+ Create a batch of skew-symmetric block matrices where each matrix is data-dependent on
117
+ the corresponding prompt_hidden_states. Only the relevant block diagonal parts are generated.
118
+
119
+ Args:
120
+ - dim: Dimension of the square matrix (must be even).
121
+ - prompt_hidden_states: Tensor of shape [batch_size, hidden_dim].
122
+
123
+ Returns:
124
+ - batch_R_matrices: Tensor of shape [batch_size, dim, dim], with skew-symmetric block entries.
125
+ """
126
+ if hasattr(self, 'prompt_head'):
127
+ batch_size = prompt_hidden_states.shape[0]
128
+
129
+ # Ensure that dim is even, as we're creating blocks of size 2x2
130
+ assert dim % 2 == 0, "dim must be even for skew-symmetric block generation"
131
+
132
+ # Pass through the linear layer to get the block diagonal entries (half of the matrix's off-diagonal blocks)
133
+ block_values = self.prompt_head(prompt_hidden_states).view(batch_size, dim // 2)
134
+ block_values = torch.softmax(block_values, dim=-1)
135
+
136
+ # Create a batch of zero matrices [batch_size, dim, dim]
137
+ batch_R_matrices = torch.zeros((batch_size, dim, dim), device=device, dtype=dtype)
138
+
139
+ # Fill only the block diagonal entries with the learned values
140
+ for i in range(0, dim, 2):
141
+ batch_R_matrices[:, i, i + 1] = -block_values[:, i // 2]
142
+ batch_R_matrices[:, i + 1, i] = block_values[:, i // 2] # Skew-symmetric condition
143
+ else:
144
+ raise AttributeError("prompt_head is not defined. Ensure 'add_prompt_head' is set to True during initialization.")
145
+
146
+ return batch_R_matrices
147
+
148
+ return CustomRewardModel
149
+
150
+ def generate_high_dim_result_with_prompt(model, value_head_dim, chosen_reward, rejected_reward, prompt_hidden_states):
151
+ R_matrix = model.create_skew_symmetric_block_matrix(value_head_dim, chosen_reward.device, chosen_reward.dtype, prompt_hidden_states)
152
+ if chosen_reward.device == rejected_reward.device == R_matrix.device:
153
+ transformed_chosen = torch.bmm(chosen_reward.view(chosen_reward.shape[0], 1, value_head_dim), R_matrix.transpose(1, 2))
154
+ result = torch.bmm(transformed_chosen, rejected_reward.view(rejected_reward.shape[0], value_head_dim, 1))
155
+ result = result.view(chosen_reward.shape[0])
156
+ return result
157
+
158
+ class GPMPipeline:
159
+ def __init__(self, model_name_or_path, device=torch.device("cuda:0"), is_general_preference: bool=True, add_prompt_head: bool=True, value_head_dim: int=2, bf16: bool=True, truncation: bool=True, max_length: int=4096, padding: bool=True, tau: float=0.1):
160
+ self.device = device
161
+ self.is_general_preference = is_general_preference
162
+ self.add_prompt_head = add_prompt_head
163
+ self.value_head_dim = value_head_dim
164
+ self.truncation = truncation
165
+ self.max_length = max_length
166
+ self.padding = padding
167
+ self.tau = tau
168
+
169
+ config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
170
+ config._attn_implementation = "flash_attention_2"
171
+ base_class = AutoModel._model_mapping[type(config)]
172
+ base_causal_class = AutoModelForCausalLM._model_mapping.get(type(config), None)
173
+ cls_class = get_reward_model(base_causal_class, base_class, is_general_preference, add_prompt_head, value_head_dim)
174
+
175
+ # configure model
176
+ self.model = cls_class.from_pretrained(
177
+ model_name_or_path,
178
+ config=config,
179
+ trust_remote_code=True,
180
+ torch_dtype=torch.bfloat16 if bf16 else "auto",
181
+ )
182
+ # configure tokenizer
183
+ self.tokenizer = get_tokenizer(model_name_or_path, self.model, "left", use_fast=True)
184
+ self.tokenizer.truncation_side = "right"
185
+
186
+ # prepare model
187
+ self.model.to(device)
188
+ self.model.eval()
189
+
190
+ def __call__(self, samples: List[List[Dict[str, str]]], return_prompt=False):
191
+ input_texts = [self.tokenizer.apply_chat_template(sample, tokenize=False) for sample in samples]
192
+
193
+ inputs = self.tokenizer(
194
+ input_texts,
195
+ truncation=True,
196
+ max_length=self.max_length,
197
+ padding=True,
198
+ return_tensors="pt",
199
+ ).to(self.device)
200
+
201
+ inputs["input_ids"][:, -1] = self.tokenizer.eos_token_id
202
+ inputs["attention_mask"][:, -1] = 1
203
+
204
+ with torch.no_grad():
205
+ rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
206
+
207
+ chosen_response_len_list = []
208
+ if return_prompt:
209
+ prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
210
+ for i in range(len(input_texts)):
211
+ prompt_token = self.tokenizer(
212
+ prompt_texts[i],
213
+ max_length=self.max_length,
214
+ padding=False,
215
+ truncation=True,
216
+ return_tensors="pt",
217
+ )
218
+ chosen_token = self.tokenizer(
219
+ input_texts[i],
220
+ max_length=self.max_length,
221
+ padding=False,
222
+ truncation=True,
223
+ return_tensors="pt",
224
+ )
225
+ chosen_response_len = chosen_token["attention_mask"].sum() - prompt_token["attention_mask"].sum()
226
+ chosen_response_len_list.append(chosen_response_len)
227
+ chosen_response_len = torch.tensor(chosen_response_len_list).view(-1, 1).to(self.device)
228
+ if return_prompt:
229
+ chosen_last_hidden_states = outputs["last_hidden_state"]
230
+ prompt_end_index = chosen_last_hidden_states.size(1) - chosen_response_len - 1
231
+ prompt_end_index_expanded = prompt_end_index.unsqueeze(-1).expand(-1, -1, chosen_last_hidden_states.size(-1))
232
+ prompt_hidden_state = torch.gather(chosen_last_hidden_states, dim=1, index=prompt_end_index_expanded).squeeze(1)
233
+ return rewards, prompt_hidden_state
234
+ else:
235
+ return rewards
236
+
237
+
238
+ prompt_text = "Describe the importance of reading books in today's digital age."
239
+ response1 = "Books remain crucial in the digital era, offering in-depth knowledge and fostering critical thinking. They provide a unique, immersive experience that digital media can't replicate, contributing significantly to personal and intellectual growth."
240
+ response2 = "Books are still useful for learning new things. They help you relax and can be a good break from screens."
241
+
242
+ context1 = [
243
+ {"role": "user", "content": prompt_text},
244
+ {"role": "assistant", "content": response1}
245
+ ]
246
+
247
+ context2 = [
248
+ {"role": "user", "content": prompt_text},
249
+ {"role": "assistant", "content": response2}
250
+ ]
251
+
252
+ rm = GPMPipeline("general-preference/GPM-Gemma-2-9B-it", value_head_dim=4)
253
+
254
+ reward1, prompt_hidden_state = rm([context1], return_prompt=True)
255
+ reward2 = rm([context2])
256
+
257
+ result = generate_high_dim_result_with_prompt(rm.model, rm.value_head_dim, reward1, reward2, prompt_hidden_state)
258
+
259
+ result_batch = result.float().cpu().detach().numpy().tolist()
260
+
261
+ results = []
262
+ [
263
+ results.append(1) if result > 0 else results.append(0)
264
+ for result in result_batch
265
+ ]
266
+
267
+ print(result_batch)
268
+ ```
269
+
270
+ ## Citation
271
+
272
+ If you find this work useful for your research, please consider citing:
273
+
274
+ ```
275
+ @article{zhang2024general,
276
+ title={General Preference Modeling with Preference Representations for Aligning Language Models},
277
+ author={Zhang, Yifan and Zhang, Ge and Wu, Yue and Xu, Kangping and Gu, Quanquan},
278
+ journal={arXiv preprint arXiv:2410.02197},
279
+ year={2024}
280
+ }
281
+ ```