alfiannajih commited on
Commit
6ec9443
·
verified ·
1 Parent(s): abe7d0e

Upload 4 files

Browse files
g_retriever/__init__.py ADDED
File without changes
g_retriever/g_retriever_config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig
2
+
3
+ class GRetrieverConfig(LlamaConfig):
4
+ model_type = "llama"
5
+
6
+ def __init__(
7
+ self,
8
+ max_txt_len: int = 1024,
9
+ max_new_tokens: int = 256,
10
+ gnn_num_layers: int = 4,
11
+ gnn_in_dim: int = 768,
12
+ gnn_hidden_dim: int = 1024,
13
+ gnn_num_heads: int = 4,
14
+ gnn_dropout: int = 0,
15
+ bos_id: list = [128000, 128006, 882, 128007],
16
+ **kwargs
17
+ ):
18
+ pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
19
+ pretrained_config.update(kwargs)
20
+
21
+ self.max_txt_len = max_txt_len
22
+ self.max_new_tokens = max_new_tokens
23
+ self.gnn_num_layers = gnn_num_layers
24
+ self.gnn_in_dim = gnn_in_dim
25
+ self.gnn_hidden_dim = gnn_hidden_dim
26
+ self.gnn_num_heads = gnn_num_heads
27
+ self.gnn_dropout = gnn_dropout
28
+ self.bos_id = bos_id
29
+
30
+ super().__init__(**pretrained_config.to_dict())
31
+ self.pad_token_id = pretrained_config.eos_token_id
g_retriever/g_retriever_model.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import LlamaForCausalLM
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.cache_utils import StaticCache
8
+ from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask_with_cache_position
9
+ from .g_retriever_config import GRetrieverConfig
10
+ from .gnn import GAT
11
+
12
+ from functools import wraps
13
+ from torch_geometric.nn.pool import global_mean_pool
14
+
15
+ class GRetrieverModel(LlamaForCausalLM):
16
+ config_class = GRetrieverConfig
17
+
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ self.graph_encoder = GAT(
21
+ in_channels=config.gnn_in_dim,
22
+ out_channels=config.gnn_hidden_dim,
23
+ hidden_channels=config.gnn_hidden_dim,
24
+ num_layers=config.gnn_num_layers,
25
+ dropout=config.gnn_dropout,
26
+ num_heads=config.gnn_num_heads,
27
+ ).to(self.model.dtype)
28
+
29
+ self.projector = nn.Sequential(
30
+ nn.Linear(config.gnn_hidden_dim, 2048),
31
+ nn.Sigmoid(),
32
+ nn.Linear(2048, self.get_input_embeddings().embedding_dim),
33
+ ).to(self.model.dtype)
34
+
35
+ def encode_graphs(self, graph):
36
+ n_embeds, _ = self.graph_encoder(
37
+ graph.x.to(self.model.dtype),
38
+ graph.edge_index.long(),
39
+ graph.edge_attr.to(self.model.dtype)
40
+ )
41
+
42
+ # mean pooling
43
+ g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
+
45
+ return g_embeds
46
+
47
+ @wraps(LlamaForCausalLM.forward)
48
+ def forward(
49
+ self,
50
+ input_ids=None,
51
+ graph=None,
52
+ attention_mask=None,
53
+ position_ids=None,
54
+ past_key_values=None,
55
+ inputs_embeds=None,
56
+ labels=None,
57
+ use_cache=None,
58
+ output_attentions=None,
59
+ output_hidden_states=None,
60
+ return_dict=None,
61
+ cache_position=None
62
+ ):
63
+ inputs = input_ids.clone()
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
+ output_hidden_states = (
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
67
+ )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
69
+
70
+ if (inputs==-1).any():
71
+ # embed bos prompt
72
+ bos_embeds = self.get_input_embeddings()(torch.tensor(
73
+ self.config.bos_id,
74
+ device=self.model.device
75
+ ))
76
+
77
+ # encode graph
78
+ graph_embeds = self.encode_graphs(graph)
79
+ graph_embeds = self.projector(graph_embeds).to(self.model.device)
80
+
81
+ # prepare for reserved ids (bos+graph)
82
+ non_tokenized_ids = (inputs == -1).nonzero()
83
+ non_tokenized_shape = non_tokenized_ids[:, 0], non_tokenized_ids[:, 1]
84
+
85
+ # embed inputs
86
+ inputs[non_tokenized_shape] = self.config.pad_token_id
87
+ temp_inputs_embeds = self.get_input_embeddings()(inputs)
88
+ non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
89
+
90
+ # replace reserved ids with bos+graph
91
+ inputs_embeds = temp_inputs_embeds.clone()
92
+ inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
93
+
94
+ else:
95
+ inputs_embeds = self.get_input_embeddings()(inputs)
96
+
97
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
98
+ outputs = self.model(
99
+ attention_mask=attention_mask,
100
+ position_ids=position_ids,
101
+ past_key_values=past_key_values,
102
+ inputs_embeds=inputs_embeds,
103
+ use_cache=use_cache,
104
+ output_attentions=output_attentions,
105
+ output_hidden_states=output_hidden_states,
106
+ return_dict=return_dict,
107
+ cache_position=cache_position,
108
+ )
109
+
110
+ hidden_states = outputs[0]
111
+
112
+ if self.config.pretraining_tp > 1:
113
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
114
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
115
+ logits = torch.cat(logits, dim=-1)
116
+ else:
117
+ logits = self.lm_head(hidden_states)
118
+ logits = logits.float()
119
+
120
+ loss = None
121
+ if labels is not None:
122
+ # Shift so that tokens < n predict n
123
+ shift_logits = logits[..., :-1, :].contiguous()
124
+ shift_labels = labels[..., 1:].contiguous()
125
+ # Flatten the tokens
126
+ loss_fct = nn.CrossEntropyLoss()
127
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
128
+ shift_labels = shift_labels.view(-1)
129
+ # Enable model parallelism
130
+ shift_labels = shift_labels.to(shift_logits.device)
131
+ loss = loss_fct(shift_logits, shift_labels)
132
+
133
+ if not return_dict:
134
+ output = (logits,) + outputs[1:]
135
+ return (loss,) + output if loss is not None else output
136
+
137
+ return CausalLMOutputWithPast(
138
+ loss=loss,
139
+ logits=logits,
140
+ past_key_values=outputs.past_key_values,
141
+ hidden_states=outputs.hidden_states,
142
+ attentions=outputs.attentions,
143
+ )
144
+
145
+ def prepare_inputs_for_generation(
146
+ self,
147
+ input_ids,
148
+ graph=None,
149
+ past_key_values=None,
150
+ attention_mask=None,
151
+ inputs_embeds=None,
152
+ cache_position=None,
153
+ position_ids=None,
154
+ use_cache=True,
155
+ **kwargs,
156
+ ):
157
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
158
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
159
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
160
+ if past_key_values is not None:
161
+ if inputs_embeds is not None: # Exception 1
162
+ input_ids = input_ids[:, -cache_position.shape[0] :]
163
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
164
+ input_ids = input_ids[:, cache_position]
165
+
166
+ if attention_mask is not None and position_ids is None:
167
+ # create position_ids on the fly for batch generation
168
+ position_ids = attention_mask.long().cumsum(-1) - 1
169
+ position_ids.masked_fill_(attention_mask == 0, 1)
170
+ if past_key_values:
171
+ position_ids = position_ids[:, -input_ids.shape[1] :]
172
+
173
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
174
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
175
+
176
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
177
+ if inputs_embeds is not None and cache_position[0] == 0:
178
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
179
+ else:
180
+ # The clone here is for the same reason as for `position_ids`.
181
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
182
+
183
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
184
+ if model_inputs["inputs_embeds"] is not None:
185
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
186
+ device = model_inputs["inputs_embeds"].device
187
+ else:
188
+ batch_size, sequence_length = model_inputs["input_ids"].shape
189
+ device = model_inputs["input_ids"].device
190
+
191
+ dtype = self.lm_head.weight.dtype
192
+ min_dtype = torch.finfo(dtype).min
193
+
194
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
195
+ attention_mask,
196
+ sequence_length=sequence_length,
197
+ target_length=past_key_values.get_max_length(),
198
+ dtype=dtype,
199
+ device=device,
200
+ min_dtype=min_dtype,
201
+ cache_position=cache_position,
202
+ batch_size=batch_size,
203
+ )
204
+
205
+ model_inputs.update(
206
+ {
207
+ "graph": graph,
208
+ "position_ids": position_ids,
209
+ "cache_position": cache_position,
210
+ "past_key_values": past_key_values,
211
+ "use_cache": use_cache,
212
+ "attention_mask": attention_mask,
213
+ }
214
+ )
215
+ return model_inputs
g_retriever/g_retriever_pipeline.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline, AutoTokenizer
2
+ from torch_geometric.data import Batch
3
+ import torch
4
+
5
+ class GRetrieverPipeline(Pipeline):
6
+ def __init__(self, **kwargs):
7
+ Pipeline.__init__(self, **kwargs)
8
+
9
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
10
+ self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
11
+ self.max_txt_len = self.model.config.max_txt_len
12
+ self.bos_length = len(self.model.config.bos_id)
13
+
14
+ def _sanitize_parameters(self, **kwargs):
15
+ preprocess_kwargs = {}
16
+ if "textualized_graph" in kwargs:
17
+ preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]
18
+
19
+ if "graph" in kwargs:
20
+ preprocess_kwargs["graph"] = kwargs["graph"]
21
+
22
+ return preprocess_kwargs, {}, {}
23
+
24
+ def preprocess(self, inputs, textualized_graph, graph):
25
+ textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
26
+ question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
27
+ eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
28
+
29
+ input_ids = torch.tensor([
30
+ [-1]*(self.bos_length + 1)
31
+ + textualized_graph_ids
32
+ + question_ids
33
+ + eos_user_ids
34
+ ])
35
+ model_inputs = {
36
+ "input_ids": input_ids,
37
+ "attention_mask": torch.ones_like(input_ids)
38
+ }
39
+ model_inputs.update({
40
+ "graph": Batch.from_data_list([graph])
41
+ })
42
+
43
+ return model_inputs
44
+
45
+ def _forward(self, model_inputs):
46
+ model_outputs = self.model.generate(**model_inputs)
47
+
48
+ return model_outputs
49
+
50
+ def postprocess(self, model_outputs):
51
+ return model_outputs