alfiannajih
commited on
Upload 4 files
Browse files- g_retriever/__init__.py +0 -0
- g_retriever/g_retriever_config.py +31 -0
- g_retriever/g_retriever_model.py +215 -0
- g_retriever/g_retriever_pipeline.py +51 -0
g_retriever/__init__.py
ADDED
File without changes
|
g_retriever/g_retriever_config.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LlamaConfig
|
2 |
+
|
3 |
+
class GRetrieverConfig(LlamaConfig):
|
4 |
+
model_type = "llama"
|
5 |
+
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
max_txt_len: int = 1024,
|
9 |
+
max_new_tokens: int = 256,
|
10 |
+
gnn_num_layers: int = 4,
|
11 |
+
gnn_in_dim: int = 768,
|
12 |
+
gnn_hidden_dim: int = 1024,
|
13 |
+
gnn_num_heads: int = 4,
|
14 |
+
gnn_dropout: int = 0,
|
15 |
+
bos_id: list = [128000, 128006, 882, 128007],
|
16 |
+
**kwargs
|
17 |
+
):
|
18 |
+
pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
|
19 |
+
pretrained_config.update(kwargs)
|
20 |
+
|
21 |
+
self.max_txt_len = max_txt_len
|
22 |
+
self.max_new_tokens = max_new_tokens
|
23 |
+
self.gnn_num_layers = gnn_num_layers
|
24 |
+
self.gnn_in_dim = gnn_in_dim
|
25 |
+
self.gnn_hidden_dim = gnn_hidden_dim
|
26 |
+
self.gnn_num_heads = gnn_num_heads
|
27 |
+
self.gnn_dropout = gnn_dropout
|
28 |
+
self.bos_id = bos_id
|
29 |
+
|
30 |
+
super().__init__(**pretrained_config.to_dict())
|
31 |
+
self.pad_token_id = pretrained_config.eos_token_id
|
g_retriever/g_retriever_model.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from transformers import LlamaForCausalLM
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7 |
+
from transformers.cache_utils import StaticCache
|
8 |
+
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask_with_cache_position
|
9 |
+
from .g_retriever_config import GRetrieverConfig
|
10 |
+
from .gnn import GAT
|
11 |
+
|
12 |
+
from functools import wraps
|
13 |
+
from torch_geometric.nn.pool import global_mean_pool
|
14 |
+
|
15 |
+
class GRetrieverModel(LlamaForCausalLM):
|
16 |
+
config_class = GRetrieverConfig
|
17 |
+
|
18 |
+
def __init__(self, config):
|
19 |
+
super().__init__(config)
|
20 |
+
self.graph_encoder = GAT(
|
21 |
+
in_channels=config.gnn_in_dim,
|
22 |
+
out_channels=config.gnn_hidden_dim,
|
23 |
+
hidden_channels=config.gnn_hidden_dim,
|
24 |
+
num_layers=config.gnn_num_layers,
|
25 |
+
dropout=config.gnn_dropout,
|
26 |
+
num_heads=config.gnn_num_heads,
|
27 |
+
).to(self.model.dtype)
|
28 |
+
|
29 |
+
self.projector = nn.Sequential(
|
30 |
+
nn.Linear(config.gnn_hidden_dim, 2048),
|
31 |
+
nn.Sigmoid(),
|
32 |
+
nn.Linear(2048, self.get_input_embeddings().embedding_dim),
|
33 |
+
).to(self.model.dtype)
|
34 |
+
|
35 |
+
def encode_graphs(self, graph):
|
36 |
+
n_embeds, _ = self.graph_encoder(
|
37 |
+
graph.x.to(self.model.dtype),
|
38 |
+
graph.edge_index.long(),
|
39 |
+
graph.edge_attr.to(self.model.dtype)
|
40 |
+
)
|
41 |
+
|
42 |
+
# mean pooling
|
43 |
+
g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
|
44 |
+
|
45 |
+
return g_embeds
|
46 |
+
|
47 |
+
@wraps(LlamaForCausalLM.forward)
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
input_ids=None,
|
51 |
+
graph=None,
|
52 |
+
attention_mask=None,
|
53 |
+
position_ids=None,
|
54 |
+
past_key_values=None,
|
55 |
+
inputs_embeds=None,
|
56 |
+
labels=None,
|
57 |
+
use_cache=None,
|
58 |
+
output_attentions=None,
|
59 |
+
output_hidden_states=None,
|
60 |
+
return_dict=None,
|
61 |
+
cache_position=None
|
62 |
+
):
|
63 |
+
inputs = input_ids.clone()
|
64 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
65 |
+
output_hidden_states = (
|
66 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
67 |
+
)
|
68 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
69 |
+
|
70 |
+
if (inputs==-1).any():
|
71 |
+
# embed bos prompt
|
72 |
+
bos_embeds = self.get_input_embeddings()(torch.tensor(
|
73 |
+
self.config.bos_id,
|
74 |
+
device=self.model.device
|
75 |
+
))
|
76 |
+
|
77 |
+
# encode graph
|
78 |
+
graph_embeds = self.encode_graphs(graph)
|
79 |
+
graph_embeds = self.projector(graph_embeds).to(self.model.device)
|
80 |
+
|
81 |
+
# prepare for reserved ids (bos+graph)
|
82 |
+
non_tokenized_ids = (inputs == -1).nonzero()
|
83 |
+
non_tokenized_shape = non_tokenized_ids[:, 0], non_tokenized_ids[:, 1]
|
84 |
+
|
85 |
+
# embed inputs
|
86 |
+
inputs[non_tokenized_shape] = self.config.pad_token_id
|
87 |
+
temp_inputs_embeds = self.get_input_embeddings()(inputs)
|
88 |
+
non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
|
89 |
+
|
90 |
+
# replace reserved ids with bos+graph
|
91 |
+
inputs_embeds = temp_inputs_embeds.clone()
|
92 |
+
inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
|
93 |
+
|
94 |
+
else:
|
95 |
+
inputs_embeds = self.get_input_embeddings()(inputs)
|
96 |
+
|
97 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
98 |
+
outputs = self.model(
|
99 |
+
attention_mask=attention_mask,
|
100 |
+
position_ids=position_ids,
|
101 |
+
past_key_values=past_key_values,
|
102 |
+
inputs_embeds=inputs_embeds,
|
103 |
+
use_cache=use_cache,
|
104 |
+
output_attentions=output_attentions,
|
105 |
+
output_hidden_states=output_hidden_states,
|
106 |
+
return_dict=return_dict,
|
107 |
+
cache_position=cache_position,
|
108 |
+
)
|
109 |
+
|
110 |
+
hidden_states = outputs[0]
|
111 |
+
|
112 |
+
if self.config.pretraining_tp > 1:
|
113 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
114 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
115 |
+
logits = torch.cat(logits, dim=-1)
|
116 |
+
else:
|
117 |
+
logits = self.lm_head(hidden_states)
|
118 |
+
logits = logits.float()
|
119 |
+
|
120 |
+
loss = None
|
121 |
+
if labels is not None:
|
122 |
+
# Shift so that tokens < n predict n
|
123 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
124 |
+
shift_labels = labels[..., 1:].contiguous()
|
125 |
+
# Flatten the tokens
|
126 |
+
loss_fct = nn.CrossEntropyLoss()
|
127 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
128 |
+
shift_labels = shift_labels.view(-1)
|
129 |
+
# Enable model parallelism
|
130 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
131 |
+
loss = loss_fct(shift_logits, shift_labels)
|
132 |
+
|
133 |
+
if not return_dict:
|
134 |
+
output = (logits,) + outputs[1:]
|
135 |
+
return (loss,) + output if loss is not None else output
|
136 |
+
|
137 |
+
return CausalLMOutputWithPast(
|
138 |
+
loss=loss,
|
139 |
+
logits=logits,
|
140 |
+
past_key_values=outputs.past_key_values,
|
141 |
+
hidden_states=outputs.hidden_states,
|
142 |
+
attentions=outputs.attentions,
|
143 |
+
)
|
144 |
+
|
145 |
+
def prepare_inputs_for_generation(
|
146 |
+
self,
|
147 |
+
input_ids,
|
148 |
+
graph=None,
|
149 |
+
past_key_values=None,
|
150 |
+
attention_mask=None,
|
151 |
+
inputs_embeds=None,
|
152 |
+
cache_position=None,
|
153 |
+
position_ids=None,
|
154 |
+
use_cache=True,
|
155 |
+
**kwargs,
|
156 |
+
):
|
157 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
158 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
159 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
160 |
+
if past_key_values is not None:
|
161 |
+
if inputs_embeds is not None: # Exception 1
|
162 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
163 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
164 |
+
input_ids = input_ids[:, cache_position]
|
165 |
+
|
166 |
+
if attention_mask is not None and position_ids is None:
|
167 |
+
# create position_ids on the fly for batch generation
|
168 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
169 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
170 |
+
if past_key_values:
|
171 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
172 |
+
|
173 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
174 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
175 |
+
|
176 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
177 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
178 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
179 |
+
else:
|
180 |
+
# The clone here is for the same reason as for `position_ids`.
|
181 |
+
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
182 |
+
|
183 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
184 |
+
if model_inputs["inputs_embeds"] is not None:
|
185 |
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
186 |
+
device = model_inputs["inputs_embeds"].device
|
187 |
+
else:
|
188 |
+
batch_size, sequence_length = model_inputs["input_ids"].shape
|
189 |
+
device = model_inputs["input_ids"].device
|
190 |
+
|
191 |
+
dtype = self.lm_head.weight.dtype
|
192 |
+
min_dtype = torch.finfo(dtype).min
|
193 |
+
|
194 |
+
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
195 |
+
attention_mask,
|
196 |
+
sequence_length=sequence_length,
|
197 |
+
target_length=past_key_values.get_max_length(),
|
198 |
+
dtype=dtype,
|
199 |
+
device=device,
|
200 |
+
min_dtype=min_dtype,
|
201 |
+
cache_position=cache_position,
|
202 |
+
batch_size=batch_size,
|
203 |
+
)
|
204 |
+
|
205 |
+
model_inputs.update(
|
206 |
+
{
|
207 |
+
"graph": graph,
|
208 |
+
"position_ids": position_ids,
|
209 |
+
"cache_position": cache_position,
|
210 |
+
"past_key_values": past_key_values,
|
211 |
+
"use_cache": use_cache,
|
212 |
+
"attention_mask": attention_mask,
|
213 |
+
}
|
214 |
+
)
|
215 |
+
return model_inputs
|
g_retriever/g_retriever_pipeline.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline, AutoTokenizer
|
2 |
+
from torch_geometric.data import Batch
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class GRetrieverPipeline(Pipeline):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
Pipeline.__init__(self, **kwargs)
|
8 |
+
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
|
10 |
+
self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
11 |
+
self.max_txt_len = self.model.config.max_txt_len
|
12 |
+
self.bos_length = len(self.model.config.bos_id)
|
13 |
+
|
14 |
+
def _sanitize_parameters(self, **kwargs):
|
15 |
+
preprocess_kwargs = {}
|
16 |
+
if "textualized_graph" in kwargs:
|
17 |
+
preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]
|
18 |
+
|
19 |
+
if "graph" in kwargs:
|
20 |
+
preprocess_kwargs["graph"] = kwargs["graph"]
|
21 |
+
|
22 |
+
return preprocess_kwargs, {}, {}
|
23 |
+
|
24 |
+
def preprocess(self, inputs, textualized_graph, graph):
|
25 |
+
textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
|
26 |
+
question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
|
27 |
+
eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
|
28 |
+
|
29 |
+
input_ids = torch.tensor([
|
30 |
+
[-1]*(self.bos_length + 1)
|
31 |
+
+ textualized_graph_ids
|
32 |
+
+ question_ids
|
33 |
+
+ eos_user_ids
|
34 |
+
])
|
35 |
+
model_inputs = {
|
36 |
+
"input_ids": input_ids,
|
37 |
+
"attention_mask": torch.ones_like(input_ids)
|
38 |
+
}
|
39 |
+
model_inputs.update({
|
40 |
+
"graph": Batch.from_data_list([graph])
|
41 |
+
})
|
42 |
+
|
43 |
+
return model_inputs
|
44 |
+
|
45 |
+
def _forward(self, model_inputs):
|
46 |
+
model_outputs = self.model.generate(**model_inputs)
|
47 |
+
|
48 |
+
return model_outputs
|
49 |
+
|
50 |
+
def postprocess(self, model_outputs):
|
51 |
+
return model_outputs
|