alfiannajih
commited on
Update g_retriever_pipeline.py
Browse files- g_retriever_pipeline.py +3 -0
g_retriever_pipeline.py
CHANGED
@@ -11,6 +11,7 @@ class GRetrieverPipeline(Pipeline):
|
|
11 |
self.max_txt_len = self.model.config.max_txt_len
|
12 |
self.bos_length = len(self.model.config.bos_id)
|
13 |
self.input_length = 0
|
|
|
14 |
|
15 |
def _sanitize_parameters(self, **kwargs):
|
16 |
preprocess_kwargs = {}
|
@@ -27,6 +28,7 @@ class GRetrieverPipeline(Pipeline):
|
|
27 |
|
28 |
def preprocess(self, inputs, textualized_graph, graph, generate_kwargs=None):
|
29 |
textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
|
|
|
30 |
question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
|
31 |
eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
|
32 |
|
@@ -34,6 +36,7 @@ class GRetrieverPipeline(Pipeline):
|
|
34 |
[-1]*(self.bos_length + 1)
|
35 |
+ textualized_graph_ids
|
36 |
+ question_ids
|
|
|
37 |
+ eos_user_ids
|
38 |
])
|
39 |
model_inputs = {
|
|
|
11 |
self.max_txt_len = self.model.config.max_txt_len
|
12 |
self.bos_length = len(self.model.config.bos_id)
|
13 |
self.input_length = 0
|
14 |
+
self.prompt = "Generate a detailed review of a resume in relation to the current job market, presented as a textual graph. The review should be divided into three sections: strengths, weaknesses, and improvements."
|
15 |
|
16 |
def _sanitize_parameters(self, **kwargs):
|
17 |
preprocess_kwargs = {}
|
|
|
28 |
|
29 |
def preprocess(self, inputs, textualized_graph, graph, generate_kwargs=None):
|
30 |
textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
|
31 |
+
prompt_ids = self.tokenizer(self.prompt, add_special_tokens=False)["input_ids"]
|
32 |
question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
|
33 |
eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
|
34 |
|
|
|
36 |
[-1]*(self.bos_length + 1)
|
37 |
+ textualized_graph_ids
|
38 |
+ question_ids
|
39 |
+
+ prompt_ids
|
40 |
+ eos_user_ids
|
41 |
])
|
42 |
model_inputs = {
|