File size: 13,562 Bytes
f1342ba
1e3f619
 
 
7b897df
 
9b707db
 
c2f2340
1e3f619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b897df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e3f619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b897df
1e3f619
 
 
 
7b897df
1e3f619
 
 
 
 
 
 
7b897df
1e3f619
 
 
7b897df
1e3f619
 
 
 
7b897df
1e3f619
 
 
7b897df
1e3f619
 
 
 
 
 
 
 
5596129
1e3f619
507d746
7b897df
 
 
 
9b707db
6237635
 
7b897df
 
 
 
 
 
 
f1342ba
5596129
 
 
 
 
 
 
 
 
 
 
 
6237635
 
5596129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1342ba
 
 
9b707db
 
 
 
 
 
 
c2f2340
 
 
 
1e3f619
 
5596129
 
 
 
 
c2f2340
f1342ba
 
7b897df
 
 
 
 
 
9b707db
 
1e3f619
 
c2f2340
1e3f619
c2f2340
 
 
 
 
 
7b897df
 
1e3f619
 
7b897df
 
 
 
c80303c
c2f2340
 
 
7b897df
 
 
 
 
5596129
 
7b897df
1e3f619
 
7b897df
1e3f619
 
7b897df
 
 
9b707db
 
 
5596129
 
 
 
 
9b707db
5596129
 
 
 
 
 
 
 
 
 
 
9b707db
 
 
 
 
 
 
 
c2f2340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b707db
 
 
7b897df
 
fec4816
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import streamlit as st
from utils.kg.construct_kg import get_graph
from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio
from streamlit_agraph import agraph, Node, Edge, Config
import random
import math
from utils.audit.response_llm import generate_response_via_langchain
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import PromptTemplate

def if_node_exists(nodes, node_id):
    """
    Check if a node exists in the graph.

    Args:
        graph (dict): A dictionary representing the graph with keys 'nodes' and 'relationships'.
        node_id (str): The id of the node to check.

    Returns:
        return_value: True if the node exists, False otherwise.
    """
    for node in nodes:
        if node.id == node_id:
            return True
    return False

def generate_random_color():
    r = random.randint(180, 255)
    g = random.randint(180, 255)
    b = random.randint(180, 255)
    return (r, g, b)

def rgb_to_hex(rgb):
    return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])

def get_node_types(graph):
    node_types = set()
    for node in graph.nodes:
        node_types.add(node.type)
    for relationship in graph.relationships:
        source = relationship.source
        target = relationship.target
        node_types.add(source.type)
        node_types.add(target.type)
    return node_types

def color_distance(color1, color2):
    # Calculate Euclidean distance between two RGB colors
    return math.sqrt((color1[0] - color2[0]) ** 2 + (color1[1] - color2[1]) ** 2 + (color1[2] - color2[2]) ** 2)

def generate_distinct_colors(num_colors, min_distance=30):
    colors = []
    while len(colors) < num_colors:
        new_color = generate_random_color()
        if all(color_distance(new_color, existing_color) >= min_distance for existing_color in colors):
            colors.append(new_color)
    return [rgb_to_hex(color) for color in colors]

def list_to_dict_colors(node_types:set):

    number_of_colors = len(node_types)
    colors = generate_distinct_colors(number_of_colors)

    node_colors = {}
    for i, node_type in enumerate(node_types):
        node_colors[node_type] = colors[i]
    
    return node_colors


