Spaces:
Running
Running
File size: 13,562 Bytes
f1342ba 1e3f619 7b897df 9b707db c2f2340 1e3f619 7b897df 1e3f619 7b897df 1e3f619 7b897df 1e3f619 7b897df 1e3f619 7b897df 1e3f619 7b897df 1e3f619 7b897df 1e3f619 5596129 1e3f619 507d746 7b897df 9b707db 6237635 7b897df f1342ba 5596129 6237635 5596129 f1342ba 9b707db c2f2340 1e3f619 5596129 c2f2340 f1342ba 7b897df 9b707db 1e3f619 c2f2340 1e3f619 c2f2340 7b897df 1e3f619 7b897df c80303c c2f2340 7b897df 5596129 7b897df 1e3f619 7b897df 1e3f619 7b897df 9b707db 5596129 9b707db 5596129 9b707db c2f2340 9b707db 7b897df fec4816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
import streamlit as st
from utils.kg.construct_kg import get_graph
from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio
from streamlit_agraph import agraph, Node, Edge, Config
import random
import math
from utils.audit.response_llm import generate_response_via_langchain
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import PromptTemplate
def if_node_exists(nodes, node_id):
"""
Check if a node exists in the graph.
Args:
graph (dict): A dictionary representing the graph with keys 'nodes' and 'relationships'.
node_id (str): The id of the node to check.
Returns:
return_value: True if the node exists, False otherwise.
"""
for node in nodes:
if node.id == node_id:
return True
return False
def generate_random_color():
r = random.randint(180, 255)
g = random.randint(180, 255)
b = random.randint(180, 255)
return (r, g, b)
def rgb_to_hex(rgb):
return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
def get_node_types(graph):
node_types = set()
for node in graph.nodes:
node_types.add(node.type)
for relationship in graph.relationships:
source = relationship.source
target = relationship.target
node_types.add(source.type)
node_types.add(target.type)
return node_types
def color_distance(color1, color2):
# Calculate Euclidean distance between two RGB colors
return math.sqrt((color1[0] - color2[0]) ** 2 + (color1[1] - color2[1]) ** 2 + (color1[2] - color2[2]) ** 2)
def generate_distinct_colors(num_colors, min_distance=30):
colors = []
while len(colors) < num_colors:
new_color = generate_random_color()
if all(color_distance(new_color, existing_color) >= min_distance for existing_color in colors):
colors.append(new_color)
return [rgb_to_hex(color) for color in colors]
def list_to_dict_colors(node_types:set):
number_of_colors = len(node_types)
colors = generate_distinct_colors(number_of_colors)
node_colors = {}
for i, node_type in enumerate(node_types):
node_colors[node_type] = colors[i]
return node_colors
def convert_neo4j_to_agraph(neo4j_graph, node_colors):
"""
Converts a Neo4j graph into an Agraph format.
Args:
neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'.
'nodes' is a list of dicts with each dict having 'id' and 'type' keys.
'relationships' is a list of dicts with 'source', 'target', and 'type' keys.
Returns:
return_value: The Agraph visualization object.
"""
nodes = []
edges = []
# Creating Agraph nodes
for node in neo4j_graph.nodes:
# Use the node id as the Agraph node id
node_id = node.id.replace(" ", "_") # Replace spaces with underscores for ids
label = node.id
type = node.type
size = 25 # Default size, can be customized
shape = "circle" # Default shape, can be customized
# For example purposes, no images are added, but you can set 'image' if needed.
nodes.append(Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type]))
# Creating Agraph edges
for relationship in neo4j_graph.relationships:
size = 25 # Default size, can be customized
shape = "circle" # Default shape, can be customized
source = relationship.source
source_type = source.type
source_id = source.id.replace(" ", "_")
label_source = source.id
source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type])
if not if_node_exists(nodes, source_node.id):
nodes.append(source_node)
target = relationship.target
target_type = target.type
target_id = target.id.replace(" ", "_")
label_target = target.id
target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type])
if not if_node_exists(nodes, target_node.id):
nodes.append(target_node)
label = relationship.type
edges.append(Edge(source=source_id, label=label, target=target_id))
# Define the configuration for Agraph
config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json")
# Create the Agraph visualization
return edges, nodes, config
def display_graph(edges, nodes, config):
# Display the Agraph visualization
return agraph(edges=edges, nodes=nodes, config=config)
def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]:
filtered_nodes = []
for node in nodes:
if node.title in node_types_filter: #the title represents the type of the node
filtered_nodes.append(node)
return filtered_nodes
@st.dialog(title="Changer la vue")
def change_view_dialog():
st.write("Changer la vue")
for index, item in enumerate(st.session_state.filter_views.keys()):
emp = st.empty()
col1, col2, col3 = emp.columns([8, 1, 1])
if index > 0 and col2.button("🗑️", key=f"del{index}"):
del st.session_state.filter_views[item]
st.session_state.current_view = "Vue par défaut"
st.rerun()
but_content = "🔍" if st.session_state.current_view != item else "✅"
if col3.button(but_content, key=f"valid{index}"):
st.session_state.current_view = item
st.rerun()
if len(st.session_state.filter_views.keys()) > index:
with col1.expander(item):
st.markdown("\n".join(f"- {label.strip()}" for label in st.session_state.filter_views[item]))
else:
emp.empty()
@st.dialog(title="Ajouter une vue")
def add_view_dialog(filters):
st.write("Ajouter une vue")
view_name = st.text_input("Nom de la vue")
st.markdown("les filtres actuels:")
st.write(filters)
if st.button("Ajouter la vue"):
st.session_state.filter_views[view_name] = filters
st.session_state.current_view = view_name
st.rerun()
@st.dialog(title="Changer la couleur")
def change_color_dialog():
st.write("Changer la couleur")
for node_type,color in st.session_state.node_types.items():
color = st.color_picker(f"La couleur de l'entité **{node_type.strip()}**",color)
st.session_state.node_types[node_type] = color
if st.button("Valider"):
st.rerun()
def kg_main():
#st.set_page_config(page_title="Graphe de connaissance", page_icon="", layout="wide")
if "audit" not in st.session_state or st.session_state.audit == {}:
st.error("Veuillez d'abord effectuer un audit pour visualiser le graphe de connaissance.")
return
if "cr" not in st.session_state:
st.error("Veuillez d'abord effectuer un compte rendu pour visualiser le graphe de connaissance.")
return
if "graph" not in st.session_state:
st.session_state.graph = None
if "filter_views" not in st.session_state:
st.session_state.filter_views = {}
if "current_view" not in st.session_state:
st.session_state.current_view = None
st.title("Graphe de connaissance")
if "node_types" not in st.session_state:
st.session_state.node_types = None
if "summary" not in st.session_state:
st.session_state.summary = None
if "chat_graph_history" not in st.session_state:
st.session_state.chat_graph_history = []
audit = st.session_state.audit_simplified
# content = st.session_state.audit["content"]
# if audit["type de fichier"] == "pdf":
# text = get_text_from_content_for_doc(content)
# elif audit["type de fichier"] == "audio":
# text = get_text_from_content_for_audio(content)
text = st.session_state.cr + "mots clés" + audit["Mots clés"]
#summary_prompt = f"Voici un ensemble de documents : {text}. À partir de ces documents, veuillez fournir des résumés concis en vous concentrant sur l'extraction des relations essentielles et des événements. Il est crucial d'inclure les dates des actions ou des événements, car elles seront utilisées pour l'analyse chronologique. Par exemple : 'Sam a été licencié par le conseil d'administration d'OpenAI le 17 novembre 2023 (17 novembre, vendredi)', ce qui illustre la relation entre Sam et OpenAI ainsi que la date de l'événement."
if st.button("Générer le graphe"):
# with st.spinner("Extractions des relations..."):
# sum = generate_response_openai(summary_prompt,model="gpt-4o")
# st.session_state.summary = sum
with st.spinner("Génération du graphe..."):
keywords_list = audit["Mots clés"].strip().split(",")
allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"]
graph = get_graph(text,allowed_nodes=allowed_nodes_types)
st.session_state.graph = graph
node_types = get_node_types(graph[0])
nodes_type_dict = list_to_dict_colors(node_types)
st.session_state.node_types = nodes_type_dict
st.session_state.filter_views["Vue par défaut"] = list(node_types)
st.session_state.current_view = "Vue par défaut"
else:
graph = st.session_state.graph
if graph is not None:
#st.write(graph)
edges,nodes,config = convert_neo4j_to_agraph(graph[0],st.session_state.node_types)
col1, col2 = st.columns([2.5, 1.5])
with col1.container(border=True,height=800):
st.write("##### Visualisation du graphe (**"+st.session_state.current_view+"**)")
filter_col,add_view_col,change_view_col,color_col = st.columns([9,1,1,1])
if color_col.button("🎨",help="Changer la couleur"):
change_color_dialog()
if change_view_col.button("🔍",help="Changer de vue"):
change_view_dialog()
#add mots cles to evry label in audit["Mots clés"]
#filter_labels = [ label + " (mot clé)" if label.strip().lower() in audit["Mots clés"].strip().lower().split(",") else label for label in st.session_state.filter_views[st.session_state.current_view] ]
filter = filter_col.multiselect("Filtrer selon l'étiquette",st.session_state.node_types.keys(),placeholder="Sélectionner une ou plusieurs étiquettes",default=st.session_state.filter_views[st.session_state.current_view],label_visibility="collapsed")
if add_view_col.button("➕",help="Ajouter une vue"):
add_view_dialog(filter)
if filter:
nodes = filter_nodes_by_types(nodes,filter)
selected = display_graph(edges,nodes,config)
with col2.container(border=True,height=800):
st.markdown("##### Dialoguer avec le graphe")
user_query = st.chat_input("Par ici ...")
if user_query is not None and user_query != "":
st.session_state.chat_graph_history.append(HumanMessage(content=user_query))
with st.container(height=650, border=False):
for message in st.session_state.chat_graph_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Moi"):
st.write(message.content)
#check if last message is human message
if len(st.session_state.chat_graph_history) > 0:
last_message = st.session_state.chat_graph_history[-1]
if isinstance(last_message, HumanMessage):
with st.chat_message("AI"):
retreive = st.session_state.vectorstore.as_retriever()
context = retreive.invoke(last_message.content)
wrapped_prompt = f"Étant donné le contexte suivant {context}, et le graph de connaissance: {graph}, {last_message.content}"
response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True))
st.session_state.chat_graph_history.append(AIMessage(content=response))
if selected is not None:
with st.chat_message("AI"):
st.markdown(f" EXPLORER LES DONNEES CONTENUES DANS **{selected}**")
prompts = [f"Extrait moi toutes les informations du noeud ''{selected}'' ➡️",
f"Montre moi les conversations autour du noeud ''{selected}'' ➡️"]
for i,prompt in enumerate(prompts):
button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i])))
node_types = st.session_state.node_types
|