Spaces:
Sleeping
Sleeping
import gradio as gr | |
import networkx as nx | |
import spacy | |
import plotly.graph_objects as go | |
from sources import RSS_FEEDS | |
from fetch import fetch_articles | |
# Imports for the LLM knowledge graph transformer | |
from langchain.schema import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_openai import ChatOpenAI | |
from langchain_experimental.graph_transformers import LLMGraphTransformer | |
def build_interactive_knowledge_graph(feed_items): | |
""" | |
Build an interactive knowledge graph from aggregated RSS feed text using an LLM. | |
Steps: | |
1. Combine the title and summary of all feed items. | |
2. Create a Document and split it into chunks. | |
3. Use ChatOpenAI and LLMGraphTransformer to get graph information. | |
4. Merge nodes and relationships into a directed NetworkX graph. | |
5. Compute a spring layout and convert the graph to a Plotly figure. | |
6. Compute node hover text showing all outgoing/incoming connections. | |
7. Re-add arrow annotations to indicate direction (with no extra text). | |
8. Return the Plotly figure. | |
""" | |
# 1. Combine all feed items into one aggregated text. | |
combined_text = "\n\n".join([f"{item['title']}. {item['summary']}" for item in feed_items]) | |
# 2. Create a Document and split it. | |
doc = Document(page_content=combined_text) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = text_splitter.split_documents([doc]) | |
# 3. Initialize the LLM and transformer. | |
llm = ChatOpenAI(temperature=0, model="gpt-4o") | |
llm_transformer = LLMGraphTransformer(llm=llm) | |
graph_documents = llm_transformer.convert_to_graph_documents(docs) | |
# 4. Build a directed NetworkX graph. | |
G = nx.DiGraph() | |
for graph_doc in graph_documents: | |
# Convert the Pydantic model to a dictionary. | |
gdoc = graph_doc.model_dump() | |
nodes = gdoc.get("nodes", []) | |
# In these documents, relationships are stored under "relationships". | |
relationships = gdoc.get("relationships", []) | |
# Add nodes. | |
for node in nodes: | |
node_id = node.get("id") or node.get("name") | |
if node_id: | |
G.add_node(node_id) | |
# Add relationships as directed edges. | |
for rel in relationships: | |
source_obj = rel.get("source", {}) | |
target_obj = rel.get("target", {}) | |
source = source_obj.get("id") | |
target = target_obj.get("id") | |
rel_type = rel.get("type", "") | |
if source and target: | |
if G.has_edge(source, target): | |
if "relation_types" in G[source][target]: | |
if rel_type not in G[source][target]["relation_types"]: | |
G[source][target]["relation_types"].append(rel_type) | |
else: | |
G[source][target]["relation_types"] = [rel_type] | |
G[source][target]["weight"] += 1 | |
else: | |
G.add_edge(source, target, weight=1, relation_types=[rel_type]) | |
# 5. Compute positions using a spring layout. | |
pos = nx.spring_layout(G, k=1.2) | |
#pos = nx.kamada_kawai_layout(G) | |
# 6. Prepare node hover text. | |
# For each node, list all outgoing and incoming connection details. | |
node_hover = {} | |
for node in G.nodes(): | |
outgoing = [] | |
for u, v, data in G.out_edges(node, data=True): | |
rels = ", ".join(data.get("relation_types", [])) | |
outgoing.append(f"Out: {node} - {rels} -> {v}") | |
incoming = [] | |
for u, v, data in G.in_edges(node, data=True): | |
rels = ", ".join(data.get("relation_types", [])) | |
incoming.append(f"In: {u} - {rels} -> {node}") | |
details = outgoing + incoming | |
if details: | |
node_hover[node] = "<br>".join(details) | |
else: | |
node_hover[node] = node # Fallback if there are no connections. | |
# 7. Create node trace using calculated positions and hover text. | |
node_x = [] | |
node_y = [] | |
node_text = [] # Displayed text is just the node name. | |
node_hover_list = [] # Custom hover info with connection details. | |
for node in G.nodes(): | |
x, y = pos[node] | |
node_x.append(x) | |
node_y.append(y) | |
node_text.append(node) | |
node_hover_list.append(node_hover.get(node, node)) | |
node_trace = go.Scatter( | |
x=node_x, | |
y=node_y, | |
mode='markers+text', | |
text=node_text, | |
textposition="top center", | |
hoverinfo='text', | |
hovertext=node_hover_list, | |
marker=dict( | |
size=10, | |
color='#1f78b4' | |
) | |
) | |
# 8. Create edge traces: one trace per edge. | |
edge_traces = [] | |
for edge in G.edges(data=True): | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
edge_trace = go.Scatter( | |
x=[x0, x1], | |
y=[y0, y1], | |
mode='lines', | |
line=dict(width=1, color='#888'), | |
hoverinfo='none' | |
) | |
edge_traces.append(edge_trace) | |
# 9. Build the interactive Plotly figure. | |
fig = go.Figure( | |
data=edge_traces + [node_trace], | |
layout=go.Layout( | |
title='<br>Interactive Knowledge Graph (LLM-derived)', | |
showlegend=False, | |
hovermode='closest', | |
margin=dict(b=20, l=5, r=5, t=40), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1200, # wider figure | |
height=800, # taller figure | |
dragmode='pan' | |
) | |
) | |
# 10. Re-add arrow annotations for each edge (without hover text). | |
for edge in G.edges(data=True): | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
fig.add_annotation( | |
x=x1, | |
y=y1, | |
ax=x0, | |
ay=y0, | |
xref='x', | |
yref='y', | |
axref='x', | |
ayref='y', | |
showarrow=True, | |
arrowhead=3, | |
arrowcolor='#888', | |
arrowwidth=2, | |
text="", # No text; rely on node hover for details. | |
) | |
return fig | |
def get_combined_feed(source_choice, selected_news_sites): | |
""" | |
Create an aggregated feed from selected RSS sources | |
and build an interactive Plotly knowledge graph. | |
""" | |
feed_items = [] | |
# Fetch articles from selected news sites. | |
if "News" in source_choice and selected_news_sites: | |
selected_feeds = {name: url for name, url in RSS_FEEDS.items() if name in selected_news_sites} | |
feed_items += fetch_articles(selected_feeds, limit=6) | |
# Aggregate feed text. | |
feed_text = "\n\n".join([f"🔹 {item['title']} ({item['published']})\n{item['link']}" for item in feed_items]) | |
# Build an interactive knowledge graph using Plotly. | |
graph_fig = build_interactive_knowledge_graph(feed_items) | |
return feed_text, graph_fig | |
# Define the Gradio interface with a button to trigger processing. | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
source_selector = gr.CheckboxGroup( | |
["News"], value=["News"], label="Select Sources" | |
) | |
news_site_selector = gr.CheckboxGroup( | |
list(RSS_FEEDS.keys()), value=["BBC", "Wired"], label="News Sites" | |
) | |
with gr.Column(): | |
feed_output = gr.Textbox(label="Aggregated Feed", lines=20) | |
with gr.Row(): | |
with gr.Column(): | |
graph_output = gr.Plot(label="Interactive Knowledge Graph") | |
# Button to trigger graph generation. | |
generate_button = gr.Button("Generate Graph") | |
generate_button.click( | |
fn=get_combined_feed, | |
inputs=[source_selector, news_site_selector], | |
outputs=[feed_output, graph_output] | |
) | |
demo.launch() | |