def convert_neo4j_to_agraph(neo4j_graph, node_colors):
    """
    Converts a Neo4j graph into an Agraph format.

    Args:
        neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'.
                            'nodes' is a list of dicts with each dict having 'id' and 'type' keys.
                            'relationships' is a list of dicts with 'source', 'target', and 'type' keys.

    Returns:
        return_value: The Agraph visualization object.
    """
    nodes = []
    edges = []

    # Creating Agraph nodes
    for node in neo4j_graph.nodes:
        # Use the node id as the Agraph node id
        node_id = node.id.replace(" ", "_")  # Replace spaces with underscores for ids
        label = node.id
        type = node.type
        size = 25  # Default size, can be customized
        shape = "circle"  # Default shape, can be customized
        
        # For example purposes, no images are added, but you can set 'image' if needed.
        nodes.append(Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type]))

    # Creating Agraph edges
    for relationship in neo4j_graph.relationships:
        size = 25  # Default size, can be customized
        shape = "circle"  # Default shape, can be customized

        source = relationship.source
        source_type = source.type
        source_id = source.id.replace(" ", "_")
        label_source = source.id
        
        source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type])
        if not if_node_exists(nodes, source_node.id):
            nodes.append(source_node)

        target = relationship.target
        target_type = target.type
        target_id = target.id.replace(" ", "_")
        label_target = target.id
        
        target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type])
        if not if_node_exists(nodes, target_node.id):
            nodes.append(target_node)

        label = relationship.type

        edges.append(Edge(source=source_id, label=label, target=target_id))

    # Define the configuration for Agraph
    config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json")
    # Create the Agraph visualization

    return edges, nodes, config

def display_graph(edges, nodes, config):
    # Display the Agraph visualization
    return agraph(edges=edges, nodes=nodes, config=config)



def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]:
    filtered_nodes = []
    for node in nodes:
        if node.title in node_types_filter: #the title represents the type of the node
            filtered_nodes.append(node)
    return filtered_nodes

@st.dialog(title="Changer la vue")
def change_view_dialog():
    st.write("Changer la vue")
    
    for index, item in enumerate(st.session_state.filter_views.keys()):
        emp = st.empty()
        col1, col2, col3 = emp.columns([8, 1, 1])

        if index > 0 and col2.button("🗑️", key=f"del{index}"):
            del st.session_state.filter_views[item]
            st.session_state.current_view = "Vue par défaut"
            st.rerun()
        but_content = "🔍" if st.session_state.current_view != item else "✅"
        if col3.button(but_content, key=f"valid{index}"):
            st.session_state.current_view = item
            st.rerun()
        if len(st.session_state.filter_views.keys()) > index:
            with col1.expander(item):
                st.markdown("\n".join(f"- {label.strip()}" for label in st.session_state.filter_views[item]))
        else:
            emp.empty()

@st.dialog(title="Ajouter une vue")
def add_view_dialog(filters):
    st.write("Ajouter une vue")
    view_name = st.text_input("Nom de la vue")
    st.markdown("les filtres actuels:")
    st.write(filters)
    if st.button("Ajouter la vue"):
        st.session_state.filter_views[view_name] = filters
        st.session_state.current_view = view_name
        st.rerun()

@st.dialog(title="Changer la couleur")
def change_color_dialog():
    st.write("Changer la couleur")
    for node_type,color in st.session_state.node_types.items():
        color = st.color_picker(f"La couleur de l'entité **{node_type.strip()}**",color)
        st.session_state.node_types[node_type] = color
    
    if st.button("Valider"):
        st.rerun()



