Spaces:
Sleeping
Sleeping
alfiannajih
commited on
Commit
·
172ecbe
1
Parent(s):
04d0886
update pipeline
Browse files- .gitignore +4 -0
- app.py +107 -57
- configuration.py +40 -0
- kg_retrieval.py +168 -0
- requirements.txt +12 -1
- utils.py +104 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
flagged
|
3 |
+
__pycache__
|
4 |
+
.env
|
app.py
CHANGED
@@ -1,63 +1,113 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
message,
|
12 |
-
history: list[tuple[str, str]],
|
13 |
-
system_message,
|
14 |
-
max_tokens,
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
-
|
20 |
-
for val in history:
|
21 |
-
if val[0]:
|
22 |
-
messages.append({"role": "user", "content": val[0]})
|
23 |
-
if val[1]:
|
24 |
-
messages.append({"role": "assistant", "content": val[1]})
|
25 |
-
|
26 |
-
messages.append({"role": "user", "content": message})
|
27 |
-
|
28 |
-
response = ""
|
29 |
-
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
-
|
39 |
-
response += token
|
40 |
-
yield response
|
41 |
-
|
42 |
-
"""
|
43 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
44 |
-
"""
|
45 |
-
demo = gr.ChatInterface(
|
46 |
-
respond,
|
47 |
-
additional_inputs=[
|
48 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
49 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
50 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
51 |
-
gr.Slider(
|
52 |
-
minimum=0.1,
|
53 |
-
maximum=1.0,
|
54 |
-
value=0.95,
|
55 |
-
step=0.05,
|
56 |
-
label="Top-p (nucleus sampling)",
|
57 |
-
),
|
58 |
-
],
|
59 |
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
if __name__ == "__main__":
|
63 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from gradio_pdf import PDF
|
3 |
+
from transformers import pipeline
|
4 |
+
|
5 |
+
from configuration import ConfigurationManager
|
6 |
+
from kg_retrieval import Neo4JConnection, KnowledgeGraphRetrievalPipeline
|
7 |
+
|
8 |
+
pipe = pipeline(
|
9 |
+
"g-retriever-task",
|
10 |
+
model="alfiannajih/g-retriever",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
)
|
12 |
|
13 |
+
config = ConfigurationManager()
|
14 |
+
|
15 |
+
neo4j_config = config.get_neo4j_config()
|
16 |
+
neo4j_connection = Neo4JConnection(neo4j_config)
|
17 |
+
|
18 |
+
kg_retrieval_config = config.get_kg_retrieval_config()
|
19 |
+
kg_retrieval_pipeline = KnowledgeGraphRetrievalPipeline(kg_retrieval_config, neo4j_connection)
|
20 |
+
|
21 |
+
def parsing_pdf(pdf_file):
|
22 |
+
pass
|
23 |
+
|
24 |
+
def retrieve_kg(pdf_file, description):
|
25 |
+
resume = parsing_pdf(pdf_file)
|
26 |
+
|
27 |
+
subgraph, textualized_graph = kg_retrieval_pipeline.graph_retrieval_pipeline(resume, description)
|
28 |
+
|
29 |
+
return subgraph, textualized_graph
|
30 |
+
|
31 |
+
def get_feedback(pdf_file, description, max_new_tokens, temperature, top_p):
|
32 |
+
graph, textualized_graph = retrieve_kg(pdf_file, description)
|
33 |
+
|
34 |
+
inputs = {
|
35 |
+
"inputs": description,
|
36 |
+
"textualized_graph": textualized_graph,
|
37 |
+
"graph": graph
|
38 |
+
}
|
39 |
+
|
40 |
+
generate_kwargs = {
|
41 |
+
"max_new_tokens": max_new_tokens,
|
42 |
+
"temperature": temperature,
|
43 |
+
"top_p": top_p,
|
44 |
+
"do_sample": True
|
45 |
+
}
|
46 |
+
|
47 |
+
generated = pipe(
|
48 |
+
generate_kwargs=generate_kwargs,
|
49 |
+
**inputs
|
50 |
+
)
|
51 |
+
|
52 |
+
return generated
|
53 |
+
|
54 |
+
with gr.Blocks() as demo:
|
55 |
+
gr.Markdown(
|
56 |
+
"""
|
57 |
+
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
58 |
+
Mentor AI
|
59 |
+
</div>
|
60 |
+
<div style="text-align: center; font-size: 20px; margin-bottom: 20px;">
|
61 |
+
A demo application that provides feedback on your resume using the G-Retriever framework, an LLM powered by a Knowledge Graph.
|
62 |
+
</div>
|
63 |
+
"""
|
64 |
+
)
|
65 |
+
with gr.Row(equal_height=True):
|
66 |
+
with gr.Column():
|
67 |
+
pdf_file = PDF(label="Resume Fle")
|
68 |
+
with gr.Column():
|
69 |
+
feedback = gr.Textbox(label="Feedback", interactive=False, lines=15)
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
with gr.Column():
|
73 |
+
description = gr.Textbox(
|
74 |
+
label="Description",
|
75 |
+
lines=3,
|
76 |
+
)
|
77 |
+
submit = gr.Button("Get Feedback")
|
78 |
+
|
79 |
+
with gr.Column():
|
80 |
+
max_new_tokens = gr.Slider(
|
81 |
+
minimum=0,
|
82 |
+
maximum=256,
|
83 |
+
value=128,
|
84 |
+
step=1.0,
|
85 |
+
label="Maximum Output Length",
|
86 |
+
interactive=True
|
87 |
+
)
|
88 |
+
top_p = gr.Slider(
|
89 |
+
minimum=0,
|
90 |
+
maximum=1,
|
91 |
+
value=0.8,
|
92 |
+
step=0.01,
|
93 |
+
label="Top P",
|
94 |
+
interactive=True
|
95 |
+
)
|
96 |
+
temperature = gr.Slider(
|
97 |
+
minimum=0.01,
|
98 |
+
maximum=2,
|
99 |
+
value=1,
|
100 |
+
step=0.01,
|
101 |
+
label="Temperature",
|
102 |
+
interactive=True
|
103 |
+
)
|
104 |
+
|
105 |
+
submit.click(
|
106 |
+
fn=get_feedback,
|
107 |
+
inputs=[pdf_file, description, max_new_tokens, temperature, top_p],
|
108 |
+
outputs=[feedback],
|
109 |
+
show_progress=True
|
110 |
+
)
|
111 |
|
112 |
if __name__ == "__main__":
|
113 |
demo.launch()
|
configuration.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import os
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
@dataclass(frozen=True)
|
8 |
+
class Neo4jConfig:
|
9 |
+
neo4j_uri: str
|
10 |
+
neo4j_user: str
|
11 |
+
neo4j_password: str
|
12 |
+
neo4j_db: str
|
13 |
+
|
14 |
+
@dataclass(frozen=True)
|
15 |
+
class KGRetrievalConfig(Neo4jConfig):
|
16 |
+
embedding_model: str
|
17 |
+
|
18 |
+
class ConfigurationManager:
|
19 |
+
def __init__(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_neo4j_config(self) -> Neo4jConfig:
|
23 |
+
config = Neo4jConfig(
|
24 |
+
neo4j_uri=os.getenv("NEO4J_URI"),
|
25 |
+
neo4j_user=os.getenv("NEO4J_USER"),
|
26 |
+
neo4j_password=os.getenv("NEO4J_PASSWORD"),
|
27 |
+
neo4j_db=os.getenv("NEO4J_DB")
|
28 |
+
)
|
29 |
+
return config
|
30 |
+
|
31 |
+
def get_kg_retrieval_config(self) -> KGRetrievalConfig:
|
32 |
+
config = KGRetrievalConfig(
|
33 |
+
neo4j_uri=os.getenv("NEO4J_URI"),
|
34 |
+
neo4j_user=os.getenv("NEO4J_USER"),
|
35 |
+
neo4j_password=os.getenv("NEO4J_PASSWORD"),
|
36 |
+
neo4j_db=os.getenv("NEO4J_DB"),
|
37 |
+
embedding_model="thenlper/gte-base"
|
38 |
+
)
|
39 |
+
|
40 |
+
return config
|
kg_retrieval.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from neo4j import GraphDatabase
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
from torch_geometric.data import Data
|
5 |
+
|
6 |
+
from configuration import Neo4jConfig, KGRetrievalConfig
|
7 |
+
from utils import get_emb_model, retrieval_via_pcst
|
8 |
+
|
9 |
+
class Neo4JConnection:
|
10 |
+
def __init__(self, config: Neo4jConfig):
|
11 |
+
self.driver = GraphDatabase.driver(config.neo4j_uri, auth=(config.neo4j_user, config.neo4j_password))
|
12 |
+
self.db = config.neo4j_db
|
13 |
+
self.driver.verify_connectivity()
|
14 |
+
|
15 |
+
def get_session(self):
|
16 |
+
return self.driver.session(database=self.db)
|
17 |
+
|
18 |
+
def get_head_node(self, relation_ids):
|
19 |
+
node_ids = self.driver.execute_query(
|
20 |
+
"""
|
21 |
+
MATCH (h)-[r]->()
|
22 |
+
WHERE elementId(r) IN {}
|
23 |
+
RETURN DISTINCT elementId(h) AS id
|
24 |
+
""".format(relation_ids)
|
25 |
+
)
|
26 |
+
nodes = [node.value() for node in node_ids.records]
|
27 |
+
|
28 |
+
return nodes
|
29 |
+
|
30 |
+
def get_tail_node(self, relation_ids):
|
31 |
+
node_ids = self.driver.execute_query(
|
32 |
+
"""
|
33 |
+
MATCH ()-[r]->(t)
|
34 |
+
WHERE elementId(r) IN {}
|
35 |
+
RETURN DISTINCT elementId(t) AS id
|
36 |
+
""".format(relation_ids)
|
37 |
+
)
|
38 |
+
nodes = [node.value() for node in node_ids.records]
|
39 |
+
|
40 |
+
return nodes
|
41 |
+
|
42 |
+
def get_tail_connection_from_head(self, head_ids):
|
43 |
+
relation_ids = self.driver.execute_query(
|
44 |
+
"""
|
45 |
+
MATCH (h)-[r]->()
|
46 |
+
WHERE elementId(h) IN {}
|
47 |
+
RETURN DISTINCT elementId(r) AS id LIMIT 50
|
48 |
+
""".format(head_ids)
|
49 |
+
)
|
50 |
+
relations = [relation.value() for relation in relation_ids.records]
|
51 |
+
|
52 |
+
return relations
|
53 |
+
|
54 |
+
def close(self):
|
55 |
+
self.driver.close()
|
56 |
+
|
57 |
+
class KnowledgeGraphRetrieval:
|
58 |
+
def __init__(self, config: KGRetrievalConfig, neo4j_connection: Neo4JConnection):
|
59 |
+
self.config = config
|
60 |
+
self.neo4j_connection = neo4j_connection
|
61 |
+
|
62 |
+
self.embedding_model = get_emb_model(self.config.embedding_model)
|
63 |
+
|
64 |
+
def query_relationship_from_node(self, query, n_query):
|
65 |
+
similar_relations = self.neo4j_connection.driver.execute_query(
|
66 |
+
"""
|
67 |
+
CALL db.index.vector.queryNodes('JobTitleIndex', {}, {})
|
68 |
+
YIELD node, score
|
69 |
+
MATCH p=(node)-[r:offered_by]->(connectedNode)
|
70 |
+
RETURN elementId(r) AS id, r.job_description, r.location
|
71 |
+
""".format(n_query, query)
|
72 |
+
)
|
73 |
+
|
74 |
+
relations = []
|
75 |
+
for relation in similar_relations.records:
|
76 |
+
_id = relation.get("id")
|
77 |
+
text = "Job description: {}".format(relation.get("r.job_description"), relation.get("r.location"))
|
78 |
+
|
79 |
+
relations.append({"rel_id": _id, "text": text})
|
80 |
+
|
81 |
+
return relations
|
82 |
+
|
83 |
+
class KnowledgeGraphRetrievalPipeline(KnowledgeGraphRetrieval):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
config: KGRetrievalConfig,
|
87 |
+
neo4j_connection: Neo4JConnection
|
88 |
+
):
|
89 |
+
KnowledgeGraphRetrieval.__init__(self, config, neo4j_connection)
|
90 |
+
|
91 |
+
def triples_retrieval(self, resume, desc, top_emb=5):
|
92 |
+
query = resume + [desc]
|
93 |
+
query_emb = self.embedding_model.encode(query, show_progress_bar=False).mean(axis=0).tolist()
|
94 |
+
|
95 |
+
relations = self.query_relationship_from_node(query_emb, top_emb)
|
96 |
+
relation_ids = [r["rel_id"] for r in relations]
|
97 |
+
|
98 |
+
tail_ids = self.neo4j_connection.get_tail_node(relation_ids)
|
99 |
+
tail_connection = self.neo4j_connection.get_tail_connection_from_head(tail_ids)
|
100 |
+
|
101 |
+
head_ids = self.neo4j_connection.get_head_node(relation_ids)
|
102 |
+
head_connection = self.neo4j_connection.get_tail_connection_from_head(head_ids)
|
103 |
+
|
104 |
+
return relation_ids + tail_connection + head_connection, torch.tensor(query_emb)
|
105 |
+
|
106 |
+
def build_graph(self, triples, query_emb):
|
107 |
+
with self.neo4j_connection.get_session() as session:
|
108 |
+
result = session.run(
|
109 |
+
"""
|
110 |
+
MATCH (h)-[r]->(t)
|
111 |
+
WHERE elementId(r) IN {}
|
112 |
+
RETURN h.name AS h_name, h.embedding AS h_embedding, TYPE(r) AS r_type, r.embedding AS r_embedding, r.description AS job_description, t.embedding AS t_embedding, t.name AS t_name
|
113 |
+
""".format(triples)
|
114 |
+
)
|
115 |
+
|
116 |
+
head_nodes = []
|
117 |
+
tail_nodes = []
|
118 |
+
node_embedding = []
|
119 |
+
node_mapping = {}
|
120 |
+
edge_attr = []
|
121 |
+
edges = []
|
122 |
+
nodes = {}
|
123 |
+
|
124 |
+
for rec in result:
|
125 |
+
if rec.get("h_name") not in node_mapping:
|
126 |
+
node_embedding.append(rec.get("h_embedding"))
|
127 |
+
nodes[len(node_mapping)] = rec.get("h_name")
|
128 |
+
node_mapping[rec.get("h_name")] = len(node_mapping)
|
129 |
+
|
130 |
+
if rec.get("t_name") not in node_mapping:
|
131 |
+
node_embedding.append(rec.get("t_embedding"))
|
132 |
+
nodes[len(node_mapping)] = rec.get("t_name")
|
133 |
+
node_mapping[rec.get("t_name")] = len(node_mapping)
|
134 |
+
|
135 |
+
head_nodes.append(rec.get("h_name"))
|
136 |
+
tail_nodes.append(rec.get("t_name"))
|
137 |
+
edge_attr.append(rec.get("r_embedding"))
|
138 |
+
|
139 |
+
if rec.get("job_description") != None:
|
140 |
+
textualized_prop = "{}\nJob Description: {}".format(rec.get("r_type"), rec.get("job_description"))
|
141 |
+
else:
|
142 |
+
textualized_prop = rec.get("r_type")
|
143 |
+
|
144 |
+
edges.append({
|
145 |
+
"src": node_mapping[rec.get("h_name")],
|
146 |
+
"edge_attr": textualized_prop,
|
147 |
+
"dst": node_mapping[rec.get("t_name")]
|
148 |
+
})
|
149 |
+
|
150 |
+
src = [node_mapping[index] for index in head_nodes]
|
151 |
+
dst = [node_mapping[index] for index in tail_nodes]
|
152 |
+
|
153 |
+
edge_index = torch.tensor([src, dst])
|
154 |
+
edge_attr = torch.tensor(edge_attr)
|
155 |
+
|
156 |
+
graph = Data(x=torch.tensor(node_embedding), edge_index=edge_index, edge_attr=edge_attr)
|
157 |
+
nodes = pd.DataFrame([{'node_id': k, 'node_attr': v} for k, v in nodes.items()], columns=['node_id', 'node_attr'])
|
158 |
+
edges = pd.DataFrame(edges, columns=['src', 'edge_attr', 'dst'])
|
159 |
+
|
160 |
+
subgraph, desc = retrieval_via_pcst(graph, query_emb, nodes, edges, topk=10, topk_e=3, cost_e=0.5)
|
161 |
+
|
162 |
+
return subgraph, desc
|
163 |
+
|
164 |
+
def graph_retrieval_pipeline(self, resume, desc, top_emb=5):
|
165 |
+
triples, query_emb = self.triples_retrieval(resume, desc, top_emb)
|
166 |
+
subgraph, textualize_graph = self.build_graph(triples, query_emb)
|
167 |
+
|
168 |
+
return subgraph, textualize_graph
|
requirements.txt
CHANGED
@@ -1 +1,12 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.44.0
|
2 |
+
gradio==4.42.0
|
3 |
+
torch_geometric==2.5.3
|
4 |
+
torch==2.4.0
|
5 |
+
neo4j==5.22.0
|
6 |
+
numpy==1.26.4
|
7 |
+
pandas==2.2.2
|
8 |
+
pcst_fast==1.0.10
|
9 |
+
sentence-transformers==3.0.1
|
10 |
+
python-dotenv==1.0.1
|
11 |
+
sentencepiece==0.2.0
|
12 |
+
python-dotenv
|
utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pcst_fast import pcst_fast
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
import torch
|
5 |
+
from torch_geometric.data import Data
|
6 |
+
|
7 |
+
def retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk=3, topk_e=3, cost_e=0.5):
|
8 |
+
c = 0.01
|
9 |
+
if len(textual_nodes) == 0 or len(textual_edges) == 0:
|
10 |
+
desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
|
11 |
+
graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
|
12 |
+
return graph, desc
|
13 |
+
|
14 |
+
root = -1 # unrooted
|
15 |
+
num_clusters = 1
|
16 |
+
pruning = 'gw'
|
17 |
+
verbosity_level = 0
|
18 |
+
if topk > 0:
|
19 |
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.x)
|
20 |
+
topk = min(topk, graph.num_nodes)
|
21 |
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
22 |
+
|
23 |
+
n_prizes = torch.zeros_like(n_prizes)
|
24 |
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
25 |
+
else:
|
26 |
+
n_prizes = torch.zeros(graph.num_nodes)
|
27 |
+
|
28 |
+
if topk_e > 0:
|
29 |
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr)
|
30 |
+
topk_e = min(topk_e, e_prizes.unique().size(0))
|
31 |
+
|
32 |
+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
|
33 |
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
34 |
+
last_topk_e_value = topk_e
|
35 |
+
for k in range(topk_e):
|
36 |
+
indices = e_prizes == topk_e_values[k]
|
37 |
+
value = min((topk_e-k)/sum(indices), last_topk_e_value)
|
38 |
+
e_prizes[indices] = value
|
39 |
+
last_topk_e_value = value*(1-c)
|
40 |
+
# reduce the cost of the edges such that at least one edge is selected
|
41 |
+
cost_e = min(cost_e, e_prizes.max().item()*(1-c/2))
|
42 |
+
else:
|
43 |
+
e_prizes = torch.zeros(graph.num_edges)
|
44 |
+
|
45 |
+
costs = []
|
46 |
+
edges = []
|
47 |
+
vritual_n_prizes = []
|
48 |
+
virtual_edges = []
|
49 |
+
virtual_costs = []
|
50 |
+
mapping_n = {}
|
51 |
+
mapping_e = {}
|
52 |
+
for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
|
53 |
+
prize_e = e_prizes[i]
|
54 |
+
if prize_e <= cost_e:
|
55 |
+
mapping_e[len(edges)] = i
|
56 |
+
edges.append((src, dst))
|
57 |
+
costs.append(cost_e - prize_e)
|
58 |
+
else:
|
59 |
+
virtual_node_id = graph.num_nodes + len(vritual_n_prizes)
|
60 |
+
mapping_n[virtual_node_id] = i
|
61 |
+
virtual_edges.append((src, virtual_node_id))
|
62 |
+
virtual_edges.append((virtual_node_id, dst))
|
63 |
+
virtual_costs.append(0)
|
64 |
+
virtual_costs.append(0)
|
65 |
+
vritual_n_prizes.append(prize_e - cost_e)
|
66 |
+
|
67 |
+
prizes = np.concatenate([n_prizes, np.array(vritual_n_prizes)])
|
68 |
+
num_edges = len(edges)
|
69 |
+
if len(virtual_costs) > 0:
|
70 |
+
costs = np.array(costs+virtual_costs)
|
71 |
+
edges = np.array(edges+virtual_edges)
|
72 |
+
|
73 |
+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, pruning, verbosity_level)
|
74 |
+
|
75 |
+
selected_nodes = vertices[vertices < graph.num_nodes]
|
76 |
+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
|
77 |
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
78 |
+
if len(virtual_vertices) > 0:
|
79 |
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
80 |
+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
|
81 |
+
selected_edges = np.array(selected_edges+virtual_edges)
|
82 |
+
|
83 |
+
edge_index = graph.edge_index[:, selected_edges]
|
84 |
+
selected_nodes = np.unique(np.concatenate([selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))
|
85 |
+
|
86 |
+
n = textual_nodes.iloc[selected_nodes]
|
87 |
+
e = textual_edges.iloc[selected_edges]
|
88 |
+
desc = n.to_csv(index=False)+'\n'+e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
|
89 |
+
|
90 |
+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
|
91 |
+
|
92 |
+
x = graph.x[selected_nodes]
|
93 |
+
edge_attr = graph.edge_attr[selected_edges]
|
94 |
+
src = [mapping[i] for i in edge_index[0].tolist()]
|
95 |
+
dst = [mapping[i] for i in edge_index[1].tolist()]
|
96 |
+
edge_index = torch.LongTensor([src, dst])
|
97 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=len(selected_nodes))
|
98 |
+
|
99 |
+
return data, desc
|
100 |
+
|
101 |
+
def get_emb_model(path):
|
102 |
+
model = SentenceTransformer(model_name_or_path=path, trust_remote_code=True)
|
103 |
+
|
104 |
+
return model
|