Ilyas KHIAT commited on
Commit
6237635
·
1 Parent(s): ea2f004

vue enhance

Browse files
app.py CHANGED
@@ -9,6 +9,7 @@ def main():
9
  st.set_page_config(page_title="RAG Agent", page_icon="🤖", layout="wide")
10
 
11
  audit_page = st.Page("audit_page/audit.py", title="Audit", icon="📋", default=True)
 
12
  kg_page = st.Page("audit_page/knowledge_graph.py", title="Graphe de connaissance", icon="🧠")
13
  agents_page = st.Page("agents_page/catalogue.py", title="Catalogue des agents", icon="📇")
14
  compte_rendu = st.Page("audit_page/compte_rendu.py", title="Compte rendu", icon="📝")
@@ -18,7 +19,7 @@ def main():
18
 
19
  pg = st.navigation(
20
  {
21
- "Audit de contenus": [audit_page, compte_rendu, kg_page],
22
  "Equipe d'agents IA": [recommended_agents],
23
  "Chatbot": [chatbot],
24
  "Documentation": [documentation]
 
9
  st.set_page_config(page_title="RAG Agent", page_icon="🤖", layout="wide")
10
 
11
  audit_page = st.Page("audit_page/audit.py", title="Audit", icon="📋", default=True)
12
+ dialog_page = st.Page("audit_page/dialogue_doc.py", title="Dialoguer avec le document", icon="💬")
13
  kg_page = st.Page("audit_page/knowledge_graph.py", title="Graphe de connaissance", icon="🧠")
14
  agents_page = st.Page("agents_page/catalogue.py", title="Catalogue des agents", icon="📇")
15
  compte_rendu = st.Page("audit_page/compte_rendu.py", title="Compte rendu", icon="📝")
 
19
 
20
  pg = st.navigation(
21
  {
22
+ "Audit de contenus": [audit_page,dialog_page, compte_rendu, kg_page],
23
  "Equipe d'agents IA": [recommended_agents],
24
  "Chatbot": [chatbot],
25
  "Documentation": [documentation]
audit_page/audit.py CHANGED
@@ -111,10 +111,13 @@ def handle_audit(uploaded_file,type:str):
111
  st.session_state.name_file = uploaded_file.name
112
  with st.spinner("Analyse du document..."):
113
  st.session_state.audit = {}
 
114
  st.session_state.audit = audit_descriptif_pdf(uploaded_file,100)
115
  with st.spinner("Préparation de la DB..."):
116
  vectorstore = setup_rag(type,st.session_state.audit["content"])
117
  st.session_state.vectorstore = vectorstore
 
 
118
 
119
  audit = st.session_state.audit["audit"]
120
  #global audit
@@ -136,6 +139,12 @@ def handle_audit(uploaded_file,type:str):
136
  with st.spinner("Analyse de l'audio..."):
137
  st.session_state.audit = {}
138
  st.session_state.audit = evaluate_audio_quality(uploaded_file)
 
 
 
 
 
 
139
  audit = st.session_state.audit["audit"]
140
  #audit global simplifié
141
  audit_simplified = {
@@ -218,6 +227,10 @@ def audit_main():
218
  st.session_state.audit_simplified = {}
219
  if "vectorstore" not in st.session_state:
220
  st.session_state.vectorstore = None
 
 
 
 
221
 
222
  # File uploader
223
  uploaded_file = col1.file_uploader("Télécharger un ou plusieurs documents")
@@ -234,4 +247,7 @@ def audit_main():
234
  display_audit(col1)
235
  handle_display_content(col2)
236
 
 
 
 
237
  audit_main()
 
111
  st.session_state.name_file = uploaded_file.name
112
  with st.spinner("Analyse du document..."):
113
  st.session_state.audit = {}
114
+
115
  st.session_state.audit = audit_descriptif_pdf(uploaded_file,100)
116
  with st.spinner("Préparation de la DB..."):
117
  vectorstore = setup_rag(type,st.session_state.audit["content"])
118
  st.session_state.vectorstore = vectorstore
119
+ st.session_state.graph = None
120
+ st.session_state.cr = ""
121
 
122
  audit = st.session_state.audit["audit"]
123
  #global audit
 
139
  with st.spinner("Analyse de l'audio..."):
140
  st.session_state.audit = {}
141
  st.session_state.audit = evaluate_audio_quality(uploaded_file)
142
+ with st.spinner("Préparation de la DB..."):
143
+ vectorstore = setup_rag(type,st.session_state.audit["content"])
144
+ st.session_state.vectorstore = vectorstore
145
+ st.session_state.graph = None
146
+ st.session_state.cr = ""
147
+
148
  audit = st.session_state.audit["audit"]
149
  #audit global simplifié
150
  audit_simplified = {
 
227
  st.session_state.audit_simplified = {}
228
  if "vectorstore" not in st.session_state:
229
  st.session_state.vectorstore = None
230
+ if "cr" not in st.session_state:
231
+ st.session_state.cr = ""
232
+ if "graph" not in st.session_state:
233
+ st.session_state.graph = None
234
 
235
  # File uploader
236
  uploaded_file = col1.file_uploader("Télécharger un ou plusieurs documents")
 
247
  display_audit(col1)
248
  handle_display_content(col2)
249
 
250
+ #init graph and cr
251
+
252
+
253
  audit_main()
audit_page/dialogue_doc.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from textwrap import dedent
3
+ from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio
4
+ from utils.audit.response_llm import generate_response_via_langchain
5
+ from langchain_core.messages import AIMessage, HumanMessage
6
+ import pyperclip
7
+ from utils.kg.construct_kg import get_graph
8
+ from audit_page.knowledge_graph import *
9
+
10
+ def graph_doc_to_json(graph):
11
+ nodes = []
12
+ edges = []
13
+ for node in graph.nodes:
14
+ node_id = node.id.replace(" ", "_")
15
+ label = node.id
16
+ type = node.type
17
+ nodes.append({"id": node_id, "label": label, "type": type})
18
+ for relationship in graph.relationships:
19
+ source = relationship.source
20
+ source_id = source.id.replace(" ", "_")
21
+ target = relationship.target
22
+ target_id = target.id.replace(" ", "_")
23
+ label = relationship.type
24
+ edges.append({"source": source_id, "label": label, "cible": target_id})
25
+ return {"noeuds": nodes, "relations": edges}
26
+
27
+ def chat_history_formatter(chat_history):
28
+ formatted_chat = ""
29
+ for message in chat_history:
30
+ if isinstance(message, AIMessage):
31
+ formatted_chat += f"AI:\n {message.content}\n\n"
32
+ elif isinstance(message, HumanMessage):
33
+ formatted_chat += f"Human:\n {message.content}\n\n"
34
+ return formatted_chat
35
+
36
+ def doc_dialog_main():
37
+ st.title("Dialogue avec le document")
38
+
39
+ if "audit" not in st.session_state or st.session_state.audit == {}:
40
+ st.error("Veuillez d'abord effectuer un audit pour générer le compte rendu ou le graphe de connaissance.")
41
+ return
42
+
43
+ #init cr and chat history cr
44
+ if "cr" not in st.session_state:
45
+ st.session_state.cr = ""
46
+ if "cr_chat_history" not in st.session_state:
47
+ st.session_state.cr_chat_history = [
48
+ ]
49
+
50
+ #init graph and filter views
51
+ if "graph" not in st.session_state:
52
+ st.session_state.graph = None
53
+
54
+ if "filter_views" not in st.session_state:
55
+ st.session_state.filter_views = {}
56
+ if "current_view" not in st.session_state:
57
+ st.session_state.current_view = None
58
+ if "node_types" not in st.session_state:
59
+ st.session_state.node_types = None
60
+ # if "summary" not in st.session_state:
61
+ # st.session_state.summary = None
62
+ if "chat_graph_history" not in st.session_state:
63
+ st.session_state.chat_graph_history = []
64
+
65
+ if "radio_choice" not in st.session_state:
66
+ st.session_state.radio_choice = None
67
+
68
+ options = ["compte_rendu","graphe de connaissance"]
69
+ choice = st.radio("Choisissez une option",options,index=st.session_state.radio_choice,horizontal=True,label_visibility="collapsed")
70
+ if choice:
71
+ st.session_state.radio_choice = options.index(choice)
72
+
73
+ audit = st.session_state.audit_simplified
74
+ content = st.session_state.audit["content"]
75
+
76
+ if audit["type de fichier"] == "pdf":
77
+ text = get_text_from_content_for_doc(content)
78
+ elif audit["type de fichier"] == "audio":
79
+ text = get_text_from_content_for_audio(content)
80
+
81
+
82
+
83
+ if choice == "compte_rendu":
84
+ if "cr" not in st.session_state or st.session_state.cr == "":
85
+ with st.spinner("Génération du compte rendu..."):
86
+ prompt_cr = dedent(f'''
87
+ À partir du document ci-dessous, générez un compte rendu détaillé contenant les sections suivantes :
88
+
89
+ 2. **Résumé** : Fournissez un résumé concis du document, en mettant en avant les points principaux, les relations essentielles, les concepts , les dates et les lieux, les conclusions et les détails importants.
90
+
91
+ 3. **Notes** :
92
+ - Présentez les points clés sous forme de liste à puces avec des émojis pertinents pour souligner la nature de chaque point.
93
+ - Incluez des sous-points (sans émojis) sous les points principaux pour offrir des détails ou explications supplémentaires.
94
+
95
+ 4. **Actions** : Identifiez et listez les actions spécifiques, tâches ou étapes recommandées ou nécessaires selon le contenu du document.
96
+
97
+ **Document :**
98
+
99
+ {text}
100
+
101
+ **Format de sortie :**
102
+
103
+
104
+ ### Résumé :
105
+ [Fournissez un résumé concis du document ici.]
106
+
107
+ ### Notes :
108
+ - 📌 **Point Principal 1**
109
+ - Sous-point A
110
+ - Sous-point B
111
+ - 📈 **Point Principal 2**
112
+ - Sous-point C
113
+ - Sous-point D
114
+ - 📝 **Point Principal 3**
115
+ - Sous-point E
116
+ - Sous-point F
117
+
118
+ ### Actions :
119
+ 1. [Action 1]
120
+ 2. [Action 2]
121
+ 3. [Action 3]
122
+ 4. ...
123
+
124
+ ---
125
+ ''')
126
+ cr = generate_response_via_langchain(prompt_cr,stream=False,model="gpt-4o")
127
+ st.session_state.cr = cr
128
+ st.session_state.cr_chat_history = []
129
+ else:
130
+ cr = st.session_state.cr
131
+
132
+ if cr:
133
+ col1, col2 = st.columns([2.5, 1.5])
134
+
135
+ with col1.container(border=True,height=800):
136
+ st.markdown("##### Compte rendu")
137
+ with st.container(height=650,border=False):
138
+ keywords_paragraph = f"### Mots clés extraits:\n- {audit['Mots clés'].strip()}"
139
+ st.markdown(keywords_paragraph)
140
+ st.write(cr)
141
+ col_copy , col_success = st.columns([1,11])
142
+ if col_copy.button("📋",key="copy_cr"):
143
+ pyperclip.copy(keywords_paragraph+"\n\n"+cr)
144
+ col_success.success("Compte rendu copié dans le presse-papier")
145
+
146
+ with col2.container(border=True,height=800):
147
+ st.markdown("##### Dialoguer avec le CR")
148
+
149
+ user_query = st.chat_input("Par ici ...")
150
+ if user_query is not None and user_query != "":
151
+ st.session_state.cr_chat_history.append(HumanMessage(content=user_query))
152
+
153
+ with st.container(height=600, border=False):
154
+ for message in st.session_state.cr_chat_history:
155
+ if isinstance(message, AIMessage):
156
+ with st.chat_message("AI"):
157
+ st.markdown(message.content)
158
+ elif isinstance(message, HumanMessage):
159
+ with st.chat_message("Human"):
160
+ st.write(message.content)
161
+
162
+ #check if last message is human message
163
+ if len(st.session_state.cr_chat_history) > 0:
164
+ last_message = st.session_state.cr_chat_history[-1]
165
+ if isinstance(last_message, HumanMessage):
166
+ with st.chat_message("AI"):
167
+ retreive = st.session_state.vectorstore.as_retriever()
168
+ context = retreive.invoke(last_message.content)
169
+ wrapped_prompt = f'''Étant donné le contexte suivant {context} et le compte rendu du document {cr}, {last_message.content}'''
170
+ response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True))
171
+ st.session_state.cr_chat_history.append(AIMessage(content=response))
172
+ col_copy_c , col_success_c = st.columns([1,7])
173
+ if col_copy_c.button("📋",key="copy_cr_chat"):
174
+ chat_formatted = chat_history_formatter(st.session_state.cr_chat_history)
175
+ pyperclip.copy(chat_formatted)
176
+ col_success_c.success("Historique copié !")
177
+
178
+ elif choice == "graphe de connaissance":
179
+ if "graph" not in st.session_state or st.session_state.graph == None:
180
+ with st.spinner("Génération du graphe..."):
181
+ keywords_list = audit["Mots clés"].strip().split(",")
182
+ allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"]
183
+ graph = get_graph(text,allowed_nodes=allowed_nodes_types)
184
+ st.session_state.graph = graph
185
+ st.session_state.filter_views = {}
186
+ st.session_state.current_view = None
187
+ st.session_state.node_types = None
188
+ st.session_state.chat_graph_history = []
189
+
190
+ node_types = get_node_types(graph[0])
191
+ nodes_type_dict = list_to_dict_colors(node_types)
192
+ st.session_state.node_types = nodes_type_dict
193
+ st.session_state.filter_views["Vue par défaut"] = list(node_types)
194
+ st.session_state.current_view = "Vue par défaut"
195
+ else:
196
+ graph = st.session_state.graph
197
+
198
+ if graph is not None:
199
+ #st.write(graph)
200
+
201
+ edges,nodes,config = convert_neo4j_to_agraph(graph[0],st.session_state.node_types)
202
+
203
+ col1, col2 = st.columns([2.5, 1.5])
204
+
205
+ with col1.container(border=True,height=800):
206
+ st.write("##### Visualisation du graphe (**"+st.session_state.current_view+"**)")
207
+ filter_col,add_view_col,change_view_col,color_col = st.columns([9,1,1,1])
208
+
209
+ if color_col.button("🎨",help="Changer la couleur"):
210
+ change_color_dialog()
211
+
212
+ if change_view_col.button("🔍",help="Changer de vue"):
213
+ change_view_dialog()
214
+
215
+
216
+ #add mots cles to evry label in audit["Mots clés"]
217
+ #filter_labels = [ label + " (mot clé)" if label.strip().lower() in audit["Mots clés"].strip().lower().split(",") else label for label in st.session_state.filter_views[st.session_state.current_view] ]
218
+ filter = filter_col.multiselect("Filtrer selon l'étiquette",st.session_state.node_types.keys(),placeholder="Sélectionner une ou plusieurs étiquettes",default=st.session_state.filter_views[st.session_state.current_view],label_visibility="collapsed")
219
+
220
+ if add_view_col.button("➕",help="Ajouter une vue"):
221
+ add_view_dialog(filter)
222
+
223
+ if filter:
224
+ nodes = filter_nodes_by_types(nodes,filter)
225
+
226
+ selected = display_graph(edges,nodes,config)
227
+
228
+ col_copy , col_success = st.columns([1,11])
229
+ if col_copy.button("📋",key="copy_graph"):
230
+ graph_json = graph_doc_to_json(graph[0])
231
+ pyperclip.copy(graph_json)
232
+ col_success.success("Graphe copié dans le presse-papier")
233
+
234
+ with col2.container(border=True,height=800):
235
+ st.markdown("##### Dialoguer avec le graphe")
236
+
237
+ user_query = st.chat_input("Par ici ...")
238
+ if user_query is not None and user_query != "":
239
+ st.session_state.chat_graph_history.append(HumanMessage(content=user_query))
240
+
241
+ with st.container(height=600, border=False):
242
+ for message in st.session_state.chat_graph_history:
243
+ if isinstance(message, AIMessage):
244
+ with st.chat_message("AI"):
245
+ st.markdown(message.content)
246
+ elif isinstance(message, HumanMessage):
247
+ with st.chat_message("Human"):
248
+ st.write(message.content)
249
+
250
+ #check if last message is human message
251
+ if len(st.session_state.chat_graph_history) > 0:
252
+ last_message = st.session_state.chat_graph_history[-1]
253
+ if isinstance(last_message, HumanMessage):
254
+ with st.chat_message("AI"):
255
+ retreive = st.session_state.vectorstore.as_retriever()
256
+ context = retreive.invoke(last_message.content)
257
+ wrapped_prompt = f"Étant donné le contexte suivant {context}, et le graph de connaissance: {graph}, {last_message.content}"
258
+ response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True))
259
+ st.session_state.chat_graph_history.append(AIMessage(content=response))
260
+
261
+ if selected is not None:
262
+ with st.chat_message("AI"):
263
+ st.markdown(f" EXPLORER LES DONNEES CONTENUES DANS **{selected}**")
264
+
265
+ prompts = [f"Extrait moi toutes les informations du noeud ''{selected}'' ➡️",
266
+ f"Montre moi les conversations autour du noeud ''{selected}'' ➡️"]
267
+
268
+ for i,prompt in enumerate(prompts):
269
+ button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i])))
270
+
271
+ col_copy_c , col_success_c = st.columns([1,7])
272
+ if col_copy_c.button("📋",key="copy_graph_chat"):
273
+ chat_formatted = chat_history_formatter(st.session_state.chat_graph_history)
274
+ pyperclip.copy(chat_formatted)
275
+ col_success_c.success("Historique copié !")
276
+
277
+
278
+
279
+
280
+
281
+
282
+
283
+ doc_dialog_main()
284
+
285
+
286
+
287
+
audit_page/knowledge_graph.py CHANGED
@@ -131,7 +131,8 @@ def convert_neo4j_to_agraph(neo4j_graph, node_colors):
131
  def display_graph(edges, nodes, config):
