missvector commited on
Commit
14357d2
·
1 Parent(s): 40118be

Initial commit

Browse files
Files changed (3) hide show
  1. aggile.py +317 -0
  2. app.py +16 -0
  3. requirements.txt +4 -0
aggile.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # aggile.py
2
+ class Aggile:
3
+ """
4
+ Graph generator for plain text
5
+ """
6
+ def __init__(self, client):
7
+ self.client = client
8
+
9
+ n = None
10
+ self.subj_prompt = f"""
11
+ extract {n} collocations describing key concepts, keywords, named entities from the provided source
12
+ """
13
+
14
+ self.obj_prompt = """
15
+ extract 5-10 most representative collocations from the provided source that are related to the provided concept
16
+ """
17
+
18
+ self.pred_prompt = """
19
+ define the relationship between two words: generate a verb or a phrase decribing a relationship between two entities; return a predicate for a knowledge graph triplet
20
+ """
21
+
22
+ def _get_subj(self, text, n=10):
23
+ """
24
+ Extract entities from the text:
25
+ - named entities
26
+ - kewords
27
+ - concepts
28
+
29
+ :text: input text (str)
30
+ :n: the number of genrated entities (int)
31
+
32
+ :return: {core_concepts: list of extracted keywords (subjects that will form triplets)} (dict)
33
+ """
34
+ import ast
35
+ # Generate keywords from the given text using LLM
36
+ core_concepts = self.client.chat.completions.create(messages=
37
+ [
38
+ {
39
+ "role": "system",
40
+ "content": self.subj_prompt
41
+ },
42
+ {
43
+ "role": "user",
44
+ "content": text
45
+ },
46
+ ],
47
+ response_format=
48
+ {
49
+ "type": "json",
50
+ "value":
51
+ {
52
+ "properties":
53
+ {
54
+ "core_concepts":
55
+ {
56
+ "type": "array",
57
+ "items":
58
+ {
59
+ "type": "string"
60
+ }
61
+ },
62
+ }
63
+ }
64
+ },
65
+ stream=False,
66
+ max_tokens=1024,
67
+ temperature=0.5,
68
+ top_p=0.1
69
+ ).choices[0].get('message')['content']
70
+ return ast.literal_eval(core_concepts)
71
+
72
+ def __extract_relations(self, word, text):
73
+ import ast
74
+ """
75
+ Extract relation for the provided concepts (subjects) based on the information from the text:
76
+ - collocations
77
+
78
+ :text: input text (str)
79
+ :concepts: the list of kewords and other key concepts extracted with aggile._get_subj (dict)
80
+
81
+ :return: {related_concepts: list of related words and collocations (objects that will form triplets)} (dict)
82
+ """
83
+ related_concepts = self.client.chat.completions.create(messages=
84
+ [
85
+ {
86
+ "role": "system",
87
+ "content": self.obj_prompt
88
+ },
89
+ {
90
+ "role": "user",
91
+ "content": f"concept = {word}, source = {text}"
92
+ },
93
+ ],
94
+ response_format=
95
+ {
96
+ "type": "json",
97
+ "value":
98
+ {
99
+ "properties":
100
+ {
101
+ "related_concepts":
102
+ {
103
+ "type": "array",
104
+ "items":
105
+ {
106
+ "type": "string"
107
+ }
108
+ },
109
+ }
110
+ }
111
+ },
112
+ stream=False,
113
+ max_tokens=512,
114
+ temperature=0.5,
115
+ top_p=0.1
116
+ ).choices[0].get('message')['content']
117
+ return ast.literal_eval(related_concepts)
118
+
119
+ def _get_obj(self, text):
120
+ """
121
+ Execute the extraction of related concepts for the list of keywords:
122
+ - generate list of objects for each object in the dictionarytract relation for the provided concepts (subjects) based on the information from the text:
123
+
124
+ :text: input text (str)
125
+ :concepts: the list of keywords and other key concepts extracted with aggile._get_subj (dict)
126
+
127
+ :return: {related_concepts: list of related words and collocations (objects that will form triplets)} (dict)
128
+ """
129
+ # Generate list of subjects
130
+ core_concepts = self._get_subj(text, n=10)
131
+ # Get object for each subject
132
+ relations = {word: self.__extract_relations(word, text) for word in core_concepts['core_concepts']}
133
+ return relations
134
+
135
+ def __generate_predicates(self, subj, obj):
136
+ import ast
137
+ """
138
+ Generate predicates between objects and subjects
139
+
140
+ :subj: one generated subject from core_concepts (str)
141
+ :obj: one generated object from relations (str)
142
+ :text: input text (str)
143
+
144
+ :return: one relevant predicate to form triplets (str)
145
+ """
146
+ predicate = self.client.chat.completions.create(messages=
147
+ [
148
+ {
149
+ "role": "system",
150
+ "content": self.pred_prompt
151
+ },
152
+ {
153
+ "role": "user",
154
+ "content": f"what is the relationship between {subj} and {obj}? return a predicate only"
155
+ },
156
+ ],
157
+ response_format=
158
+ {
159
+ "type": "json",
160
+ "value":
161
+ {
162
+ "properties":
163
+ {
164
+ "predicate":
165
+ {
166
+ "type": "string"
167
+ },
168
+ }
169
+ }
170
+ },
171
+ stream=False,
172
+ max_tokens=512,
173
+ temperature=0.5,
174
+ top_p=0.1
175
+ ).choices[0].get('message')['content']
176
+ return ast.literal_eval(predicate)['predicate'] # Return predicate only, not the whole dictionary
177
+
178
+ def form_triples(self, text):
179
+ """
180
+ :text: input text (str) if from_string=True
181
+ """
182
+
183
+ # Generate objects from text
184
+ relations = self._get_obj(text)
185
+ # Placeholder for triplets
186
+ triplets = dict()
187
+ # Form triplets for each subject
188
+ for subj in relations:
189
+ # Placeholder for the current subject
190
+ triplets[subj] = list()
191
+ # For each object generated for this subject:
192
+ for obj in relations[subj]['related_concepts']:
193
+ # Create placeholder with the triplet structure "subject-predicate-object"
194
+ temp = {'subject': subj, 'predicate': '', 'object': ''}
195
+ # Save the object to the triplet
196
+ temp['object'] = obj
197
+ # Generate predicate between the current object and the current subject
198
+ temp['predicate'] = self.__generate_predicates(subj, obj)
199
+ # Hallucincation check: if object and subjects are the same entities, do not append them to the list of triplets
200
+ if temp['subject'] != temp['object']:
201
+ # Otherwise, append the triplet
202
+ triplets[subj].append(temp)
203
+
204
+ return triplets
205
+
206
+ class Graph:
207
+ def __init__(self, triplets):
208
+ self.triplets = triplets
209
+
210
+ def build_graph(self):
211
+ import plotly.graph_objects as go
212
+ import networkx as nx
213
+ from collections import Counter
214
+ import random
215
+
216
+ # Prepare nodes and edges
217
+ nodes = set()
218
+ edges = []
219
+
220
+ # Extract noded and edges from the set of triplets
221
+ for key, values in self.triplets.items():
222
+ for rel in values:
223
+ nodes.add(rel['subject'])
224
+ nodes.add(rel['object'])
225
+ edges.append((rel['subject'], rel['object'], rel['predicate']))
226
+
227
+ # Create a networkx graph
228
+ G = nx.Graph()
229
+
230
+ # Add nodes and edges to the graph
231
+ for edge in edges:
232
+ G.add_edge(edge[0], edge[1], label=edge[2])
233
+
234
+ # Generate positions for nodes using force-directed layout with more space
235
+ pos = nx.spring_layout(G, seed=42) # Increasing k for more spacing
236
+
237
+ # Extract node and edge data for Plotly
238
+ node_x = [pos[node][0] for node in G.nodes()]
239
+ node_y = [pos[node][1] for node in G.nodes()]
240
+ node_labels = list(G.nodes())
241
+
242
+ # Count connections
243
+ node_degrees = Counter([node for edge in edges for node in edge[:2]])
244
+
245
+ # Assign distinct colors for each predicate (use a set to avoid duplicates)
246
+ unique_predicates = list(set([edge[2] for edge in edges]))
247
+ predicate_colors = {predicate: f'rgba({random.randint(0,255)},{random.randint(0,255)},{random.randint(0,255)},1)'
248
+ for predicate in unique_predicates}
249
+
250
+ # Plotly data for edges
251
+ edge_x = []
252
+ edge_y = []
253
+
254
+ for edge in edges:
255
+ x0, y0 = pos[edge[0]]
256
+ x1, y1 = pos[edge[1]]
257
+ edge_x += [x0, x1, None]
258
+ edge_y += [y0, y1, None]
259
+
260
+ # Create the figure
261
+ fig = go.Figure()
262
+
263
+ # Add edges
264
+ fig.add_trace(go.Scatter(
265
+ x=edge_x, y=edge_y,
266
+ line=dict(width=0.5, color='#888'),
267
+ hoverinfo='text',
268
+ mode='lines'
269
+ ))
270
+
271
+ # Add nodes with uniform size and labels
272
+ fig.add_trace(go.Scatter(
273
+ x=node_x, y=node_y,
274
+ mode='markers+text',
275
+ marker=dict(
276
+ size=25, # Uniform node size for all nodes
277
+ color=[node_degrees[node] for node in node_labels],
278
+ #colorscale='Viridis',
279
+ colorbar=dict(title='Connections')
280
+ ),
281
+ text=node_labels,
282
+ hoverinfo='text',
283
+ textposition='top center',
284
+ textfont=dict(size=13, weight="bold")
285
+ ))
286
+
287
+ # Add predicate labels near the nodes with black text
288
+ for edge in edges:
289
+ x0, y0 = pos[edge[0]]
290
+ x1, y1 = pos[edge[1]]
291
+ predicate_label = edge[2]
292
+
293
+ # Calculate the midpoint of the edge and add small offsets to create spacing
294
+ mid_x = (x0 + x1) / 2
295
+ mid_y = (y0 + y1) / 2
296
+
297
+ # Add the label near the midpoint of the edge with black text
298
+ fig.add_trace(go.Scatter(
299
+ x=[mid_x], y=[mid_y],
300
+ mode='text',
301
+ text=[predicate_label],
302
+ textposition='middle center',
303
+ showlegend=False,
304
+ textfont=dict(size=10)
305
+ ))
306
+
307
+ # Update layout
308
+ fig.update_layout(
309
+ showlegend=False,
310
+ margin=dict(l=0, r=0, t=0, b=0),
311
+ xaxis=dict(showgrid=False, zeroline=False),
312
+ yaxis=dict(showgrid=False, zeroline=False),
313
+ title="Force-Directed Graph with Predicate Labels on Nodes"
314
+ )
315
+
316
+ # Save the figure as an HTML file
317
+ fig.write_html("graph_with_predicates.html")
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from huggingface_hub import InferenceClient
3
+ from aggile import Aggile, Graph
4
+ import streamlit as st
5
+
6
+ st.title('AGGILE: Automated Graph Generation for Inference and Language Exploration')
7
+
8
+ text = st.text_input("Enter your text", "")
9
+
10
+ # Initialize Aggile with your HuggingFace credentials; change the model if needed
11
+ client = InferenceClient("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B")
12
+ aggile = Aggile(client=client)
13
+ # Form triplets from the text from string
14
+ triplets = aggile.form_triples('This is a sample text')
15
+ # Visualize graph based on generated triplets
16
+ st.html(Graph(tripletspip).build_graph()) # Saves and shows graph_with_predicates.html
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ huggingface_hub
2
+ networkx
3
+ numpy
4
+ plotly