import gradio as gr
import networkx as nx
import spacy
import plotly.graph_objects as go
from sources import RSS_FEEDS
from fetch import fetch_articles
# Imports for the LLM knowledge graph transformer
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
# Load the spaCy model (kept for other tasks if needed)
nlp = spacy.load("en_core_web_sm")
def build_interactive_knowledge_graph(feed_items):
"""
Build an interactive knowledge graph from aggregated RSS feed text using an LLM.
Steps:
1. Combine the title and summary of all feed items.
2. Create a Document and split it into chunks.
3. Use ChatOpenAI and LLMGraphTransformer to get graph information.
4. Merge nodes and relationships into a directed NetworkX graph.
5. Compute a spring layout and convert the graph to a Plotly figure.
6. Compute node hover text showing all outgoing/incoming connections.
7. Re-add arrow annotations to indicate direction (with no extra text).
8. Return the Plotly figure.
"""
# 1. Combine all feed items into one aggregated text.
combined_text = "\n\n".join([f"{item['title']}. {item['summary']}" for item in feed_items])
# 2. Create a Document and split it.
doc = Document(page_content=combined_text)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.split_documents([doc])
# 3. Initialize the LLM and transformer.
llm = ChatOpenAI(temperature=0, model="gpt-4o")
llm_transformer = LLMGraphTransformer(llm=llm)
graph_documents = llm_transformer.convert_to_graph_documents(docs)
# 4. Build a directed NetworkX graph.
G = nx.DiGraph()
for graph_doc in graph_documents:
# Convert the Pydantic model to a dictionary.
gdoc = graph_doc.model_dump()
nodes = gdoc.get("nodes", [])
# In these documents, relationships are stored under "relationships".
relationships = gdoc.get("relationships", [])
# Add nodes.
for node in nodes:
node_id = node.get("id") or node.get("name")
if node_id:
G.add_node(node_id)
# Add relationships as directed edges.
for rel in relationships:
source_obj = rel.get("source", {})
target_obj = rel.get("target", {})
source = source_obj.get("id")
target = target_obj.get("id")
rel_type = rel.get("type", "")
if source and target:
if G.has_edge(source, target):
if "relation_types" in G[source][target]:
if rel_type not in G[source][target]["relation_types"]:
G[source][target]["relation_types"].append(rel_type)
else:
G[source][target]["relation_types"] = [rel_type]
G[source][target]["weight"] += 1
else:
G.add_edge(source, target, weight=1, relation_types=[rel_type])
# 5. Compute positions using a spring layout.
pos = nx.spring_layout(G, k=1.2)
#pos = nx.kamada_kawai_layout(G)
# 6. Prepare node hover text.
# For each node, list all outgoing and incoming connection details.
node_hover = {}
for node in G.nodes():
outgoing = []
for u, v, data in G.out_edges(node, data=True):
rels = ", ".join(data.get("relation_types", []))
outgoing.append(f"Out: {node} - {rels} -> {v}")
incoming = []
for u, v, data in G.in_edges(node, data=True):
rels = ", ".join(data.get("relation_types", []))
incoming.append(f"In: {u} - {rels} -> {node}")
details = outgoing + incoming
if details:
node_hover[node] = "
".join(details)
else:
node_hover[node] = node # Fallback if there are no connections.
# 7. Create node trace using calculated positions and hover text.
node_x = []
node_y = []
node_text = [] # Displayed text is just the node name.
node_hover_list = [] # Custom hover info with connection details.
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_text.append(node)
node_hover_list.append(node_hover.get(node, node))
node_trace = go.Scatter(
x=node_x,
y=node_y,
mode='markers+text',
text=node_text,
textposition="top center",
hoverinfo='text',
hovertext=node_hover_list,
marker=dict(
size=10,
color='#1f78b4'
)
)
# 8. Create edge traces: one trace per edge.
edge_traces = []
for edge in G.edges(data=True):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace = go.Scatter(
x=[x0, x1],
y=[y0, y1],
mode='lines',
line=dict(width=1, color='#888'),
hoverinfo='none'
)
edge_traces.append(edge_trace)
# 9. Build the interactive Plotly figure.
fig = go.Figure(
data=edge_traces + [node_trace],
layout=go.Layout(
title='
Interactive Knowledge Graph (LLM-derived)',
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=1200, # wider figure
height=800, # taller figure
dragmode='pan'
)
)
# 10. Re-add arrow annotations for each edge (without hover text).
for edge in G.edges(data=True):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
fig.add_annotation(
x=x1,
y=y1,
ax=x0,
ay=y0,
xref='x',
yref='y',
axref='x',
ayref='y',
showarrow=True,
arrowhead=3,
arrowcolor='#888',
arrowwidth=2,
text="", # No text; rely on node hover for details.
)
return fig
def get_combined_feed(source_choice, selected_news_sites):
"""
Create an aggregated feed from selected RSS sources
and build an interactive Plotly knowledge graph.
"""
feed_items = []
# Fetch articles from selected news sites.
if "News" in source_choice and selected_news_sites:
selected_feeds = {name: url for name, url in RSS_FEEDS.items() if name in selected_news_sites}
feed_items += fetch_articles(selected_feeds, limit=6)
# Aggregate feed text.
feed_text = "\n\n".join([f"🔹 {item['title']} ({item['published']})\n{item['link']}" for item in feed_items])
# Build an interactive knowledge graph using Plotly.
graph_fig = build_interactive_knowledge_graph(feed_items)
return feed_text, graph_fig
# Define the Gradio interface with a button to trigger processing.
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
source_selector = gr.CheckboxGroup(
["News"], value=["News"], label="Select Sources"
)
news_site_selector = gr.CheckboxGroup(
list(RSS_FEEDS.keys()), value=["BBC", "Wired"], label="News Sites"
)
with gr.Column():
feed_output = gr.Textbox(label="Aggregated Feed", lines=20)
with gr.Row():
with gr.Column():
graph_output = gr.Plot(label="Interactive Knowledge Graph")
# Button to trigger graph generation.
generate_button = gr.Button("Generate Graph")
generate_button.click(
fn=get_combined_feed,
inputs=[source_selector, news_site_selector],
outputs=[feed_output, graph_output]
)
demo.launch()