132
  # Display the Agraph visualization
133
  return agraph(edges=edges, nodes=nodes, config=config)
134
-
 
135
 
136
  def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]:
137
  filtered_nodes = []
@@ -152,7 +153,8 @@ def change_view_dialog():
152
  del st.session_state.filter_views[item]
153
  st.session_state.current_view = "Vue par défaut"
154
  st.rerun()
155
- if col3.button("✅", key=f"valid{index}"):
 
156
  st.session_state.current_view = item
157
  st.rerun()
158
  if len(st.session_state.filter_views.keys()) > index:
 
131
  def display_graph(edges, nodes, config):
132
  # Display the Agraph visualization
133
  return agraph(edges=edges, nodes=nodes, config=config)
134
+
135
+
136
 
137
  def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]:
138
  filtered_nodes = []
 
153
  del st.session_state.filter_views[item]
154
  st.session_state.current_view = "Vue par défaut"
155
  st.rerun()
156
+ but_content = "🔍" if st.session_state.current_view != item else "✅"
157
+ if col3.button(but_content, key=f"valid{index}"):
158
  st.session_state.current_view = item
159
  st.rerun()
160
  if len(st.session_state.filter_views.keys()) > index:
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "height": "650px",
3
  "width": "1200px",
4
  "autoResize": true,