def kg_main():
    #st.set_page_config(page_title="Graphe de connaissance", page_icon="", layout="wide")
    

    
    if "audit" not in st.session_state or st.session_state.audit == {}:
        st.error("Veuillez d'abord effectuer un audit pour visualiser le graphe de connaissance.")
        return
    
    if "cr" not in st.session_state:
        st.error("Veuillez d'abord effectuer un compte rendu pour visualiser le graphe de connaissance.")
        return
    
    if "graph" not in st.session_state:
        st.session_state.graph = None
    
    if "filter_views" not in st.session_state:
        st.session_state.filter_views = {}
    if "current_view" not in st.session_state:
        st.session_state.current_view = None

    st.title("Graphe de connaissance")

    if "node_types" not in st.session_state:
        st.session_state.node_types = None
    
    if "summary" not in st.session_state:
        st.session_state.summary = None

    if "chat_graph_history" not in st.session_state:
        st.session_state.chat_graph_history = []
    
    audit = st.session_state.audit_simplified
    # content = st.session_state.audit["content"]

    # if audit["type de fichier"] == "pdf":
    #     text = get_text_from_content_for_doc(content)
    # elif audit["type de fichier"] == "audio":
    #     text = get_text_from_content_for_audio(content)

    text = st.session_state.cr + "mots clés" + audit["Mots clés"]
    
    #summary_prompt = f"Voici un ensemble de documents : {text}. À partir de ces documents, veuillez fournir des résumés concis en vous concentrant sur l'extraction des relations essentielles et des événements. Il est crucial d'inclure les dates des actions ou des événements, car elles seront utilisées pour l'analyse chronologique. Par exemple : 'Sam a été licencié par le conseil d'administration d'OpenAI le 17 novembre 2023 (17 novembre, vendredi)', ce qui illustre la relation entre Sam et OpenAI ainsi que la date de l'événement."

    if st.button("Générer le graphe"):
        # with st.spinner("Extractions des relations..."):
        #     sum = generate_response_openai(summary_prompt,model="gpt-4o")
        #     st.session_state.summary = sum

        with st.spinner("Génération du graphe..."):
            keywords_list = audit["Mots clés"].strip().split(",")
            allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"]
            graph = get_graph(text,allowed_nodes=allowed_nodes_types)
            st.session_state.graph = graph
        
        node_types = get_node_types(graph[0])
        nodes_type_dict = list_to_dict_colors(node_types)
        st.session_state.node_types = nodes_type_dict
        st.session_state.filter_views["Vue par défaut"] = list(node_types)
        st.session_state.current_view = "Vue par défaut"

    else:
        graph = st.session_state.graph

    if graph is not None:
        #st.write(graph)
        
        edges,nodes,config = convert_neo4j_to_agraph(graph[0],st.session_state.node_types)
        
        col1, col2 = st.columns([2.5, 1.5])

        with col1.container(border=True,height=800):
            st.write("##### Visualisation du graphe (**"+st.session_state.current_view+"**)")
            filter_col,add_view_col,change_view_col,color_col = st.columns([9,1,1,1])
            
            if color_col.button("🎨",help="Changer la couleur"):
                change_color_dialog()
                
            if change_view_col.button("🔍",help="Changer de vue"):
                change_view_dialog()

            
            #add mots cles to evry label in audit["Mots clés"]
            #filter_labels = [ label + " (mot clé)" if label.strip().lower() in audit["Mots clés"].strip().lower().split(",") else label for label in st.session_state.filter_views[st.session_state.current_view] ]
            filter = filter_col.multiselect("Filtrer selon l'étiquette",st.session_state.node_types.keys(),placeholder="Sélectionner une ou plusieurs étiquettes",default=st.session_state.filter_views[st.session_state.current_view],label_visibility="collapsed")
            
            if add_view_col.button("➕",help="Ajouter une vue"):
                add_view_dialog(filter)
            
            if filter:
                nodes = filter_nodes_by_types(nodes,filter) 

            selected = display_graph(edges,nodes,config)

        with col2.container(border=True,height=800):
            st.markdown("##### Dialoguer avec le graphe")

            user_query = st.chat_input("Par ici ...")
            if user_query is not None and user_query != "":
                st.session_state.chat_graph_history.append(HumanMessage(content=user_query))

            with st.container(height=650, border=False):
                for message in st.session_state.chat_graph_history:
                    if isinstance(message, AIMessage):
                        with st.chat_message("AI"):
                            st.markdown(message.content)
                    elif isinstance(message, HumanMessage):
                        with st.chat_message("Moi"):
                            st.write(message.content)
                
                #check if last message is human message
                if len(st.session_state.chat_graph_history) > 0:
                    last_message = st.session_state.chat_graph_history[-1]
                    if isinstance(last_message, HumanMessage):
                        with st.chat_message("AI"):
                            retreive = st.session_state.vectorstore.as_retriever()
                            context = retreive.invoke(last_message.content)
                            wrapped_prompt = f"Étant donné le contexte suivant {context}, et le graph de connaissance: {graph}, {last_message.content}"
                            response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True))
                            st.session_state.chat_graph_history.append(AIMessage(content=response))

                if selected is not None:
                        with st.chat_message("AI"):
                            st.markdown(f" EXPLORER LES DONNEES CONTENUES DANS **{selected}**")

                            prompts = [f"Extrait moi toutes les informations du noeud ''{selected}'' ➡️",
                                    f"Montre moi les conversations autour du noeud ''{selected}'' ➡️"]
                            
                            for i,prompt in enumerate(prompts):
                                button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i])))
            
            


    node_types = st.session_state.node_types