5
 
 
1
  {
2
+ "height": "600px",
3
  "width": "1200px",
4
  "autoResize": true,
5
 
utils/kg/barnes_algo.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import numpy as np
3
+ import math
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.animation as animation
6
+ import matplotlib.patches as patches
7
+ import random
8
+
9
+ theta = 0.5
10
+ AU = (149.6e6 * 1000) # 149.6 million km, in meters.
11
+ G = 6.67408e-11 #m^3 kg^-1 s^-2
12
+ fig1 = plt.figure()
13
+ sim = fig1.add_subplot(111, aspect='equal')
14
+ fig2 = plt.figure()
15
+ quadt = fig2.add_subplot(111, aspect='equal')
16
+
17
+ class Node:
18
+ children = None
19
+ mass = None
20
+ center_of_mass = None
21
+ bbox = None
22
+ vx = vy = None
23
+
24
+ def quad_insert(root, x, y, m):
25
+ if root.mass is None: #when the root is empty, add the first particle
26
+ root.mass = m
27
+ root.center_of_mass = [x,y]
28
+ return
29
+ elif root.children is None:
30
+ root.children = [None,None,None,None]
31
+ old_quadrant = quadrant_of_particle(root.bbox, root.center_of_mass[0], root.center_of_mass[1])
32
+ if root.children[old_quadrant] is None:
33
+ root.children[old_quadrant] = Node()
34
+ root.children[old_quadrant].bbox = quadrant_bbox(root.bbox,old_quadrant)
35
+ quad_insert(root.children[old_quadrant], root.center_of_mass[0], root.center_of_mass[1], root.mass)
36
+ new_quadrant = quadrant_of_particle(root.bbox, x, y)
37
+ if root.children[new_quadrant] is None:
38
+ root.children[new_quadrant] = Node()
39
+ root.children[new_quadrant].bbox = quadrant_bbox(root.bbox,new_quadrant)
40
+ quad_insert(root.children[new_quadrant], x, y, m)
41
+ root.center_of_mass[0] = (root.center_of_mass[0]*root.mass + x*m) / (root.mass + m)
42
+ root.center_of_mass[1] = (root.center_of_mass[1]*root.mass + y*m) / (root.mass + m)
43
+ root.mass = root.mass + m
44
+ else:
45
+ new_quadrant = quadrant_of_particle(root.bbox, x, y)
46
+ if root.children[new_quadrant] is None:
47
+ root.children[new_quadrant] = Node()
48
+ root.children[new_quadrant].bbox = quadrant_bbox(root.bbox, new_quadrant)
49
+ quad_insert(root.children[new_quadrant], x, y, m)
50
+ root.center_of_mass[0] = (root.center_of_mass[0]*root.mass + x*m) / (root.mass + m)
51
+ root.center_of_mass[1] = (root.center_of_mass[1]*root.mass + y*m) / (root.mass + m)
52
+ root.mass = root.mass + m
53
+
54
+ def display(root):
55
+ if root.mass is None:
56
+ return
57
+ if root.children is not None:
58
+ x = (root.bbox[0] + root.bbox[1]) / 2
59
+ y = (root.bbox[2] + root.bbox[3]) / 2
60
+ width = x-root.bbox[0]
61
+ plt_node(root.bbox[0], root.bbox[2], width)
62
+ plt_node(root.bbox[0], y, width)
63
+ plt_node(x, root.bbox[2], width)
64
+ plt_node(x, y, width)
65
+ for i in xrange(4):
66
+ if root.children[i] is not None:
67
+ display(root.children[i])
68
+ else:
69
+ quadt.scatter(root.center_of_mass[0], root.center_of_mass[1])
70
+
71
+ def integrate(particles):
72
+ bodies = particles
73
+ n = len(bodies)
74
+ timestep = 24*3600 #one day
75
+ years = 2 * 365 #how many Earth years that simulate
76
+ for day in xrange(years):
77
+ particles_force = {}
78
+ root = Node()
79
+ root.center_of_mass = []
80
+ root.bbox = find_root_bbox(bodies)
81
+ for i in xrange(n):
82
+ quad_insert(root, bodies[i][3], bodies[i][4], bodies[i][2])
83
+ for i in xrange(n):
84
+ total_fx, total_fy = compute_force(root,bodies[i][3],bodies[i][4],bodies[i][2])
85
+ particles_force[bodies[i][0]] = (total_fx, total_fy)
86
+ for i in xrange(n):
87
+ fx, fy = particles_force[bodies[i][0]]
88
+ bodies[i][5] += fx / bodies[i][2] * timestep
89
+ bodies[i][6] += fy / bodies[i][2] * timestep
90
+
91
+ bodies[i][3] += bodies[i][5] * timestep
92
+ bodies[i][4] += bodies[i][6] * timestep
93
+ sim.scatter(bodies[i][3], bodies[i][4], c=bodies[i][1])
94
+ display(root)
95
+ quadt.scatter(root.center_of_mass[0], root.center_of_mass[1], c='red', marker='x')
96
+
97
+ def compute_force(root,x,y,m):
98
+ if root.mass is None:
99
+ return 0, 0
100
+ if root.center_of_mass[0] == x and root.center_of_mass[1] == y and root.mass == m:
101
+ return 0, 0
102
+ d = root.bbox[1]-root.bbox[0]
103
+ r = distance(x,y, root.center_of_mass[0], root.center_of_mass[1])
104
+ if d/r < theta or root.children is None:
105
+ return force(m, x, y, root.mass, root.center_of_mass[0], root.center_of_mass[1])
106
+ else:
107
+ fx = 0.0
108
+ fy = 0.0
109
+ for i in xrange(4):
110
+ if root.children[i] is not None:
111
+ fx += compute_force(root.children[i],x,y,m)[0]
112
+ fy += compute_force(root.children[i],x,y,m)[1]
113
+ return fx, fy
114
+
115
+ ################################################# SUPPORTING FUNCTION ##############################################################
116
+
117
+ def force(m, x, y, mcm, xcm, ycm):
118
+ d = distance(x, y, xcm, ycm)
119
+ f = G*m*mcm/(d**2)
120
+ dx = xcm - x
121
+ dy = ycm - y
122
+ angle = math.atan2(dy, dx)
123
+ fx = math.cos(angle) * f
124
+ fy = math.sin(angle) * f
125
+ return fx, fy
126
+
127
+ def distance(x1, y1, x2, y2):
128
+ return math.sqrt((x2-x1)**2+(y2-y1)**2)
129
+
130
+ def plt_node(x, y, width):
131
+ quadt.add_patch(patches.Rectangle((x, y), width, width, fill = False))
132
+
133
+ def find_root_bbox(array):
134
+ """ Create a suitable square boundary box for the input particles
135
+ """
136
+ if len(array) == 0 or len(array) == 1:
137
+ return None
138
+ xmin, xmax, ymin, ymax = array[0][3], array[0][3], array[0][4], array[0][4]
139
+ for i in xrange(len(array)):
140
+ if array[i][3] > xmax:
141
+ xmax = array[i][3]
142
+ if array[i][3] < xmin:
143
+ xmin = array[i][3]
144
+ if array[i][4] > ymax:
145
+ ymax = array[i][4]
146
+ if array[i][4] < ymin:
147
+ ymin = array[i][4]
148
+ if xmax - xmin == ymax - ymin:
149
+ return xmin, xmax, ymin, ymax
150
+ elif xmax - xmin > ymax - ymin:
151
+ return xmin, xmax, ymin, ymax+(xmax-xmin-ymax+ymin)
152
+ else:
153
+ return xmin, xmax+(ymax-ymin-xmax+xmin), ymin, ymax
154
+
155
+ def quadrant_of_particle(bbox, x, y):
156
+ """Return position of quadrant of the particle (x,y)
157
+ """
158
+ if y >= (bbox[3] + bbox[2])/2:
159
+ if x <= (bbox[1] + bbox[0])/2:
160
+ return 0
161
+ else:
162
+ return 1
163
+ else:
164
+ if x >= (bbox[1] + bbox[0])/2:
165
+ return 2
166
+ else:
167
+ return 3
168
+
169
+ def quadrant_bbox(bbox,quadrant):
170
+ """Return the coordinate of the quadrant
171
+ """
172
+ x = (bbox[0] + bbox[1]) / 2
173
+ y = (bbox[2] + bbox[3]) / 2
174
+ #Quadrant 0: (xmin, x, y, ymax)
175
+ if quadrant == 0:
176
+ return bbox[0], x, y, bbox[3]
177
+ #Quadrant 1: (x, xmax, y, ymax)
178
+ elif quadrant == 1:
179
+ return x, bbox[1], y, bbox[3]
180
+ #Quadrant 2: (x, xmax, ymin, y)
181
+ elif quadrant == 2:
182
+ return x, bbox[1], bbox[2], y
183
+ #Quadrant 3: (xmin, x, ymin, y)
184
+ elif quadrant == 3:
185
+ return bbox[0], x, bbox[2], y
186
+
187
+ def data_from_file(filename, array):
188
+ with open(filename) as f:
189
+ for line in f:
190
+ if line[0] == '#':
191
+ continue
192
+ else:
193
+ name,color,m,x,y,vx,vy = line.split(',')
194
+ array.append([name,color,float(m),float(x)*AU,float(y)*AU,float(vx)*1000,float(vy)*1000])
195
+
196
+ if __name__ == '__main__':
197
+ filename = ('solar-system.txt')
198
+ particles = []
199
+ data_from_file(filename, particles)
200
+ #root = Node()
201
+ #root.center_of_mass = []
202
+ #root.bbox = find_root_bbox(particles)
203
+ #for i in xrange(len(particles)):
204
+ # quad_insert(root, particles[i][3], particles[i][4], particles[i][2])
205
+ #print 'Boundary box: ',root.bbox
206
+ #print 'Total mass: ',root.mass
207
+ #print 'Coordinate of center of mass: ',root.center_of_mass
208
+ #plt.scatter(root.center_of_mass[0], root.center_of_mass[1], c='r', marker='x', s=50)
209
+ #print 'Theta: ', theta
210
+ integrate(particles)
211
+ plt.show()