Add content recommandation (#17)
Browse files- First commit CQA with Agents (481f3b1453fde4c19018915d101d575b6ea25a3e)
- Connecting to front (088e816846227b694f2d56ca3af739cc010de4bc)
- Update app.py (fd67e156abd0293625d2b73765bda2d3905fa5de)
- agents mode (99e91d83efb40b6cfec5a887f0d464eaffd09431)
- bugfixs (72edd2d9e6ad64e3ecb59505b744cd415b9a6776)
- Update requirements.txt (ae857ef845ac5b3baed5ef7de1e1b8b63874947e)
- Update requirements.txt (25e32e6bdf0ca289bef8617d92ad77d7edeac19f)
- add dora graph recommandation (6b43c8608bbebdffdebdbd315d70c7df60199fab)
- Merge branch 'bugfix/add_dummy_searchs' into feature/graph_recommandation (aa904c191cf4dd783e3ec870883a06746fe52bf8)
- Update .gitignore (8b71e5e7c71f665a8515d8bdbfee913fdaff12f0)
- Update .gitignore (9dd246e7f975322a6be247188bffb7aa0f6d954e)
- add message types (bed4e9bbfb6f7c823789daf54c443f3f27198b45)
- WIP (6a39d69f772f97cef8b0b551a888ace822713753)
- WIP (7da1a3ac2237ded7b9891fdcda32d0674a9b7b4a)
- add steps and fix css for gradio 5 (df5d08d8710beacb04a9c0c281a195c6dd7cc800)
- remove unused code (4ab651938b1c2af7c0d8f155488820e47b42c6c8)
- Update .gitignore (5228f5c0f26f78825d572edbe82200ee3ece6a60)
- fixs (ccd4b9e7b0d2b6d7c0a8b9f2a6513609b4bfe3e1)
- Merge branch 'pr/14' into feature/graph_recommandation (6edd6c2f1c5215b2e3e72b2c36759b934d748606)
- Merge branch 'feature/add_steps_display' into feature/graph_recommandation (89a69e623fa7fec59d135eb57624bcb8d8f67985)
- WIP (49acaf1b850b0914bc1f5d52ceab47d2a22fe944)
- add graphs from ourworldindata (7335378313ce70a4d4ec305a6114e8e6d167af12)
- Merge branch 'main' into feature/graph_recommandation (0c4d82b36b5d6d2f79215460720118c746d88804)
- fix merge (57a1ed70b9a0ad0e48de283e2f044e1e38eac8e1)
- Merge branch 'main' into feature/graph_recommandation (196d79336e51d5deef0215353095954d98d4165b)
- add notification on tabs (484fc0d6f3d80e3fe3afb6ffff20560ef35b6b7c)
- switch owid vectorstore to pinecone (5664fc84569d8e455ce09e2e29c48d7881e126e3)
- remove unused code (fd2ccc64d648490a0ee5acf2159827091d9fc123)
- Move configuration to an other pannel (5c3a3a4b99323e93ebf0c852bc2c2a5401929dae)
- Display images before answering (12c9afe58a2de4b7413bc860f808501c6ee2a6aa)
- Display image in a parallel task (6c5a20c1cab94c503e004acc0991ce4149835a7f)
- Merge branch 'main' into feature/graph_recommandation (b58c53f25d1cc85cfe82bbd452f8d5d11c306da3)
- Add step by step notebook to run steps of langgraphs (a059c938ce111e6673981ad32785d0d4e1c0d177)
- put event handling in separate file (76603dfba448efa3334c2bcb0169f8e1fbd92c60)
- Rerank documents and force summary for policy makers (d562d3805e6230ad8db525229b4bdde42185e721)
- edit prints for logs (9609df9642e477c67d7bf03a83becaef5c3e2b6f)
- Add OpenAlex papers recommandation (c3b815e6e630f551740188fdef719d0df16acd7e)
- Merge branch 'add_openalex_papers' into feature/graph_recommandation (6541df34bd07c0c7e2d1f1c4ffbf03c2778187a3)
- fix merge (d78271b334fb803de4c424420c297cc6983f0c93)
- add time measure (09457a7da47af1e265d9c5c906cecbf2d4586174)
- fix css (22ac4eb7878d7c9e0f3581972f399ad3119dfdde)
- remove unused code (4c4fe76848d84079766d5ec6e94c1e941d5fea01)
- fix answer latency when having multiple sources (40084ba7b8c0424741cfe3d2142eea4d24683c07)
- remove unused code (58bf75084a4d964a486a5a602e81a17e56e1cc82)
- Change number of figures (781788244a88876bdecfc5b3ddbcf65e8bd9ead6)
- front UI change (c9346b33c3251304f5b52f5c837347f10842b87c)
- add owid subtitles (7ec5d9ecfda9f0688da0521126034de80cf9dffc)
- move code from papers in separate file (363fe2eb8548665fb4fe577db1dce7ea682bac8d)
- add search only button (be2863be1c4cfdedca78497d548be321d218c312)
- config in a modal object (7283e6a6e173ab3ced6c21e292e1b6f94e659141)
- few code cleaning (094ee349527297a47eaf9bbb0903651170be47ec)
- update display and fix search only (d396732ee8df5f4aa33c10cca64d6b05d197e4d5)
- Update 20241104 - CQA - StepByStep CQA.ipynb (d7adcaada52ac5feae76017b457623ab308bbfc8)
Co-authored-by: Theo Alves <[email protected]>
- app.py +313 -213
- climateqa/constants.py +22 -1
- climateqa/engine/chains/answer_chitchat.py +5 -1
- climateqa/engine/chains/answer_rag.py +31 -18
- climateqa/engine/chains/chitchat_categorization.py +43 -0
- climateqa/engine/chains/graph_retriever.py +128 -0
- climateqa/engine/chains/intent_categorization.py +10 -6
- climateqa/engine/chains/prompts.py +24 -1
- climateqa/engine/chains/query_transformation.py +10 -2
- climateqa/engine/chains/retrieve_documents.py +223 -81
- climateqa/engine/chains/retrieve_papers.py +95 -0
- climateqa/engine/chains/retriever.py +126 -0
- climateqa/engine/chains/set_defaults.py +13 -0
- climateqa/engine/chains/translation.py +2 -1
- climateqa/engine/graph.py +55 -14
- climateqa/engine/graph_retriever.py +88 -0
- climateqa/engine/keywords.py +3 -1
- climateqa/engine/reranker.py +12 -2
- climateqa/engine/vectorstore.py +8 -2
- climateqa/event_handler.py +123 -0
- climateqa/knowledge/openalex.py +10 -6
- climateqa/knowledge/retriever.py +102 -81
- climateqa/utils.py +13 -0
- front/utils.py +148 -3
- sandbox/20240310 - CQA - Semantic Routing 1.ipynb +0 -0
- sandbox/20240702 - CQA - Graph Functionality.ipynb +0 -0
- sandbox/20241104 - CQA - StepByStep CQA.ipynb +0 -0
- style.css +209 -54
@@ -1,13 +1,12 @@
|
|
1 |
from climateqa.engine.embeddings import get_embeddings_function
|
2 |
embeddings_function = get_embeddings_function()
|
3 |
|
4 |
-
from climateqa.knowledge.openalex import OpenAlex
|
5 |
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
8 |
-
oa = OpenAlex()
|
9 |
|
10 |
import gradio as gr
|
|
|
11 |
import pandas as pd
|
12 |
import numpy as np
|
13 |
import os
|
@@ -29,7 +28,9 @@ from utils import create_user_id
|
|
29 |
|
30 |
from gradio_modal import Modal
|
31 |
|
|
|
32 |
|
|
|
33 |
|
34 |
# ClimateQ&A imports
|
35 |
from climateqa.engine.llm import get_llm
|
@@ -39,13 +40,15 @@ from climateqa.engine.reranker import get_reranker
|
|
39 |
from climateqa.engine.embeddings import get_embeddings_function
|
40 |
from climateqa.engine.chains.prompts import audience_prompts
|
41 |
from climateqa.sample_questions import QUESTIONS
|
42 |
-
from climateqa.constants import POSSIBLE_REPORTS
|
43 |
from climateqa.utils import get_image_from_azure_blob_storage
|
44 |
-
from climateqa.engine.
|
45 |
-
|
46 |
-
from climateqa.engine.
|
|
|
|
|
47 |
|
48 |
-
from
|
49 |
|
50 |
# Load environment variables in local mode
|
51 |
try:
|
@@ -54,6 +57,8 @@ try:
|
|
54 |
except Exception as e:
|
55 |
pass
|
56 |
|
|
|
|
|
57 |
# Set up Gradio Theme
|
58 |
theme = gr.themes.Base(
|
59 |
primary_hue="blue",
|
@@ -104,52 +109,47 @@ CITATION_TEXT = r"""@misc{climateqa,
|
|
104 |
|
105 |
|
106 |
# Create vectorstore and retriever
|
107 |
-
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
108 |
-
|
109 |
-
reranker = get_reranker("large")
|
110 |
-
agent = make_graph_agent(llm,vectorstore,reranker)
|
111 |
|
|
|
|
|
112 |
|
|
|
113 |
|
|
|
|
|
|
|
114 |
|
115 |
-
async def chat(query,history,audience,sources,reports):
|
116 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
117 |
(messages in gradio format, messages in langchain format, source documents)"""
|
118 |
|
119 |
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
120 |
print(f">> NEW QUESTION ({date_now}) : {query}")
|
121 |
|
122 |
-
|
123 |
-
audience_prompt = audience_prompts["children"]
|
124 |
-
elif audience == "General public":
|
125 |
-
audience_prompt = audience_prompts["general"]
|
126 |
-
elif audience == "Experts":
|
127 |
-
audience_prompt = audience_prompts["experts"]
|
128 |
-
else:
|
129 |
-
audience_prompt = audience_prompts["experts"]
|
130 |
|
131 |
# Prepare default values
|
132 |
-
if len(sources) == 0:
|
133 |
-
sources = ["IPCC"]
|
134 |
|
135 |
-
|
136 |
-
|
137 |
|
138 |
-
inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources}
|
139 |
result = agent.astream_events(inputs,version = "v1")
|
140 |
-
|
141 |
-
# path_reformulation = "/logs/reformulation/final_output"
|
142 |
-
# path_keywords = "/logs/keywords/final_output"
|
143 |
-
# path_retriever = "/logs/find_documents/final_output"
|
144 |
-
# path_answer = "/logs/answer/streamed_output_str/-"
|
145 |
|
146 |
docs = []
|
|
|
|
|
147 |
docs_html = ""
|
148 |
output_query = ""
|
149 |
output_language = ""
|
150 |
output_keywords = ""
|
151 |
-
gallery = []
|
152 |
start_streaming = False
|
|
|
153 |
figures = '<div class="figures-container"><p></p> </div>'
|
154 |
|
155 |
steps_display = {
|
@@ -166,36 +166,29 @@ async def chat(query,history,audience,sources,reports):
|
|
166 |
node = event["metadata"]["langgraph_node"]
|
167 |
|
168 |
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
except Exception as e:
|
183 |
-
print(f"Error getting documents: {e}")
|
184 |
-
print(event)
|
185 |
-
|
186 |
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
|
187 |
-
event_description,display_output = steps_display[node]
|
188 |
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
|
189 |
history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
|
190 |
|
191 |
elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
history[-1] = ChatMessage(role="assistant", content = answer_message_content)
|
198 |
-
# history.append(ChatMessage(role="assistant", content = new_message_content))
|
199 |
|
200 |
if event["name"] == "transform_query" and event["event"] =="on_chain_end":
|
201 |
if hasattr(history[-1],"content"):
|
@@ -204,7 +197,7 @@ async def chat(query,history,audience,sources,reports):
|
|
204 |
if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
|
205 |
print("X")
|
206 |
|
207 |
-
yield history,docs_html,output_query,output_language,
|
208 |
|
209 |
except Exception as e:
|
210 |
print(event, "has failed")
|
@@ -232,68 +225,7 @@ async def chat(query,history,audience,sources,reports):
|
|
232 |
print(f"Error logging on Azure Blob Storage: {e}")
|
233 |
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
# image_dict = {}
|
239 |
-
# for i,doc in enumerate(docs):
|
240 |
-
|
241 |
-
# if doc.metadata["chunk_type"] == "image":
|
242 |
-
# try:
|
243 |
-
# key = f"Image {i+1}"
|
244 |
-
# image_path = doc.metadata["image_path"].split("documents/")[1]
|
245 |
-
# img = get_image_from_azure_blob_storage(image_path)
|
246 |
-
|
247 |
-
# # Convert the image to a byte buffer
|
248 |
-
# buffered = BytesIO()
|
249 |
-
# img.save(buffered, format="PNG")
|
250 |
-
# img_str = base64.b64encode(buffered.getvalue()).decode()
|
251 |
-
|
252 |
-
# # Embedding the base64 string in Markdown
|
253 |
-
# markdown_image = f"![Alt text](data:image/png;base64,{img_str})"
|
254 |
-
# image_dict[key] = {"img":img,"md":markdown_image,"short_name": doc.metadata["short_name"],"figure_code":doc.metadata["figure_code"],"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"], "img_str" : img_str}
|
255 |
-
# except Exception as e:
|
256 |
-
# print(f"Skipped adding image {i} because of {e}")
|
257 |
-
|
258 |
-
# if len(image_dict) > 0:
|
259 |
-
|
260 |
-
# gallery = [x["img"] for x in list(image_dict.values())]
|
261 |
-
# img = list(image_dict.values())[0]
|
262 |
-
# img_md = img["md"]
|
263 |
-
# img_caption = img["caption"]
|
264 |
-
# img_code = img["figure_code"]
|
265 |
-
# if img_code != "N/A":
|
266 |
-
# img_name = f"{img['key']} - {img['figure_code']}"
|
267 |
-
# else:
|
268 |
-
# img_name = f"{img['key']}"
|
269 |
-
|
270 |
-
# history.append(ChatMessage(role="assistant", content = f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"))
|
271 |
-
|
272 |
-
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
|
273 |
-
for i, doc in enumerate(docs_figures):
|
274 |
-
if doc.metadata["chunk_type"] == "image":
|
275 |
-
try:
|
276 |
-
key = f"Image {i+1}"
|
277 |
-
|
278 |
-
image_path = doc.metadata["image_path"].split("documents/")[1]
|
279 |
-
img = get_image_from_azure_blob_storage(image_path)
|
280 |
-
|
281 |
-
# Convert the image to a byte buffer
|
282 |
-
buffered = BytesIO()
|
283 |
-
img.save(buffered, format="PNG")
|
284 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
285 |
-
|
286 |
-
figures = figures + make_html_figure_sources(doc, i, img_str)
|
287 |
-
|
288 |
-
gallery.append(img)
|
289 |
-
|
290 |
-
except Exception as e:
|
291 |
-
print(f"Skipped adding image {i} because of {e}")
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
yield history,docs_html,output_query,output_language,gallery, figures#,output_query,output_keywords
|
297 |
|
298 |
|
299 |
def save_feedback(feed: str, user_id):
|
@@ -317,29 +249,9 @@ def log_on_azure(file, logs, share_client):
|
|
317 |
file_client.upload_file(logs)
|
318 |
|
319 |
|
320 |
-
def generate_keywords(query):
|
321 |
-
chain = make_keywords_chain(llm)
|
322 |
-
keywords = chain.invoke(query)
|
323 |
-
keywords = " AND ".join(keywords["keywords"])
|
324 |
-
return keywords
|
325 |
|
326 |
|
327 |
|
328 |
-
papers_cols_widths = {
|
329 |
-
"doc":50,
|
330 |
-
"id":100,
|
331 |
-
"title":300,
|
332 |
-
"doi":100,
|
333 |
-
"publication_year":100,
|
334 |
-
"abstract":500,
|
335 |
-
"rerank_score":100,
|
336 |
-
"is_oa":50,
|
337 |
-
}
|
338 |
-
|
339 |
-
papers_cols = list(papers_cols_widths.keys())
|
340 |
-
papers_cols_widths = list(papers_cols_widths.values())
|
341 |
-
|
342 |
-
|
343 |
# --------------------------------------------------------------------
|
344 |
# Gradio
|
345 |
# --------------------------------------------------------------------
|
@@ -370,10 +282,23 @@ def vote(data: gr.LikeData):
|
|
370 |
else:
|
371 |
print(data)
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
|
375 |
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
|
|
|
|
|
|
|
|
|
376 |
|
|
|
377 |
with gr.Tab("ClimateQ&A"):
|
378 |
|
379 |
with gr.Row(elem_id="chatbot-row"):
|
@@ -396,12 +321,16 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
396 |
|
397 |
with gr.Row(elem_id = "input-message"):
|
398 |
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
|
399 |
-
|
|
|
|
|
|
|
|
|
400 |
|
401 |
-
with gr.Column(scale=
|
402 |
|
403 |
|
404 |
-
with gr.Tabs() as tabs:
|
405 |
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
|
406 |
|
407 |
examples_hidden = gr.Textbox(visible = False)
|
@@ -427,91 +356,210 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
427 |
)
|
428 |
|
429 |
samples.append(group_examples)
|
|
|
|
|
|
|
430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
|
432 |
-
|
433 |
-
|
434 |
-
docs_textbox = gr.State("")
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
|
439 |
-
# with Modal(visible = False) as config_modal:
|
440 |
-
with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
|
441 |
|
442 |
-
|
|
|
443 |
|
444 |
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
|
460 |
-
|
461 |
-
["Children","General public","Experts"],
|
462 |
-
label="Select audience",
|
463 |
-
value="Experts",
|
464 |
-
interactive=True,
|
465 |
-
)
|
466 |
|
467 |
-
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
468 |
-
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
-
|
472 |
-
|
473 |
-
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
|
474 |
-
|
475 |
-
show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True)
|
476 |
-
show_full_size_figures.click(lambda : Modal(visible=True),None,modal)
|
477 |
|
478 |
-
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
|
479 |
-
|
480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
482 |
|
|
|
483 |
|
484 |
|
485 |
#---------------------------------------------------------------------------------------
|
486 |
# OTHER TABS
|
487 |
#---------------------------------------------------------------------------------------
|
488 |
|
|
|
489 |
|
490 |
-
#
|
491 |
-
# gallery_component = gr.Gallery(object_fit='cover')
|
492 |
|
493 |
-
# with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
|
494 |
|
495 |
-
#
|
496 |
-
#
|
497 |
-
#
|
498 |
-
#
|
499 |
-
#
|
500 |
-
#
|
501 |
|
502 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
504 |
-
#
|
505 |
-
#
|
|
|
|
|
|
|
|
|
506 |
|
507 |
-
# with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
|
508 |
-
# papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
|
509 |
|
510 |
-
#
|
511 |
-
#
|
512 |
|
513 |
|
514 |
-
|
515 |
with gr.Tab("About",elem_classes = "max-height other-tabs"):
|
516 |
with gr.Row():
|
517 |
with gr.Column(scale=1):
|
@@ -519,13 +567,15 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
519 |
|
520 |
|
521 |
|
522 |
-
gr.Markdown(
|
523 |
-
|
524 |
-
|
525 |
-
-
|
526 |
-
|
527 |
-
|
528 |
-
|
|
|
|
|
529 |
with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
|
530 |
# # Display citation label and text)
|
531 |
gr.Textbox(
|
@@ -538,25 +588,61 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
538 |
|
539 |
|
540 |
|
541 |
-
def start_chat(query,history):
|
542 |
-
# history = history + [(query,None)]
|
543 |
-
# history = [tuple(x) for x in history]
|
544 |
history = history + [ChatMessage(role="user", content=query)]
|
545 |
-
|
|
|
|
|
|
|
546 |
|
547 |
def finish_chat():
|
548 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
(textbox
|
551 |
-
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
552 |
-
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports],
|
553 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
|
|
|
554 |
)
|
555 |
|
556 |
(examples_hidden
|
557 |
-
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
558 |
-
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports],
|
559 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
|
|
560 |
)
|
561 |
|
562 |
|
@@ -567,9 +653,23 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
567 |
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
568 |
|
569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
|
|
571 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
demo.queue()
|
575 |
|
|
|
1 |
from climateqa.engine.embeddings import get_embeddings_function
|
2 |
embeddings_function = get_embeddings_function()
|
3 |
|
|
|
4 |
from sentence_transformers import CrossEncoder
|
5 |
|
6 |
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
+
from gradio_modal import Modal
|
10 |
import pandas as pd
|
11 |
import numpy as np
|
12 |
import os
|
|
|
28 |
|
29 |
from gradio_modal import Modal
|
30 |
|
31 |
+
from PIL import Image
|
32 |
|
33 |
+
from langchain_core.runnables.schema import StreamEvent
|
34 |
|
35 |
# ClimateQ&A imports
|
36 |
from climateqa.engine.llm import get_llm
|
|
|
40 |
from climateqa.engine.embeddings import get_embeddings_function
|
41 |
from climateqa.engine.chains.prompts import audience_prompts
|
42 |
from climateqa.sample_questions import QUESTIONS
|
43 |
+
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
|
44 |
from climateqa.utils import get_image_from_azure_blob_storage
|
45 |
+
from climateqa.engine.graph import make_graph_agent
|
46 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
47 |
+
from climateqa.engine.chains.retrieve_papers import find_papers
|
48 |
+
|
49 |
+
from front.utils import serialize_docs,process_figures
|
50 |
|
51 |
+
from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
|
52 |
|
53 |
# Load environment variables in local mode
|
54 |
try:
|
|
|
57 |
except Exception as e:
|
58 |
pass
|
59 |
|
60 |
+
import requests
|
61 |
+
|
62 |
# Set up Gradio Theme
|
63 |
theme = gr.themes.Base(
|
64 |
primary_hue="blue",
|
|
|
109 |
|
110 |
|
111 |
# Create vectorstore and retriever
|
112 |
+
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX"))
|
113 |
+
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
|
|
|
|
114 |
|
115 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
116 |
+
reranker = get_reranker("nano")
|
117 |
|
118 |
+
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
119 |
|
120 |
+
def update_config_modal_visibility(config_open):
|
121 |
+
new_config_visibility_status = not config_open
|
122 |
+
return gr.update(visible=new_config_visibility_status), new_config_visibility_status
|
123 |
|
124 |
+
async def chat(query, history, audience, sources, reports, relevant_content_sources, search_only):
|
125 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
126 |
(messages in gradio format, messages in langchain format, source documents)"""
|
127 |
|
128 |
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
129 |
print(f">> NEW QUESTION ({date_now}) : {query}")
|
130 |
|
131 |
+
audience_prompt = init_audience(audience)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# Prepare default values
|
134 |
+
if sources is None or len(sources) == 0:
|
135 |
+
sources = ["IPCC", "IPBES", "IPOS"]
|
136 |
|
137 |
+
if reports is None or len(reports) == 0:
|
138 |
+
reports = []
|
139 |
|
140 |
+
inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources, "relevant_content_sources" : relevant_content_sources, "search_only": search_only}
|
141 |
result = agent.astream_events(inputs,version = "v1")
|
142 |
+
|
|
|
|
|
|
|
|
|
143 |
|
144 |
docs = []
|
145 |
+
used_figures=[]
|
146 |
+
related_contents = []
|
147 |
docs_html = ""
|
148 |
output_query = ""
|
149 |
output_language = ""
|
150 |
output_keywords = ""
|
|
|
151 |
start_streaming = False
|
152 |
+
graphs_html = ""
|
153 |
figures = '<div class="figures-container"><p></p> </div>'
|
154 |
|
155 |
steps_display = {
|
|
|
166 |
node = event["metadata"]["langgraph_node"]
|
167 |
|
168 |
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
|
169 |
+
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
|
170 |
+
|
171 |
+
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
|
172 |
+
|
173 |
+
intent = event["data"]["output"]["intent"]
|
174 |
+
if "language" in event["data"]["output"]:
|
175 |
+
output_language = event["data"]["output"]["language"]
|
176 |
+
else :
|
177 |
+
output_language = "English"
|
178 |
+
history[-1].content = f"Language identified : {output_language} \n Intent identified : {intent}"
|
179 |
+
|
180 |
+
|
|
|
|
|
|
|
|
|
|
|
181 |
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
|
182 |
+
event_description, display_output = steps_display[node]
|
183 |
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
|
184 |
history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
|
185 |
|
186 |
elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
|
187 |
+
history, start_streaming, answer_message_content = stream_answer(history, event, start_streaming, answer_message_content)
|
188 |
+
|
189 |
+
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
|
190 |
+
graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
|
191 |
+
|
|
|
|
|
192 |
|
193 |
if event["name"] == "transform_query" and event["event"] =="on_chain_end":
|
194 |
if hasattr(history[-1],"content"):
|
|
|
197 |
if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
|
198 |
print("X")
|
199 |
|
200 |
+
yield history, docs_html, output_query, output_language, related_contents , graphs_html, #,output_query,output_keywords
|
201 |
|
202 |
except Exception as e:
|
203 |
print(event, "has failed")
|
|
|
225 |
print(f"Error logging on Azure Blob Storage: {e}")
|
226 |
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
227 |
|
228 |
+
yield history, docs_html, output_query, output_language, related_contents, graphs_html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
|
231 |
def save_feedback(feed: str, user_id):
|
|
|
249 |
file_client.upload_file(logs)
|
250 |
|
251 |
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
# --------------------------------------------------------------------
|
256 |
# Gradio
|
257 |
# --------------------------------------------------------------------
|
|
|
282 |
else:
|
283 |
print(data)
|
284 |
|
285 |
+
def save_graph(saved_graphs_state, embedding, category):
|
286 |
+
print(f"\nCategory:\n{saved_graphs_state}\n")
|
287 |
+
if category not in saved_graphs_state:
|
288 |
+
saved_graphs_state[category] = []
|
289 |
+
if embedding not in saved_graphs_state[category]:
|
290 |
+
saved_graphs_state[category].append(embedding)
|
291 |
+
return saved_graphs_state, gr.Button("Graph Saved")
|
292 |
+
|
293 |
|
294 |
|
295 |
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
|
296 |
+
chat_completed_state = gr.State(0)
|
297 |
+
current_graphs = gr.State([])
|
298 |
+
saved_graphs = gr.State({})
|
299 |
+
config_open = gr.State(False)
|
300 |
|
301 |
+
|
302 |
with gr.Tab("ClimateQ&A"):
|
303 |
|
304 |
with gr.Row(elem_id="chatbot-row"):
|
|
|
321 |
|
322 |
with gr.Row(elem_id = "input-message"):
|
323 |
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
|
324 |
+
|
325 |
+
config_button = gr.Button("",elem_id="config-button")
|
326 |
+
# config_checkbox_button = gr.Checkbox(label = '⚙️', value="show",visible=True, interactive=True, elem_id="checkbox-config")
|
327 |
+
|
328 |
+
|
329 |
|
330 |
+
with gr.Column(scale=2, variant="panel",elem_id = "right-panel"):
|
331 |
|
332 |
|
333 |
+
with gr.Tabs(elem_id = "right_panel_tab") as tabs:
|
334 |
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
|
335 |
|
336 |
examples_hidden = gr.Textbox(visible = False)
|
|
|
356 |
)
|
357 |
|
358 |
samples.append(group_examples)
|
359 |
+
|
360 |
+
# with gr.Tab("Configuration", id = 10, ) as tab_config:
|
361 |
+
# # gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
|
362 |
|
363 |
+
# pass
|
364 |
+
|
365 |
+
# with gr.Row():
|
366 |
+
|
367 |
+
# dropdown_sources = gr.CheckboxGroup(
|
368 |
+
# ["IPCC", "IPBES","IPOS"],
|
369 |
+
# label="Select source",
|
370 |
+
# value=["IPCC"],
|
371 |
+
# interactive=True,
|
372 |
+
# )
|
373 |
+
# dropdown_external_sources = gr.CheckboxGroup(
|
374 |
+
# ["IPCC figures","OpenAlex", "OurWorldInData"],
|
375 |
+
# label="Select database to search for relevant content",
|
376 |
+
# value=["IPCC figures"],
|
377 |
+
# interactive=True,
|
378 |
+
# )
|
379 |
+
|
380 |
+
# dropdown_reports = gr.Dropdown(
|
381 |
+
# POSSIBLE_REPORTS,
|
382 |
+
# label="Or select specific reports",
|
383 |
+
# multiselect=True,
|
384 |
+
# value=None,
|
385 |
+
# interactive=True,
|
386 |
+
# )
|
387 |
+
|
388 |
+
# search_only = gr.Checkbox(label="Search only without chating", value=False, interactive=True, elem_id="checkbox-chat")
|
389 |
+
|
390 |
+
|
391 |
+
# dropdown_audience = gr.Dropdown(
|
392 |
+
# ["Children","General public","Experts"],
|
393 |
+
# label="Select audience",
|
394 |
+
# value="Experts",
|
395 |
+
# interactive=True,
|
396 |
+
# )
|
397 |
+
|
398 |
+
|
399 |
+
# after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False)
|
400 |
+
|
401 |
|
402 |
+
# output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
|
403 |
+
# output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
|
|
|
|
|
|
|
|
|
404 |
|
|
|
|
|
405 |
|
406 |
+
# dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after])
|
407 |
+
# # dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after], visible=True)
|
408 |
|
409 |
|
410 |
+
with gr.Tab("Sources",elem_id = "tab-sources",id = 1) as tab_sources:
|
411 |
+
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
412 |
+
|
413 |
+
|
414 |
+
|
415 |
+
with gr.Tab("Recommended content", elem_id="tab-recommended_content",id=2) as tab_recommended_content:
|
416 |
+
with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
|
417 |
+
|
418 |
+
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
|
419 |
+
sources_raw = gr.State()
|
420 |
+
|
421 |
+
with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
|
422 |
+
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
|
423 |
+
|
424 |
+
show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True)
|
425 |
+
show_full_size_figures.click(lambda : Modal(visible=True),None,figure_modal)
|
426 |
+
|
427 |
+
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
|
428 |
+
|
429 |
+
|
430 |
+
|
431 |
+
with gr.Tab("Papers",elem_id = "tab-citations",id = 4) as tab_papers:
|
432 |
+
# btn_summary = gr.Button("Summary")
|
433 |
+
# Fenêtre simulée pour le Summary
|
434 |
+
with gr.Accordion(visible=True, elem_id="papers-summary-popup", label= "See summary of relevant papers", open= False) as summary_popup:
|
435 |
+
papers_summary = gr.Markdown("", visible=True, elem_id="papers-summary")
|
436 |
+
|
437 |
+
# btn_relevant_papers = gr.Button("Relevant papers")
|
438 |
+
# Fenêtre simulée pour les Relevant Papers
|
439 |
+
with gr.Accordion(visible=True, elem_id="papers-relevant-popup",label= "See relevant papers", open= False) as relevant_popup:
|
440 |
+
papers_html = gr.HTML(show_label=False, elem_id="papers-textbox")
|
441 |
+
|
442 |
+
btn_citations_network = gr.Button("Explore papers citations network")
|
443 |
+
# Fenêtre simulée pour le Citations Network
|
444 |
+
with Modal(visible=False) as papers_modal:
|
445 |
+
citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
|
446 |
+
btn_citations_network.click(lambda: Modal(visible=True), None, papers_modal)
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
|
451 |
+
|
452 |
+
graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",elem_id="graphs-container")
|
453 |
+
current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container])
|
454 |
+
|
455 |
+
with Modal(visible=False,elem_id="modal-config") as config_modal:
|
456 |
+
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
|
457 |
|
458 |
+
|
459 |
+
# with gr.Row():
|
460 |
+
|
461 |
+
dropdown_sources = gr.CheckboxGroup(
|
462 |
+
["IPCC", "IPBES","IPOS"],
|
463 |
+
label="Select source (by default search in all sources)",
|
464 |
+
value=["IPCC"],
|
465 |
+
interactive=True,
|
466 |
+
)
|
467 |
+
|
468 |
+
dropdown_reports = gr.Dropdown(
|
469 |
+
POSSIBLE_REPORTS,
|
470 |
+
label="Or select specific reports",
|
471 |
+
multiselect=True,
|
472 |
+
value=None,
|
473 |
+
interactive=True,
|
474 |
+
)
|
475 |
+
|
476 |
+
dropdown_external_sources = gr.CheckboxGroup(
|
477 |
+
["IPCC figures","OpenAlex", "OurWorldInData"],
|
478 |
+
label="Select database to search for relevant content",
|
479 |
+
value=["IPCC figures"],
|
480 |
+
interactive=True,
|
481 |
+
)
|
482 |
|
483 |
+
search_only = gr.Checkbox(label="Search only for recommended content without chating", value=False, interactive=True, elem_id="checkbox-chat")
|
|
|
|
|
|
|
|
|
|
|
484 |
|
|
|
|
|
485 |
|
486 |
+
dropdown_audience = gr.Dropdown(
|
487 |
+
["Children","General public","Experts"],
|
488 |
+
label="Select audience",
|
489 |
+
value="Experts",
|
490 |
+
interactive=True,
|
491 |
+
)
|
492 |
+
|
493 |
+
|
494 |
+
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False)
|
495 |
+
|
496 |
|
497 |
+
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
|
498 |
+
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
|
|
|
|
|
|
|
|
|
499 |
|
|
|
|
|
500 |
|
501 |
+
dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after])
|
502 |
+
|
503 |
+
close_config_modal = gr.Button("Validate and Close",elem_id="close-config-modal")
|
504 |
+
close_config_modal.click(fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open])
|
505 |
+
# dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after], visible=True)
|
506 |
+
|
507 |
|
508 |
+
|
509 |
+
config_button.click(fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open])
|
510 |
+
|
511 |
+
# with gr.Tab("OECD",elem_id = "tab-oecd",id = 6):
|
512 |
+
# oecd_indicator = "RIVER_FLOOD_RP100_POP_SH"
|
513 |
+
# oecd_topic = "climate"
|
514 |
+
# oecd_latitude = "46.8332"
|
515 |
+
# oecd_longitude = "5.3725"
|
516 |
+
# oecd_zoom = "5.6442"
|
517 |
+
# # Create the HTML content with the iframe
|
518 |
+
# iframe_html = f"""
|
519 |
+
# <iframe src="https://localdataportal.oecd.org/maps.html?indicator={oecd_indicator}&topic={oecd_topic}&latitude={oecd_latitude}&longitude={oecd_longitude}&zoom={oecd_zoom}"
|
520 |
+
# width="100%" height="600" frameborder="0" style="border:0;" allowfullscreen></iframe>
|
521 |
+
# """
|
522 |
+
# oecd_textbox = gr.HTML(iframe_html, show_label=False, elem_id="oecd-textbox")
|
523 |
|
524 |
+
|
525 |
|
526 |
|
527 |
#---------------------------------------------------------------------------------------
|
528 |
# OTHER TABS
|
529 |
#---------------------------------------------------------------------------------------
|
530 |
|
531 |
+
# with gr.Tab("Settings",elem_id = "tab-config",id = 2):
|
532 |
|
533 |
+
# gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
|
|
|
534 |
|
|
|
535 |
|
536 |
+
# dropdown_sources = gr.CheckboxGroup(
|
537 |
+
# ["IPCC", "IPBES","IPOS", "OpenAlex"],
|
538 |
+
# label="Select source",
|
539 |
+
# value=["IPCC"],
|
540 |
+
# interactive=True,
|
541 |
+
# )
|
542 |
|
543 |
+
# dropdown_reports = gr.Dropdown(
|
544 |
+
# POSSIBLE_REPORTS,
|
545 |
+
# label="Or select specific reports",
|
546 |
+
# multiselect=True,
|
547 |
+
# value=None,
|
548 |
+
# interactive=True,
|
549 |
+
# )
|
550 |
|
551 |
+
# dropdown_audience = gr.Dropdown(
|
552 |
+
# ["Children","General public","Experts"],
|
553 |
+
# label="Select audience",
|
554 |
+
# value="Experts",
|
555 |
+
# interactive=True,
|
556 |
+
# )
|
557 |
|
|
|
|
|
558 |
|
559 |
+
# output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
560 |
+
# output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
561 |
|
562 |
|
|
|
563 |
with gr.Tab("About",elem_classes = "max-height other-tabs"):
|
564 |
with gr.Row():
|
565 |
with gr.Column(scale=1):
|
|
|
567 |
|
568 |
|
569 |
|
570 |
+
gr.Markdown(
|
571 |
+
"""
|
572 |
+
### More info
|
573 |
+
- See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
|
574 |
+
- Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
|
575 |
+
|
576 |
+
### Citation
|
577 |
+
"""
|
578 |
+
)
|
579 |
with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
|
580 |
# # Display citation label and text)
|
581 |
gr.Textbox(
|
|
|
588 |
|
589 |
|
590 |
|
591 |
+
def start_chat(query,history,search_only):
|
|
|
|
|
592 |
history = history + [ChatMessage(role="user", content=query)]
|
593 |
+
if search_only:
|
594 |
+
return (gr.update(interactive = False),gr.update(selected=1),history)
|
595 |
+
else:
|
596 |
+
return (gr.update(interactive = False),gr.update(selected=2),history)
|
597 |
|
598 |
def finish_chat():
|
599 |
+
return gr.update(interactive = True,value = "")
|
600 |
+
|
601 |
+
# Initialize visibility states
|
602 |
+
summary_visible = False
|
603 |
+
relevant_visible = False
|
604 |
+
|
605 |
+
# Functions to toggle visibility
|
606 |
+
def toggle_summary_visibility():
|
607 |
+
global summary_visible
|
608 |
+
summary_visible = not summary_visible
|
609 |
+
return gr.update(visible=summary_visible)
|
610 |
+
|
611 |
+
def toggle_relevant_visibility():
|
612 |
+
global relevant_visible
|
613 |
+
relevant_visible = not relevant_visible
|
614 |
+
return gr.update(visible=relevant_visible)
|
615 |
+
|
616 |
|
617 |
+
def change_completion_status(current_state):
|
618 |
+
current_state = 1 - current_state
|
619 |
+
return current_state
|
620 |
+
|
621 |
+
def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
|
622 |
+
sources_number = sources_textbox.count("<h2>")
|
623 |
+
figures_number = figures_cards.count("<h2>")
|
624 |
+
graphs_number = current_graphs.count("<iframe")
|
625 |
+
papers_number = papers_html.count("<h2>")
|
626 |
+
sources_notif_label = f"Sources ({sources_number})"
|
627 |
+
figures_notif_label = f"Figures ({figures_number})"
|
628 |
+
graphs_notif_label = f"Graphs ({graphs_number})"
|
629 |
+
papers_notif_label = f"Papers ({papers_number})"
|
630 |
+
recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
|
631 |
+
|
632 |
+
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
|
633 |
+
|
634 |
(textbox
|
635 |
+
.submit(start_chat, [textbox,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
636 |
+
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
|
637 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
|
638 |
+
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
|
639 |
)
|
640 |
|
641 |
(examples_hidden
|
642 |
+
.change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
643 |
+
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
|
644 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
645 |
+
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
|
646 |
)
|
647 |
|
648 |
|
|
|
653 |
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
654 |
|
655 |
|
656 |
+
sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
|
657 |
+
|
658 |
+
# update sources numbers
|
659 |
+
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
660 |
+
figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
661 |
+
current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
662 |
+
papers_html.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
663 |
|
664 |
+
# other questions examples
|
665 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
666 |
|
667 |
+
# search for papers
|
668 |
+
textbox.submit(find_papers,[textbox,after, dropdown_external_sources], [papers_html,citations_network,papers_summary])
|
669 |
+
examples_hidden.change(find_papers,[examples_hidden,after,dropdown_external_sources], [papers_html,citations_network,papers_summary])
|
670 |
+
|
671 |
+
# btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
|
672 |
+
# btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
|
673 |
|
674 |
demo.queue()
|
675 |
|
@@ -42,4 +42,25 @@ POSSIBLE_REPORTS = [
|
|
42 |
"IPBES IAS A C5",
|
43 |
"IPBES IAS A C6",
|
44 |
"IPBES IAS A SPM"
|
45 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
"IPBES IAS A C5",
|
43 |
"IPBES IAS A C6",
|
44 |
"IPBES IAS A SPM"
|
45 |
+
]
|
46 |
+
|
47 |
+
OWID_CATEGORIES = ['Access to Energy', 'Agricultural Production',
|
48 |
+
'Agricultural Regulation & Policy', 'Air Pollution',
|
49 |
+
'Animal Welfare', 'Antibiotics', 'Biodiversity', 'Biofuels',
|
50 |
+
'Biological & Chemical Weapons', 'CO2 & Greenhouse Gas Emissions',
|
51 |
+
'COVID-19', 'Clean Water', 'Clean Water & Sanitation',
|
52 |
+
'Climate Change', 'Crop Yields', 'Diet Compositions',
|
53 |
+
'Electricity', 'Electricity Mix', 'Energy', 'Energy Efficiency',
|
54 |
+
'Energy Prices', 'Environmental Impacts of Food Production',
|
55 |
+
'Environmental Protection & Regulation', 'Famines', 'Farm Size',
|
56 |
+
'Fertilizers', 'Fish & Overfishing', 'Food Supply', 'Food Trade',
|
57 |
+
'Food Waste', 'Food and Agriculture', 'Forests & Deforestation',
|
58 |
+
'Fossil Fuels', 'Future Population Growth',
|
59 |
+
'Hunger & Undernourishment', 'Indoor Air Pollution', 'Land Use',
|
60 |
+
'Land Use & Yields in Agriculture', 'Lead Pollution',
|
61 |
+
'Meat & Dairy Production', 'Metals & Minerals',
|
62 |
+
'Natural Disasters', 'Nuclear Energy', 'Nuclear Weapons',
|
63 |
+
'Oil Spills', 'Outdoor Air Pollution', 'Ozone Layer', 'Pandemics',
|
64 |
+
'Pesticides', 'Plastic Pollution', 'Renewable Energy', 'Soil',
|
65 |
+
'Transport', 'Urbanization', 'Waste Management', 'Water Pollution',
|
66 |
+
'Water Use & Stress', 'Wildfires']
|
@@ -45,8 +45,12 @@ def make_chitchat_node(llm):
|
|
45 |
chitchat_chain = make_chitchat_chain(llm)
|
46 |
|
47 |
async def answer_chitchat(state,config):
|
|
|
|
|
48 |
answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
|
49 |
-
|
|
|
|
|
50 |
|
51 |
return answer_chitchat
|
52 |
|
|
|
45 |
chitchat_chain = make_chitchat_chain(llm)
|
46 |
|
47 |
async def answer_chitchat(state,config):
|
48 |
+
print("---- Answer chitchat ----")
|
49 |
+
|
50 |
answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
|
51 |
+
state["answer"] = answer
|
52 |
+
return state
|
53 |
+
# return {"answer":answer}
|
54 |
|
55 |
return answer_chitchat
|
56 |
|
@@ -7,6 +7,9 @@ from langchain_core.prompts.base import format_document
|
|
7 |
|
8 |
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
9 |
from climateqa.engine.chains.prompts import papers_prompt_template
|
|
|
|
|
|
|
10 |
|
11 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
12 |
|
@@ -40,6 +43,7 @@ def make_rag_chain(llm):
|
|
40 |
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
|
41 |
chain = ({
|
42 |
"context":lambda x : _combine_documents(x["documents"]),
|
|
|
43 |
"query":itemgetter("query"),
|
44 |
"language":itemgetter("language"),
|
45 |
"audience":itemgetter("audience"),
|
@@ -51,7 +55,6 @@ def make_rag_chain_without_docs(llm):
|
|
51 |
chain = prompt | llm | StrOutputParser()
|
52 |
return chain
|
53 |
|
54 |
-
|
55 |
def make_rag_node(llm,with_docs = True):
|
56 |
|
57 |
if with_docs:
|
@@ -60,7 +63,17 @@ def make_rag_node(llm,with_docs = True):
|
|
60 |
rag_chain = make_rag_chain_without_docs(llm)
|
61 |
|
62 |
async def answer_rag(state,config):
|
|
|
|
|
|
|
63 |
answer = await rag_chain.ainvoke(state,config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
return {"answer":answer}
|
65 |
|
66 |
return answer_rag
|
@@ -68,32 +81,32 @@ def make_rag_node(llm,with_docs = True):
|
|
68 |
|
69 |
|
70 |
|
71 |
-
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
|
80 |
-
|
81 |
|
82 |
-
|
83 |
|
84 |
|
85 |
|
86 |
|
87 |
|
88 |
|
89 |
-
|
90 |
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
|
98 |
-
|
99 |
-
|
|
|
7 |
|
8 |
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
9 |
from climateqa.engine.chains.prompts import papers_prompt_template
|
10 |
+
import time
|
11 |
+
from ..utils import rename_chain, pass_values
|
12 |
+
|
13 |
|
14 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
15 |
|
|
|
43 |
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
|
44 |
chain = ({
|
45 |
"context":lambda x : _combine_documents(x["documents"]),
|
46 |
+
"context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
|
47 |
"query":itemgetter("query"),
|
48 |
"language":itemgetter("language"),
|
49 |
"audience":itemgetter("audience"),
|
|
|
55 |
chain = prompt | llm | StrOutputParser()
|
56 |
return chain
|
57 |
|
|
|
58 |
def make_rag_node(llm,with_docs = True):
|
59 |
|
60 |
if with_docs:
|
|
|
63 |
rag_chain = make_rag_chain_without_docs(llm)
|
64 |
|
65 |
async def answer_rag(state,config):
|
66 |
+
print("---- Answer RAG ----")
|
67 |
+
start_time = time.time()
|
68 |
+
|
69 |
answer = await rag_chain.ainvoke(state,config)
|
70 |
+
|
71 |
+
end_time = time.time()
|
72 |
+
elapsed_time = end_time - start_time
|
73 |
+
print("RAG elapsed time: ", elapsed_time)
|
74 |
+
print("Answer size : ", len(answer))
|
75 |
+
# print(f"\n\nAnswer:\n{answer}")
|
76 |
+
|
77 |
return {"answer":answer}
|
78 |
|
79 |
return answer_rag
|
|
|
81 |
|
82 |
|
83 |
|
84 |
+
def make_rag_papers_chain(llm):
|
85 |
|
86 |
+
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
|
87 |
+
input_documents = {
|
88 |
+
"context":lambda x : _combine_documents(x["docs"]),
|
89 |
+
**pass_values(["question","language"])
|
90 |
+
}
|
91 |
|
92 |
+
chain = input_documents | prompt | llm | StrOutputParser()
|
93 |
+
chain = rename_chain(chain,"answer")
|
94 |
|
95 |
+
return chain
|
96 |
|
97 |
|
98 |
|
99 |
|
100 |
|
101 |
|
102 |
+
def make_illustration_chain(llm):
|
103 |
|
104 |
+
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
|
105 |
|
106 |
+
input_description_images = {
|
107 |
+
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
|
108 |
+
**pass_values(["question","audience","language","answer"]),
|
109 |
+
}
|
110 |
|
111 |
+
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
|
112 |
+
return illustration_chain
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
+
from typing import List
|
4 |
+
from typing import Literal
|
5 |
+
from langchain.prompts import ChatPromptTemplate
|
6 |
+
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
+
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
+
|
9 |
+
|
10 |
+
class IntentCategorizer(BaseModel):
|
11 |
+
"""Analyzing the user message input"""
|
12 |
+
|
13 |
+
environment: bool = Field(
|
14 |
+
description="Return 'True' if the question relates to climate change, the environment, nature, etc. (Example: should I eat fish?). Return 'False' if the question is just chit chat or not related to the environment or climate change.",
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def make_chitchat_intent_categorization_chain(llm):
|
19 |
+
|
20 |
+
openai_functions = [convert_to_openai_function(IntentCategorizer)]
|
21 |
+
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
22 |
+
|
23 |
+
prompt = ChatPromptTemplate.from_messages([
|
24 |
+
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
|
25 |
+
("user", "input: {input}")
|
26 |
+
])
|
27 |
+
|
28 |
+
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
29 |
+
return chain
|
30 |
+
|
31 |
+
|
32 |
+
def make_chitchat_intent_categorization_node(llm):
|
33 |
+
|
34 |
+
categorization_chain = make_chitchat_intent_categorization_chain(llm)
|
35 |
+
|
36 |
+
def categorize_message(state):
|
37 |
+
output = categorization_chain.invoke({"input": state["user_input"]})
|
38 |
+
print(f"\n\nChit chat output intent categorization: {output}\n")
|
39 |
+
state["search_graphs_chitchat"] = output["environment"]
|
40 |
+
print(f"\n\nChit chat output intent categorization: {state}\n")
|
41 |
+
return state
|
42 |
+
|
43 |
+
return categorize_message
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from contextlib import contextmanager
|
4 |
+
|
5 |
+
from ..reranker import rerank_docs
|
6 |
+
from ..graph_retriever import retrieve_graphs # GraphRetriever
|
7 |
+
from ...utils import remove_duplicates_keep_highest_score
|
8 |
+
|
9 |
+
|
10 |
+
def divide_into_parts(target, parts):
|
11 |
+
# Base value for each part
|
12 |
+
base = target // parts
|
13 |
+
# Remainder to distribute
|
14 |
+
remainder = target % parts
|
15 |
+
# List to hold the result
|
16 |
+
result = []
|
17 |
+
|
18 |
+
for i in range(parts):
|
19 |
+
if i < remainder:
|
20 |
+
# These parts get base value + 1
|
21 |
+
result.append(base + 1)
|
22 |
+
else:
|
23 |
+
# The rest get the base value
|
24 |
+
result.append(base)
|
25 |
+
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
@contextmanager
|
30 |
+
def suppress_output():
|
31 |
+
# Open a null device
|
32 |
+
with open(os.devnull, 'w') as devnull:
|
33 |
+
# Store the original stdout and stderr
|
34 |
+
old_stdout = sys.stdout
|
35 |
+
old_stderr = sys.stderr
|
36 |
+
# Redirect stdout and stderr to the null device
|
37 |
+
sys.stdout = devnull
|
38 |
+
sys.stderr = devnull
|
39 |
+
try:
|
40 |
+
yield
|
41 |
+
finally:
|
42 |
+
# Restore stdout and stderr
|
43 |
+
sys.stdout = old_stdout
|
44 |
+
sys.stderr = old_stderr
|
45 |
+
|
46 |
+
|
47 |
+
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
48 |
+
|
49 |
+
async def node_retrieve_graphs(state):
|
50 |
+
print("---- Retrieving graphs ----")
|
51 |
+
|
52 |
+
POSSIBLE_SOURCES = ["IEA", "OWID"]
|
53 |
+
questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
|
54 |
+
# sources_input = state["sources_input"]
|
55 |
+
sources_input = ["auto"]
|
56 |
+
|
57 |
+
auto_mode = "auto" in sources_input
|
58 |
+
|
59 |
+
# There are several options to get the final top k
|
60 |
+
# Option 1 - Get 100 documents by question and rerank by question
|
61 |
+
# Option 2 - Get 100/n documents by question and rerank the total
|
62 |
+
if rerank_by_question:
|
63 |
+
k_by_question = divide_into_parts(k_final,len(questions))
|
64 |
+
|
65 |
+
docs = []
|
66 |
+
|
67 |
+
for i,q in enumerate(questions):
|
68 |
+
|
69 |
+
question = q["question"] if isinstance(q, dict) else q
|
70 |
+
|
71 |
+
print(f"Subquestion {i}: {question}")
|
72 |
+
|
73 |
+
# If auto mode, we use all sources
|
74 |
+
if auto_mode:
|
75 |
+
sources = POSSIBLE_SOURCES
|
76 |
+
# Otherwise, we use the config
|
77 |
+
else:
|
78 |
+
sources = sources_input
|
79 |
+
|
80 |
+
if any([x in POSSIBLE_SOURCES for x in sources]):
|
81 |
+
|
82 |
+
sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
83 |
+
|
84 |
+
# Search the document store using the retriever
|
85 |
+
docs_question = await retrieve_graphs(
|
86 |
+
query = question,
|
87 |
+
vectorstore = vectorstore,
|
88 |
+
sources = sources,
|
89 |
+
k_total = k_before_reranking,
|
90 |
+
threshold = 0.5,
|
91 |
+
)
|
92 |
+
# docs_question = retriever.get_relevant_documents(question)
|
93 |
+
|
94 |
+
# Rerank
|
95 |
+
if reranker is not None and docs_question!=[]:
|
96 |
+
with suppress_output():
|
97 |
+
docs_question = rerank_docs(reranker,docs_question,question)
|
98 |
+
else:
|
99 |
+
# Add a default reranking score
|
100 |
+
for doc in docs_question:
|
101 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
102 |
+
|
103 |
+
# If rerank by question we select the top documents for each question
|
104 |
+
if rerank_by_question:
|
105 |
+
docs_question = docs_question[:k_by_question[i]]
|
106 |
+
|
107 |
+
# Add sources used in the metadata
|
108 |
+
for doc in docs_question:
|
109 |
+
doc.metadata["sources_used"] = sources
|
110 |
+
|
111 |
+
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
|
112 |
+
|
113 |
+
docs.extend(docs_question)
|
114 |
+
|
115 |
+
else:
|
116 |
+
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
|
117 |
+
|
118 |
+
# Remove duplicates and keep the duplicate document with the highest reranking score
|
119 |
+
docs = remove_duplicates_keep_highest_score(docs)
|
120 |
+
|
121 |
+
# Sorting the list in descending order by rerank_score
|
122 |
+
# Then select the top k
|
123 |
+
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
124 |
+
docs = docs[:k_final]
|
125 |
+
|
126 |
+
return {"recommended_content": docs}
|
127 |
+
|
128 |
+
return node_retrieve_graphs
|
@@ -17,8 +17,8 @@ class IntentCategorizer(BaseModel):
|
|
17 |
intent: str = Field(
|
18 |
enum=[
|
19 |
"ai_impact",
|
20 |
-
"geo_info",
|
21 |
-
"esg",
|
22 |
"search",
|
23 |
"chitchat",
|
24 |
],
|
@@ -28,11 +28,12 @@ class IntentCategorizer(BaseModel):
|
|
28 |
|
29 |
Examples:
|
30 |
- ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
|
31 |
-
- geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
|
32 |
-
- esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
|
33 |
- search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
|
34 |
- chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
|
35 |
""",
|
|
|
|
|
|
|
36 |
)
|
37 |
|
38 |
|
@@ -43,7 +44,7 @@ def make_intent_categorization_chain(llm):
|
|
43 |
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
44 |
|
45 |
prompt = ChatPromptTemplate.from_messages([
|
46 |
-
("system", "You are a helpful assistant, you will analyze, translate and
|
47 |
("user", "input: {input}")
|
48 |
])
|
49 |
|
@@ -56,7 +57,10 @@ def make_intent_categorization_node(llm):
|
|
56 |
categorization_chain = make_intent_categorization_chain(llm)
|
57 |
|
58 |
def categorize_message(state):
|
59 |
-
|
|
|
|
|
|
|
60 |
if "language" not in output: output["language"] = "English"
|
61 |
output["query"] = state["user_input"]
|
62 |
return output
|
|
|
17 |
intent: str = Field(
|
18 |
enum=[
|
19 |
"ai_impact",
|
20 |
+
# "geo_info",
|
21 |
+
# "esg",
|
22 |
"search",
|
23 |
"chitchat",
|
24 |
],
|
|
|
28 |
|
29 |
Examples:
|
30 |
- ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
|
|
|
|
|
31 |
- search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
|
32 |
- chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
|
33 |
""",
|
34 |
+
# - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
|
35 |
+
# - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
|
36 |
+
|
37 |
)
|
38 |
|
39 |
|
|
|
44 |
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
45 |
|
46 |
prompt = ChatPromptTemplate.from_messages([
|
47 |
+
("system", "You are a helpful assistant, you will analyze, translate and categorize the user input message using the function provided. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
|
48 |
("user", "input: {input}")
|
49 |
])
|
50 |
|
|
|
57 |
categorization_chain = make_intent_categorization_chain(llm)
|
58 |
|
59 |
def categorize_message(state):
|
60 |
+
print("---- Categorize_message ----")
|
61 |
+
|
62 |
+
output = categorization_chain.invoke({"input": state["user_input"]})
|
63 |
+
print(f"\n\nOutput intent categorization: {output}\n")
|
64 |
if "language" not in output: output["language"] = "English"
|
65 |
output["query"] = state["user_input"]
|
66 |
return output
|
@@ -147,4 +147,27 @@ audience_prompts = {
|
|
147 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
148 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
149 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
150 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
148 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
149 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
150 |
+
}
|
151 |
+
|
152 |
+
|
153 |
+
answer_prompt_graph_template = """
|
154 |
+
Given the user question and a list of graphs which are related to the question, rank the graphs based on relevance to the user question. ALWAYS follow the guidelines given below.
|
155 |
+
|
156 |
+
### Guidelines ###
|
157 |
+
- Keep all the graphs that are given to you.
|
158 |
+
- NEVER modify the graph HTML embedding, the category or the source leave them exactly as they are given.
|
159 |
+
- Return the ranked graphs as a list of dictionaries with keys 'embedding', 'category', and 'source'.
|
160 |
+
- Return a valid JSON output.
|
161 |
+
|
162 |
+
-----------------------
|
163 |
+
User question:
|
164 |
+
{query}
|
165 |
+
|
166 |
+
Graphs and their HTML embedding:
|
167 |
+
{recommended_content}
|
168 |
+
|
169 |
+
-----------------------
|
170 |
+
{format_instructions}
|
171 |
+
|
172 |
+
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
173 |
+
"""
|
@@ -69,15 +69,15 @@ class QueryAnalysis(BaseModel):
|
|
69 |
# """
|
70 |
# )
|
71 |
|
72 |
-
sources: List[Literal["IPCC", "IPBES", "IPOS"
|
73 |
...,
|
74 |
description="""
|
75 |
Given a user question choose which documents would be most relevant for answering their question,
|
76 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
77 |
- IPBES is for questions about biodiversity and nature
|
78 |
- IPOS is for questions about the ocean and deep sea mining
|
79 |
-
- OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
80 |
""",
|
|
|
81 |
)
|
82 |
# topics: List[Literal[
|
83 |
# "Climate change",
|
@@ -138,6 +138,8 @@ def make_query_transform_node(llm,k_final=15):
|
|
138 |
rewriter_chain = make_query_rewriter_chain(llm)
|
139 |
|
140 |
def transform_query(state):
|
|
|
|
|
141 |
|
142 |
if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
|
143 |
auto_mode = False
|
@@ -158,6 +160,12 @@ def make_query_transform_node(llm,k_final=15):
|
|
158 |
for question in new_state["questions"]:
|
159 |
question_state = {"question":question}
|
160 |
analysis_output = rewriter_chain.invoke({"input":question})
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
question_state.update(analysis_output)
|
162 |
questions.append(question_state)
|
163 |
|
|
|
69 |
# """
|
70 |
# )
|
71 |
|
72 |
+
sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field( #,"OpenAlex"]] = Field(
|
73 |
...,
|
74 |
description="""
|
75 |
Given a user question choose which documents would be most relevant for answering their question,
|
76 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
77 |
- IPBES is for questions about biodiversity and nature
|
78 |
- IPOS is for questions about the ocean and deep sea mining
|
|
|
79 |
""",
|
80 |
+
# - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
81 |
)
|
82 |
# topics: List[Literal[
|
83 |
# "Climate change",
|
|
|
138 |
rewriter_chain = make_query_rewriter_chain(llm)
|
139 |
|
140 |
def transform_query(state):
|
141 |
+
print("---- Transform query ----")
|
142 |
+
|
143 |
|
144 |
if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
|
145 |
auto_mode = False
|
|
|
160 |
for question in new_state["questions"]:
|
161 |
question_state = {"question":question}
|
162 |
analysis_output = rewriter_chain.invoke({"input":question})
|
163 |
+
|
164 |
+
# TODO WARNING llm should always return smthg
|
165 |
+
# The case when the llm does not return any sources
|
166 |
+
if not analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in analysis_output["sources"]):
|
167 |
+
analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
|
168 |
+
|
169 |
question_state.update(analysis_output)
|
170 |
questions.append(question_state)
|
171 |
|
@@ -8,10 +8,13 @@ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
|
8 |
from langchain_core.runnables import RunnableLambda
|
9 |
|
10 |
from ..reranker import rerank_docs
|
11 |
-
from ...knowledge.retriever import ClimateQARetriever
|
12 |
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
from .keywords_extraction import make_keywords_extraction_chain
|
14 |
from ..utils import log_event
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
|
@@ -57,105 +60,244 @@ def query_retriever(question):
|
|
57 |
"""Just a dummy tool to simulate the retriever query"""
|
58 |
return question
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
|
|
|
|
|
|
|
|
|
|
63 |
|
|
|
|
|
64 |
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
|
|
|
|
79 |
|
80 |
-
#
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
docs = []
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
k_by_question = k_final // state["n_questions"]
|
93 |
-
|
94 |
-
sources = current_question["sources"]
|
95 |
-
question = current_question["question"]
|
96 |
-
index = current_question["index"]
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
-
if index == "Vector":
|
103 |
-
|
104 |
-
# Search the document store using the retriever
|
105 |
-
# Configure high top k for further reranking step
|
106 |
-
retriever = ClimateQARetriever(
|
107 |
-
vectorstore=vectorstore,
|
108 |
-
sources = sources,
|
109 |
-
min_size = 200,
|
110 |
-
k_summary = k_summary,
|
111 |
-
k_total = k_before_reranking,
|
112 |
-
threshold = 0.5,
|
113 |
-
)
|
114 |
-
docs_question = await retriever.ainvoke(question,config)
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
print(f"... OpenAlex query: {openalex_query}")
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
docs_question = docs_question[:k_by_question]
|
145 |
-
|
146 |
-
# Add sources used in the metadata
|
147 |
for doc in docs_question:
|
148 |
-
doc.metadata["
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
docs.extend(docs_question)
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
|
|
8 |
from langchain_core.runnables import RunnableLambda
|
9 |
|
10 |
from ..reranker import rerank_docs
|
11 |
+
# from ...knowledge.retriever import ClimateQARetriever
|
12 |
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
from .keywords_extraction import make_keywords_extraction_chain
|
14 |
from ..utils import log_event
|
15 |
+
from langchain_core.vectorstores import VectorStore
|
16 |
+
from typing import List
|
17 |
+
from langchain_core.documents.base import Document
|
18 |
|
19 |
|
20 |
|
|
|
60 |
"""Just a dummy tool to simulate the retriever query"""
|
61 |
return question
|
62 |
|
63 |
+
def _add_sources_used_in_metadata(docs,sources,question,index):
|
64 |
+
for doc in docs:
|
65 |
+
doc.metadata["sources_used"] = sources
|
66 |
+
doc.metadata["question_used"] = question
|
67 |
+
doc.metadata["index_used"] = index
|
68 |
+
return docs
|
69 |
|
70 |
+
def _get_k_summary_by_question(n_questions):
|
71 |
+
if n_questions == 0:
|
72 |
+
return 0
|
73 |
+
elif n_questions == 1:
|
74 |
+
return 5
|
75 |
+
elif n_questions == 2:
|
76 |
+
return 3
|
77 |
+
elif n_questions == 3:
|
78 |
+
return 2
|
79 |
+
else:
|
80 |
+
return 1
|
81 |
+
|
82 |
+
def _get_k_images_by_question(n_questions):
|
83 |
+
if n_questions == 0:
|
84 |
+
return 0
|
85 |
+
elif n_questions == 1:
|
86 |
+
return 7
|
87 |
+
elif n_questions == 2:
|
88 |
+
return 5
|
89 |
+
elif n_questions == 3:
|
90 |
+
return 2
|
91 |
+
else:
|
92 |
+
return 1
|
93 |
+
|
94 |
+
def _add_metadata_and_score(docs: List) -> Document:
|
95 |
+
# Add score to metadata
|
96 |
+
docs_with_metadata = []
|
97 |
+
for i,(doc,score) in enumerate(docs):
|
98 |
+
doc.page_content = doc.page_content.replace("\r\n"," ")
|
99 |
+
doc.metadata["similarity_score"] = score
|
100 |
+
doc.metadata["content"] = doc.page_content
|
101 |
+
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
102 |
+
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
103 |
+
docs_with_metadata.append(doc)
|
104 |
+
return docs_with_metadata
|
105 |
|
106 |
+
async def get_IPCC_relevant_documents(
|
107 |
+
query: str,
|
108 |
+
vectorstore:VectorStore,
|
109 |
+
sources:list = ["IPCC","IPBES","IPOS"],
|
110 |
+
search_figures:bool = False,
|
111 |
+
reports:list = [],
|
112 |
+
threshold:float = 0.6,
|
113 |
+
k_summary:int = 3,
|
114 |
+
k_total:int = 10,
|
115 |
+
k_images: int = 5,
|
116 |
+
namespace:str = "vectors",
|
117 |
+
min_size:int = 200,
|
118 |
+
search_only:bool = False,
|
119 |
+
) :
|
120 |
|
121 |
+
# Check if all elements in the list are either IPCC or IPBES
|
122 |
+
assert isinstance(sources,list)
|
123 |
+
assert sources
|
124 |
+
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
|
125 |
+
assert k_total > k_summary, "k_total should be greater than k_summary"
|
126 |
|
127 |
+
# Prepare base search kwargs
|
128 |
+
filters = {}
|
129 |
|
130 |
+
if len(reports) > 0:
|
131 |
+
filters["short_name"] = {"$in":reports}
|
132 |
+
else:
|
133 |
+
filters["source"] = { "$in": sources}
|
134 |
|
135 |
+
# INIT
|
136 |
+
docs_summaries = []
|
137 |
+
docs_full = []
|
138 |
+
docs_images = []
|
139 |
|
140 |
+
if search_only:
|
141 |
+
# Only search for images if search_only is True
|
142 |
+
if search_figures:
|
143 |
+
filters_image = {
|
144 |
+
**filters,
|
145 |
+
"chunk_type":"image"
|
146 |
+
}
|
147 |
+
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
148 |
+
docs_images = _add_metadata_and_score(docs_images)
|
149 |
+
else:
|
150 |
+
# Regular search flow for text and optionally images
|
151 |
+
# Search for k_summary documents in the summaries dataset
|
152 |
+
filters_summaries = {
|
153 |
+
**filters,
|
154 |
+
"chunk_type":"text",
|
155 |
+
"report_type": { "$in":["SPM"]},
|
156 |
+
}
|
157 |
|
158 |
+
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
|
159 |
+
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
|
160 |
|
161 |
+
# Search for k_total - k_summary documents in the full reports dataset
|
162 |
+
filters_full = {
|
163 |
+
**filters,
|
164 |
+
"chunk_type":"text",
|
165 |
+
"report_type": { "$nin":["SPM"]},
|
166 |
+
}
|
167 |
+
k_full = k_total - len(docs_summaries)
|
168 |
+
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
if search_figures:
|
171 |
+
# Images
|
172 |
+
filters_image = {
|
173 |
+
**filters,
|
174 |
+
"chunk_type":"image"
|
175 |
+
}
|
176 |
+
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
177 |
|
178 |
+
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
179 |
+
|
180 |
+
# Filter if length are below threshold
|
181 |
+
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
|
182 |
+
docs_full = [x for x in docs_full if len(x.page_content) > min_size]
|
183 |
+
|
184 |
+
return {
|
185 |
+
"docs_summaries" : docs_summaries,
|
186 |
+
"docs_full" : docs_full,
|
187 |
+
"docs_images" : docs_images,
|
188 |
+
}
|
189 |
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
+
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
193 |
+
# @chain
|
194 |
+
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
|
195 |
+
"""
|
196 |
+
Retrieve and rerank documents based on the current question in the state.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
|
200 |
+
config (dict): Configuration settings for logging and other purposes.
|
201 |
+
vectorstore (object): The vector store used to retrieve relevant documents.
|
202 |
+
reranker (object): The reranker used to rerank the retrieved documents.
|
203 |
+
llm (object): The language model used for processing.
|
204 |
+
rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
|
205 |
+
k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
|
206 |
+
k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
|
207 |
+
k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
|
208 |
+
k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
|
209 |
+
Returns:
|
210 |
+
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
211 |
+
"""
|
212 |
+
print("---- Retrieve documents ----")
|
213 |
+
|
214 |
+
# Get the documents from the state
|
215 |
+
if "documents" in state and state["documents"] is not None:
|
216 |
+
docs = state["documents"]
|
217 |
+
else:
|
218 |
+
docs = []
|
219 |
+
# Get the related_content from the state
|
220 |
+
if "related_content" in state and state["related_content"] is not None:
|
221 |
+
related_content = state["related_content"]
|
222 |
+
else:
|
223 |
+
related_content = []
|
224 |
+
|
225 |
+
search_figures = "IPCC figures" in state["relevant_content_sources"]
|
226 |
+
search_only = state["search_only"]
|
227 |
|
228 |
+
# Get the current question
|
229 |
+
current_question = state["remaining_questions"][0]
|
230 |
+
remaining_questions = state["remaining_questions"][1:]
|
231 |
+
|
232 |
+
k_by_question = k_final // state["n_questions"]
|
233 |
+
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
|
234 |
+
k_images_by_question = _get_k_images_by_question(state["n_questions"])
|
235 |
+
|
236 |
+
sources = current_question["sources"]
|
237 |
+
question = current_question["question"]
|
238 |
+
index = current_question["index"]
|
239 |
+
|
240 |
+
print(f"Retrieve documents for question: {question}")
|
241 |
+
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
242 |
|
|
|
243 |
|
244 |
+
if index == "Vector": # always true for now
|
245 |
+
docs_question_dict = await get_IPCC_relevant_documents(
|
246 |
+
query = question,
|
247 |
+
vectorstore=vectorstore,
|
248 |
+
search_figures = search_figures,
|
249 |
+
sources = sources,
|
250 |
+
min_size = 200,
|
251 |
+
k_summary = k_summary_by_question,
|
252 |
+
k_total = k_before_reranking,
|
253 |
+
k_images = k_images_by_question,
|
254 |
+
threshold = 0.5,
|
255 |
+
search_only = search_only,
|
256 |
+
)
|
257 |
|
258 |
+
|
259 |
+
# Rerank
|
260 |
+
if reranker is not None:
|
261 |
+
with suppress_output():
|
262 |
+
docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
|
263 |
+
docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
|
264 |
+
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
|
265 |
+
if rerank_by_question:
|
266 |
+
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
267 |
+
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
268 |
+
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
269 |
+
else:
|
270 |
+
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
|
271 |
+
# Add a default reranking score
|
|
|
|
|
|
|
272 |
for doc in docs_question:
|
273 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
274 |
+
|
275 |
+
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
|
276 |
+
docs_question = docs_question[:k_by_question]
|
277 |
+
images_question = docs_question_images_reranked[:k_images]
|
|
|
278 |
|
279 |
+
if reranker is not None and rerank_by_question:
|
280 |
+
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
281 |
+
|
282 |
+
# Add sources used in the metadata
|
283 |
+
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
284 |
+
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
285 |
+
|
286 |
+
# Add to the list of docs
|
287 |
+
docs.extend(docs_question)
|
288 |
+
related_content.extend(images_question)
|
289 |
+
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
|
290 |
+
return new_state
|
291 |
|
292 |
+
|
293 |
+
|
294 |
+
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
295 |
+
@chain
|
296 |
+
async def retrieve_docs(state, config):
|
297 |
+
state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
|
298 |
+
return state
|
299 |
+
|
300 |
+
return retrieve_docs
|
301 |
+
|
302 |
+
|
303 |
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.keywords import make_keywords_chain
|
2 |
+
from climateqa.engine.llm import get_llm
|
3 |
+
from climateqa.knowledge.openalex import OpenAlex
|
4 |
+
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
5 |
+
from front.utils import make_html_papers
|
6 |
+
from climateqa.engine.reranker import get_reranker
|
7 |
+
|
8 |
+
oa = OpenAlex()
|
9 |
+
|
10 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
11 |
+
reranker = get_reranker("nano")
|
12 |
+
|
13 |
+
|
14 |
+
papers_cols_widths = {
|
15 |
+
"id":100,
|
16 |
+
"title":300,
|
17 |
+
"doi":100,
|
18 |
+
"publication_year":100,
|
19 |
+
"abstract":500,
|
20 |
+
"is_oa":50,
|
21 |
+
}
|
22 |
+
|
23 |
+
papers_cols = list(papers_cols_widths.keys())
|
24 |
+
papers_cols_widths = list(papers_cols_widths.values())
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def generate_keywords(query):
|
29 |
+
chain = make_keywords_chain(llm)
|
30 |
+
keywords = chain.invoke(query)
|
31 |
+
keywords = " AND ".join(keywords["keywords"])
|
32 |
+
return keywords
|
33 |
+
|
34 |
+
|
35 |
+
async def find_papers(query,after, relevant_content_sources, reranker= reranker):
|
36 |
+
if "OpenAlex" in relevant_content_sources:
|
37 |
+
summary = ""
|
38 |
+
keywords = generate_keywords(query)
|
39 |
+
df_works = oa.search(keywords,after = after)
|
40 |
+
|
41 |
+
print(f"Found {len(df_works)} papers")
|
42 |
+
|
43 |
+
if not df_works.empty:
|
44 |
+
df_works = df_works.dropna(subset=["abstract"])
|
45 |
+
df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
|
46 |
+
df_works = oa.rerank(query,df_works,reranker)
|
47 |
+
df_works = df_works.sort_values("rerank_score",ascending=False)
|
48 |
+
docs_html = []
|
49 |
+
for i in range(10):
|
50 |
+
docs_html.append(make_html_papers(df_works, i))
|
51 |
+
docs_html = "".join(docs_html)
|
52 |
+
G = oa.make_network(df_works)
|
53 |
+
|
54 |
+
height = "750px"
|
55 |
+
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
56 |
+
network_html = network.generate_html()
|
57 |
+
|
58 |
+
network_html = network_html.replace("'", "\"")
|
59 |
+
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
60 |
+
network_html = network_html + css_to_inject
|
61 |
+
|
62 |
+
|
63 |
+
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
64 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
65 |
+
allow-scripts allow-same-origin allow-popups
|
66 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
67 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
68 |
+
|
69 |
+
|
70 |
+
docs = df_works["content"].head(10).tolist()
|
71 |
+
|
72 |
+
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
73 |
+
df_works["doc"] = df_works["doc"] + 1
|
74 |
+
df_works = df_works[papers_cols]
|
75 |
+
|
76 |
+
yield docs_html, network_html, summary
|
77 |
+
|
78 |
+
chain = make_rag_papers_chain(llm)
|
79 |
+
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
80 |
+
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
81 |
+
|
82 |
+
async for op in result:
|
83 |
+
|
84 |
+
op = op.ops[0]
|
85 |
+
|
86 |
+
if op['path'] == path_answer: # reforulated question
|
87 |
+
new_token = op['value'] # str
|
88 |
+
summary += new_token
|
89 |
+
else:
|
90 |
+
continue
|
91 |
+
yield docs_html, network_html, summary
|
92 |
+
else :
|
93 |
+
print("No papers found")
|
94 |
+
else :
|
95 |
+
yield "","", ""
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import sys
|
2 |
+
# import os
|
3 |
+
# from contextlib import contextmanager
|
4 |
+
|
5 |
+
# from ..reranker import rerank_docs
|
6 |
+
# from ...knowledge.retriever import ClimateQARetriever
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
# def divide_into_parts(target, parts):
|
12 |
+
# # Base value for each part
|
13 |
+
# base = target // parts
|
14 |
+
# # Remainder to distribute
|
15 |
+
# remainder = target % parts
|
16 |
+
# # List to hold the result
|
17 |
+
# result = []
|
18 |
+
|
19 |
+
# for i in range(parts):
|
20 |
+
# if i < remainder:
|
21 |
+
# # These parts get base value + 1
|
22 |
+
# result.append(base + 1)
|
23 |
+
# else:
|
24 |
+
# # The rest get the base value
|
25 |
+
# result.append(base)
|
26 |
+
|
27 |
+
# return result
|
28 |
+
|
29 |
+
|
30 |
+
# @contextmanager
|
31 |
+
# def suppress_output():
|
32 |
+
# # Open a null device
|
33 |
+
# with open(os.devnull, 'w') as devnull:
|
34 |
+
# # Store the original stdout and stderr
|
35 |
+
# old_stdout = sys.stdout
|
36 |
+
# old_stderr = sys.stderr
|
37 |
+
# # Redirect stdout and stderr to the null device
|
38 |
+
# sys.stdout = devnull
|
39 |
+
# sys.stderr = devnull
|
40 |
+
# try:
|
41 |
+
# yield
|
42 |
+
# finally:
|
43 |
+
# # Restore stdout and stderr
|
44 |
+
# sys.stdout = old_stdout
|
45 |
+
# sys.stderr = old_stderr
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
# def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
50 |
+
|
51 |
+
# def retrieve_documents(state):
|
52 |
+
|
53 |
+
# POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
|
54 |
+
# questions = state["questions"]
|
55 |
+
|
56 |
+
# # Use sources from the user input or from the LLM detection
|
57 |
+
# if "sources_input" not in state or state["sources_input"] is None:
|
58 |
+
# sources_input = ["auto"]
|
59 |
+
# else:
|
60 |
+
# sources_input = state["sources_input"]
|
61 |
+
# auto_mode = "auto" in sources_input
|
62 |
+
|
63 |
+
# # There are several options to get the final top k
|
64 |
+
# # Option 1 - Get 100 documents by question and rerank by question
|
65 |
+
# # Option 2 - Get 100/n documents by question and rerank the total
|
66 |
+
# if rerank_by_question:
|
67 |
+
# k_by_question = divide_into_parts(k_final,len(questions))
|
68 |
+
|
69 |
+
# docs = []
|
70 |
+
|
71 |
+
# for i,q in enumerate(questions):
|
72 |
+
|
73 |
+
# sources = q["sources"]
|
74 |
+
# question = q["question"]
|
75 |
+
|
76 |
+
# # If auto mode, we use the sources detected by the LLM
|
77 |
+
# if auto_mode:
|
78 |
+
# sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
79 |
+
|
80 |
+
# # Otherwise, we use the config
|
81 |
+
# else:
|
82 |
+
# sources = sources_input
|
83 |
+
|
84 |
+
# # Search the document store using the retriever
|
85 |
+
# # Configure high top k for further reranking step
|
86 |
+
# retriever = ClimateQARetriever(
|
87 |
+
# vectorstore=vectorstore,
|
88 |
+
# sources = sources,
|
89 |
+
# # reports = ias_reports,
|
90 |
+
# min_size = 200,
|
91 |
+
# k_summary = k_summary,
|
92 |
+
# k_total = k_before_reranking,
|
93 |
+
# threshold = 0.5,
|
94 |
+
# )
|
95 |
+
# docs_question = retriever.get_relevant_documents(question)
|
96 |
+
|
97 |
+
# # Rerank
|
98 |
+
# if reranker is not None:
|
99 |
+
# with suppress_output():
|
100 |
+
# docs_question = rerank_docs(reranker,docs_question,question)
|
101 |
+
# else:
|
102 |
+
# # Add a default reranking score
|
103 |
+
# for doc in docs_question:
|
104 |
+
# doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
105 |
+
|
106 |
+
# # If rerank by question we select the top documents for each question
|
107 |
+
# if rerank_by_question:
|
108 |
+
# docs_question = docs_question[:k_by_question[i]]
|
109 |
+
|
110 |
+
# # Add sources used in the metadata
|
111 |
+
# for doc in docs_question:
|
112 |
+
# doc.metadata["sources_used"] = sources
|
113 |
+
|
114 |
+
# # Add to the list of docs
|
115 |
+
# docs.extend(docs_question)
|
116 |
+
|
117 |
+
# # Sorting the list in descending order by rerank_score
|
118 |
+
# # Then select the top k
|
119 |
+
# docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
120 |
+
# docs = docs[:k_final]
|
121 |
+
|
122 |
+
# new_state = {"documents":docs}
|
123 |
+
# return new_state
|
124 |
+
|
125 |
+
# return retrieve_documents
|
126 |
+
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def set_defaults(state):
|
2 |
+
print("---- Setting defaults ----")
|
3 |
+
|
4 |
+
if not state["audience"] or state["audience"] is None:
|
5 |
+
state.update({"audience": "experts"})
|
6 |
+
|
7 |
+
sources_input = state["sources_input"] if "sources_input" in state else ["auto"]
|
8 |
+
state.update({"sources_input": sources_input})
|
9 |
+
|
10 |
+
# if not state["sources_input"] or state["sources_input"] is None:
|
11 |
+
# state.update({"sources_input": ["auto"]})
|
12 |
+
|
13 |
+
return state
|
@@ -30,10 +30,11 @@ def make_translation_chain(llm):
|
|
30 |
|
31 |
|
32 |
def make_translation_node(llm):
|
33 |
-
|
34 |
translation_chain = make_translation_chain(llm)
|
35 |
|
36 |
def translate_query(state):
|
|
|
|
|
37 |
user_input = state["user_input"]
|
38 |
translation = translation_chain.invoke({"input":user_input})
|
39 |
return {"query":translation["translation"]}
|
|
|
30 |
|
31 |
|
32 |
def make_translation_node(llm):
|
|
|
33 |
translation_chain = make_translation_chain(llm)
|
34 |
|
35 |
def translate_query(state):
|
36 |
+
print("---- Translate query ----")
|
37 |
+
|
38 |
user_input = state["user_input"]
|
39 |
translation = translation_chain.invoke({"input":user_input})
|
40 |
return {"query":translation["translation"]}
|
@@ -7,7 +7,7 @@ from langgraph.graph import END, StateGraph
|
|
7 |
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
|
8 |
|
9 |
from typing_extensions import TypedDict
|
10 |
-
from typing import List
|
11 |
|
12 |
from IPython.display import display, HTML, Image
|
13 |
|
@@ -18,6 +18,9 @@ from .chains.translation import make_translation_node
|
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
from .chains.retrieve_documents import make_retriever_node
|
20 |
from .chains.answer_rag import make_rag_node
|
|
|
|
|
|
|
21 |
|
22 |
class GraphState(TypedDict):
|
23 |
"""
|
@@ -26,16 +29,21 @@ class GraphState(TypedDict):
|
|
26 |
user_input : str
|
27 |
language : str
|
28 |
intent : str
|
|
|
29 |
query: str
|
30 |
remaining_questions : List[dict]
|
31 |
n_questions : int
|
32 |
answer: str
|
33 |
audience: str = "experts"
|
34 |
sources_input: List[str] = ["IPCC","IPBES"]
|
|
|
35 |
sources_auto: bool = True
|
36 |
min_year: int = 1960
|
37 |
max_year: int = None
|
38 |
documents: List[Document]
|
|
|
|
|
|
|
39 |
|
40 |
def search(state): #TODO
|
41 |
return state
|
@@ -52,6 +60,13 @@ def route_intent(state):
|
|
52 |
else:
|
53 |
# Search route
|
54 |
return "search"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def route_translation(state):
|
57 |
if state["language"].lower() == "english":
|
@@ -66,11 +81,18 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
66 |
else:
|
67 |
return "answer_rag_no_docs"
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def make_id_dict(values):
|
71 |
return {k:k for k in values}
|
72 |
|
73 |
-
def make_graph_agent(llm,
|
74 |
|
75 |
workflow = StateGraph(GraphState)
|
76 |
|
@@ -80,21 +102,26 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
|
80 |
translate_query = make_translation_node(llm)
|
81 |
answer_chitchat = make_chitchat_node(llm)
|
82 |
answer_ai_impact = make_ai_impact_node(llm)
|
83 |
-
retrieve_documents = make_retriever_node(
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
# Define the nodes
|
|
|
88 |
workflow.add_node("categorize_intent", categorize_intent)
|
89 |
workflow.add_node("search", search)
|
90 |
workflow.add_node("answer_search", answer_search)
|
91 |
workflow.add_node("transform_query", transform_query)
|
92 |
workflow.add_node("translate_query", translate_query)
|
93 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
94 |
-
|
95 |
-
workflow.add_node("
|
96 |
-
workflow.add_node("
|
97 |
-
workflow.add_node("
|
|
|
|
|
98 |
|
99 |
# Entry point
|
100 |
workflow.set_entry_point("categorize_intent")
|
@@ -106,6 +133,12 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
|
106 |
make_id_dict(["answer_chitchat","search"])
|
107 |
)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
workflow.add_conditional_edges(
|
110 |
"search",
|
111 |
route_translation,
|
@@ -113,8 +146,9 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
|
113 |
)
|
114 |
workflow.add_conditional_edges(
|
115 |
"retrieve_documents",
|
116 |
-
lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
|
117 |
-
|
|
|
118 |
)
|
119 |
|
120 |
workflow.add_conditional_edges(
|
@@ -122,14 +156,21 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
|
122 |
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
123 |
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
124 |
)
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
# Define the edges
|
127 |
workflow.add_edge("translate_query", "transform_query")
|
128 |
workflow.add_edge("transform_query", "retrieve_documents")
|
|
|
|
|
129 |
workflow.add_edge("answer_rag", END)
|
130 |
workflow.add_edge("answer_rag_no_docs", END)
|
131 |
-
workflow.add_edge("answer_chitchat",
|
132 |
-
|
133 |
|
134 |
# Compile
|
135 |
app = workflow.compile()
|
@@ -146,4 +187,4 @@ def display_graph(app):
|
|
146 |
draw_method=MermaidDrawMethod.API,
|
147 |
)
|
148 |
)
|
149 |
-
)
|
|
|
7 |
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
|
8 |
|
9 |
from typing_extensions import TypedDict
|
10 |
+
from typing import List, Dict
|
11 |
|
12 |
from IPython.display import display, HTML, Image
|
13 |
|
|
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
from .chains.retrieve_documents import make_retriever_node
|
20 |
from .chains.answer_rag import make_rag_node
|
21 |
+
from .chains.graph_retriever import make_graph_retriever_node
|
22 |
+
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
23 |
+
# from .chains.set_defaults import set_defaults
|
24 |
|
25 |
class GraphState(TypedDict):
|
26 |
"""
|
|
|
29 |
user_input : str
|
30 |
language : str
|
31 |
intent : str
|
32 |
+
search_graphs_chitchat : bool
|
33 |
query: str
|
34 |
remaining_questions : List[dict]
|
35 |
n_questions : int
|
36 |
answer: str
|
37 |
audience: str = "experts"
|
38 |
sources_input: List[str] = ["IPCC","IPBES"]
|
39 |
+
relevant_content_sources: List[str] = ["IPCC figures"]
|
40 |
sources_auto: bool = True
|
41 |
min_year: int = 1960
|
42 |
max_year: int = None
|
43 |
documents: List[Document]
|
44 |
+
related_contents : Dict[str,Document]
|
45 |
+
recommended_content : List[Document]
|
46 |
+
search_only : bool = False
|
47 |
|
48 |
def search(state): #TODO
|
49 |
return state
|
|
|
60 |
else:
|
61 |
# Search route
|
62 |
return "search"
|
63 |
+
|
64 |
+
def chitchat_route_intent(state):
|
65 |
+
intent = state["search_graphs_chitchat"]
|
66 |
+
if intent is True:
|
67 |
+
return "retrieve_graphs_chitchat"
|
68 |
+
elif intent is False:
|
69 |
+
return END
|
70 |
|
71 |
def route_translation(state):
|
72 |
if state["language"].lower() == "english":
|
|
|
81 |
else:
|
82 |
return "answer_rag_no_docs"
|
83 |
|
84 |
+
def route_retrieve_documents(state):
|
85 |
+
if state["search_only"] :
|
86 |
+
return END
|
87 |
+
elif len(state["remaining_questions"]) > 0:
|
88 |
+
return "retrieve_documents"
|
89 |
+
else:
|
90 |
+
return "answer_search"
|
91 |
|
92 |
def make_id_dict(values):
|
93 |
return {k:k for k in values}
|
94 |
|
95 |
+
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2):
|
96 |
|
97 |
workflow = StateGraph(GraphState)
|
98 |
|
|
|
102 |
translate_query = make_translation_node(llm)
|
103 |
answer_chitchat = make_chitchat_node(llm)
|
104 |
answer_ai_impact = make_ai_impact_node(llm)
|
105 |
+
retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker, llm)
|
106 |
+
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
107 |
+
answer_rag = make_rag_node(llm, with_docs=True)
|
108 |
+
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
109 |
+
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
110 |
|
111 |
# Define the nodes
|
112 |
+
# workflow.add_node("set_defaults", set_defaults)
|
113 |
workflow.add_node("categorize_intent", categorize_intent)
|
114 |
workflow.add_node("search", search)
|
115 |
workflow.add_node("answer_search", answer_search)
|
116 |
workflow.add_node("transform_query", transform_query)
|
117 |
workflow.add_node("translate_query", translate_query)
|
118 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
119 |
+
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
120 |
+
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
121 |
+
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
122 |
+
workflow.add_node("retrieve_documents", retrieve_documents)
|
123 |
+
workflow.add_node("answer_rag", answer_rag)
|
124 |
+
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
125 |
|
126 |
# Entry point
|
127 |
workflow.set_entry_point("categorize_intent")
|
|
|
133 |
make_id_dict(["answer_chitchat","search"])
|
134 |
)
|
135 |
|
136 |
+
workflow.add_conditional_edges(
|
137 |
+
"chitchat_categorize_intent",
|
138 |
+
chitchat_route_intent,
|
139 |
+
make_id_dict(["retrieve_graphs_chitchat", END])
|
140 |
+
)
|
141 |
+
|
142 |
workflow.add_conditional_edges(
|
143 |
"search",
|
144 |
route_translation,
|
|
|
146 |
)
|
147 |
workflow.add_conditional_edges(
|
148 |
"retrieve_documents",
|
149 |
+
# lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
|
150 |
+
route_retrieve_documents,
|
151 |
+
make_id_dict([END,"retrieve_documents","answer_search"])
|
152 |
)
|
153 |
|
154 |
workflow.add_conditional_edges(
|
|
|
156 |
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
157 |
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
158 |
)
|
159 |
+
workflow.add_conditional_edges(
|
160 |
+
"transform_query",
|
161 |
+
lambda state : "retrieve_graphs" if "OurWorldInData" in state["relevant_content_sources"] else END,
|
162 |
+
make_id_dict(["retrieve_graphs", END])
|
163 |
+
)
|
164 |
|
165 |
# Define the edges
|
166 |
workflow.add_edge("translate_query", "transform_query")
|
167 |
workflow.add_edge("transform_query", "retrieve_documents")
|
168 |
+
|
169 |
+
workflow.add_edge("retrieve_graphs", END)
|
170 |
workflow.add_edge("answer_rag", END)
|
171 |
workflow.add_edge("answer_rag_no_docs", END)
|
172 |
+
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
173 |
+
|
174 |
|
175 |
# Compile
|
176 |
app = workflow.compile()
|
|
|
187 |
draw_method=MermaidDrawMethod.API,
|
188 |
)
|
189 |
)
|
190 |
+
)
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.retrievers import BaseRetriever
|
2 |
+
from langchain_core.documents.base import Document
|
3 |
+
from langchain_core.vectorstores import VectorStore
|
4 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
# class GraphRetriever(BaseRetriever):
|
9 |
+
# vectorstore:VectorStore
|
10 |
+
# sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
11 |
+
# threshold:float = 0.5
|
12 |
+
# k_total:int = 10
|
13 |
+
|
14 |
+
# def _get_relevant_documents(
|
15 |
+
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
16 |
+
# ) -> List[Document]:
|
17 |
+
|
18 |
+
# # Check if all elements in the list are IEA or OWID
|
19 |
+
# assert isinstance(self.sources,list)
|
20 |
+
# assert self.sources
|
21 |
+
# assert any([x in ["OWID"] for x in self.sources])
|
22 |
+
|
23 |
+
# # Prepare base search kwargs
|
24 |
+
# filters = {}
|
25 |
+
|
26 |
+
# filters["source"] = {"$in": self.sources}
|
27 |
+
|
28 |
+
# docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
|
29 |
+
|
30 |
+
# # Filter if scores are below threshold
|
31 |
+
# docs = [x for x in docs if x[1] > self.threshold]
|
32 |
+
|
33 |
+
# # Remove duplicate documents
|
34 |
+
# unique_docs = []
|
35 |
+
# seen_docs = []
|
36 |
+
# for i, doc in enumerate(docs):
|
37 |
+
# if doc[0].page_content not in seen_docs:
|
38 |
+
# unique_docs.append(doc)
|
39 |
+
# seen_docs.append(doc[0].page_content)
|
40 |
+
|
41 |
+
# # Add score to metadata
|
42 |
+
# results = []
|
43 |
+
# for i,(doc,score) in enumerate(unique_docs):
|
44 |
+
# doc.metadata["similarity_score"] = score
|
45 |
+
# doc.metadata["content"] = doc.page_content
|
46 |
+
# results.append(doc)
|
47 |
+
|
48 |
+
# return results
|
49 |
+
|
50 |
+
async def retrieve_graphs(
|
51 |
+
query: str,
|
52 |
+
vectorstore:VectorStore,
|
53 |
+
sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
54 |
+
threshold:float = 0.5,
|
55 |
+
k_total:int = 10,
|
56 |
+
)-> List[Document]:
|
57 |
+
|
58 |
+
# Check if all elements in the list are IEA or OWID
|
59 |
+
assert isinstance(sources,list)
|
60 |
+
assert sources
|
61 |
+
assert any([x in ["OWID"] for x in sources])
|
62 |
+
|
63 |
+
# Prepare base search kwargs
|
64 |
+
filters = {}
|
65 |
+
|
66 |
+
filters["source"] = {"$in": sources}
|
67 |
+
|
68 |
+
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
|
69 |
+
|
70 |
+
# Filter if scores are below threshold
|
71 |
+
docs = [x for x in docs if x[1] > threshold]
|
72 |
+
|
73 |
+
# Remove duplicate documents
|
74 |
+
unique_docs = []
|
75 |
+
seen_docs = []
|
76 |
+
for i, doc in enumerate(docs):
|
77 |
+
if doc[0].page_content not in seen_docs:
|
78 |
+
unique_docs.append(doc)
|
79 |
+
seen_docs.append(doc[0].page_content)
|
80 |
+
|
81 |
+
# Add score to metadata
|
82 |
+
results = []
|
83 |
+
for i,(doc,score) in enumerate(unique_docs):
|
84 |
+
doc.metadata["similarity_score"] = score
|
85 |
+
doc.metadata["content"] = doc.page_content
|
86 |
+
results.append(doc)
|
87 |
+
|
88 |
+
return results
|
@@ -11,10 +11,12 @@ class KeywordsOutput(BaseModel):
|
|
11 |
|
12 |
keywords: list = Field(
|
13 |
description="""
|
14 |
-
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
|
|
|
15 |
|
16 |
Example:
|
17 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
|
|
18 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
19 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
20 |
"""
|
|
|
11 |
|
12 |
keywords: list = Field(
|
13 |
description="""
|
14 |
+
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers. Answer only with English keywords.
|
15 |
+
Do not use special characters or accents.
|
16 |
|
17 |
Example:
|
18 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
19 |
+
- "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
|
20 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
21 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
22 |
"""
|
@@ -1,11 +1,14 @@
|
|
1 |
import os
|
|
|
2 |
from scipy.special import expit, logit
|
3 |
from rerankers import Reranker
|
|
|
4 |
|
|
|
5 |
|
6 |
-
def get_reranker(model = "nano",cohere_api_key = None):
|
7 |
|
8 |
-
assert model in ["nano","tiny","small","large"]
|
9 |
|
10 |
if model == "nano":
|
11 |
reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
|
@@ -17,11 +20,18 @@ def get_reranker(model = "nano",cohere_api_key = None):
|
|
17 |
if cohere_api_key is None:
|
18 |
cohere_api_key = os.environ["COHERE_API_KEY"]
|
19 |
reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
|
|
|
|
|
|
|
|
|
|
|
20 |
return reranker
|
21 |
|
22 |
|
23 |
|
24 |
def rerank_docs(reranker,docs,query):
|
|
|
|
|
25 |
|
26 |
# Get a list of texts from langchain docs
|
27 |
input_docs = [x.page_content for x in docs]
|
|
|
1 |
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
from scipy.special import expit, logit
|
4 |
from rerankers import Reranker
|
5 |
+
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
+
load_dotenv()
|
8 |
|
9 |
+
def get_reranker(model = "nano", cohere_api_key = None):
|
10 |
|
11 |
+
assert model in ["nano","tiny","small","large", "jina"]
|
12 |
|
13 |
if model == "nano":
|
14 |
reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
|
|
|
20 |
if cohere_api_key is None:
|
21 |
cohere_api_key = os.environ["COHERE_API_KEY"]
|
22 |
reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
|
23 |
+
elif model == "jina":
|
24 |
+
# Reached token quota so does not work
|
25 |
+
reranker = Reranker("jina-reranker-v2-base-multilingual", api_key = os.getenv("JINA_RERANKER_API_KEY"))
|
26 |
+
# marche pas sans gpu ? et anyways returns with another structure donc faudrait changer le code du retriever node
|
27 |
+
# reranker = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual", automodel_args={"torch_dtype": "auto"}, trust_remote_code=True,)
|
28 |
return reranker
|
29 |
|
30 |
|
31 |
|
32 |
def rerank_docs(reranker,docs,query):
|
33 |
+
if docs == []:
|
34 |
+
return []
|
35 |
|
36 |
# Get a list of texts from langchain docs
|
37 |
input_docs = [x.page_content for x in docs]
|
@@ -4,6 +4,7 @@
|
|
4 |
import os
|
5 |
from pinecone import Pinecone
|
6 |
from langchain_community.vectorstores import Pinecone as PineconeVectorstore
|
|
|
7 |
|
8 |
# LOAD ENVIRONMENT VARIABLES
|
9 |
try:
|
@@ -13,7 +14,12 @@ except:
|
|
13 |
pass
|
14 |
|
15 |
|
16 |
-
def
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# # initialize pinecone
|
19 |
# pinecone.init(
|
@@ -27,7 +33,7 @@ def get_pinecone_vectorstore(embeddings,text_key = "content"):
|
|
27 |
# return vectorstore
|
28 |
|
29 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
30 |
-
index = pc.Index(
|
31 |
|
32 |
vectorstore = PineconeVectorstore(
|
33 |
index, embeddings, text_key,
|
|
|
4 |
import os
|
5 |
from pinecone import Pinecone
|
6 |
from langchain_community.vectorstores import Pinecone as PineconeVectorstore
|
7 |
+
from langchain_chroma import Chroma
|
8 |
|
9 |
# LOAD ENVIRONMENT VARIABLES
|
10 |
try:
|
|
|
14 |
pass
|
15 |
|
16 |
|
17 |
+
def get_chroma_vectorstore(embedding_function, persist_directory="/home/dora/climate-question-answering/data/vectorstore"):
|
18 |
+
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
|
19 |
+
return vectorstore
|
20 |
+
|
21 |
+
|
22 |
+
def get_pinecone_vectorstore(embeddings,text_key = "content", index_name = os.getenv("PINECONE_API_INDEX")):
|
23 |
|
24 |
# # initialize pinecone
|
25 |
# pinecone.init(
|
|
|
33 |
# return vectorstore
|
34 |
|
35 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
36 |
+
index = pc.Index(index_name)
|
37 |
|
38 |
vectorstore = PineconeVectorstore(
|
39 |
index, embeddings, text_key,
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.runnables.schema import StreamEvent
|
2 |
+
from gradio import ChatMessage
|
3 |
+
from climateqa.engine.chains.prompts import audience_prompts
|
4 |
+
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def init_audience(audience :str) -> str:
|
8 |
+
if audience == "Children":
|
9 |
+
audience_prompt = audience_prompts["children"]
|
10 |
+
elif audience == "General public":
|
11 |
+
audience_prompt = audience_prompts["general"]
|
12 |
+
elif audience == "Experts":
|
13 |
+
audience_prompt = audience_prompts["experts"]
|
14 |
+
else:
|
15 |
+
audience_prompt = audience_prompts["experts"]
|
16 |
+
return audience_prompt
|
17 |
+
|
18 |
+
def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
|
19 |
+
"""
|
20 |
+
Handles the retrieved documents and returns the HTML representation of the documents
|
21 |
+
|
22 |
+
Args:
|
23 |
+
event (StreamEvent): The event containing the retrieved documents
|
24 |
+
history (list[ChatMessage]): The current message history
|
25 |
+
used_documents (list[str]): The list of used documents
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
|
29 |
+
"""
|
30 |
+
try:
|
31 |
+
docs = event["data"]["output"]["documents"]
|
32 |
+
docs_html = []
|
33 |
+
textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
|
34 |
+
for i, d in enumerate(textual_docs, 1):
|
35 |
+
if d.metadata["chunk_type"] == "text":
|
36 |
+
docs_html.append(make_html_source(d, i))
|
37 |
+
|
38 |
+
used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
|
39 |
+
if used_documents!=[]:
|
40 |
+
history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
|
41 |
+
|
42 |
+
docs_html = "".join(docs_html)
|
43 |
+
|
44 |
+
related_contents = event["data"]["output"]["related_contents"]
|
45 |
+
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error getting documents: {e}")
|
48 |
+
print(event)
|
49 |
+
return docs, docs_html, history, used_documents, related_contents
|
50 |
+
|
51 |
+
def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
|
52 |
+
"""
|
53 |
+
Handles the streaming of the answer and updates the history with the new message content
|
54 |
+
|
55 |
+
Args:
|
56 |
+
history (list[ChatMessage]): The current message history
|
57 |
+
event (StreamEvent): The event containing the streamed answer
|
58 |
+
start_streaming (bool): A flag indicating if the streaming has started
|
59 |
+
new_message_content (str): The content of the new message
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
tuple[list[ChatMessage], bool, str]: The updated history, the updated streaming flag and the updated message content
|
63 |
+
"""
|
64 |
+
if start_streaming == False:
|
65 |
+
start_streaming = True
|
66 |
+
history.append(ChatMessage(role="assistant", content = ""))
|
67 |
+
answer_message_content += event["data"]["chunk"].content
|
68 |
+
answer_message_content = parse_output_llm_with_sources(answer_message_content)
|
69 |
+
history[-1] = ChatMessage(role="assistant", content = answer_message_content)
|
70 |
+
# history.append(ChatMessage(role="assistant", content = new_message_content))
|
71 |
+
return history, start_streaming, answer_message_content
|
72 |
+
|
73 |
+
def handle_retrieved_owid_graphs(event :StreamEvent, graphs_html: str) -> str:
|
74 |
+
"""
|
75 |
+
Handles the retrieved OWID graphs and returns the HTML representation of the graphs
|
76 |
+
|
77 |
+
Args:
|
78 |
+
event (StreamEvent): The event containing the retrieved graphs
|
79 |
+
graphs_html (str): The current HTML representation of the graphs
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
str: The updated HTML representation
|
83 |
+
"""
|
84 |
+
try:
|
85 |
+
recommended_content = event["data"]["output"]["recommended_content"]
|
86 |
+
|
87 |
+
unique_graphs = []
|
88 |
+
seen_embeddings = set()
|
89 |
+
|
90 |
+
for x in recommended_content:
|
91 |
+
embedding = x.metadata["returned_content"]
|
92 |
+
|
93 |
+
# Check if the embedding has already been seen
|
94 |
+
if embedding not in seen_embeddings:
|
95 |
+
unique_graphs.append({
|
96 |
+
"embedding": embedding,
|
97 |
+
"metadata": {
|
98 |
+
"source": x.metadata["source"],
|
99 |
+
"category": x.metadata["category"]
|
100 |
+
}
|
101 |
+
})
|
102 |
+
# Add the embedding to the seen set
|
103 |
+
seen_embeddings.add(embedding)
|
104 |
+
|
105 |
+
|
106 |
+
categories = {}
|
107 |
+
for graph in unique_graphs:
|
108 |
+
category = graph['metadata']['category']
|
109 |
+
if category not in categories:
|
110 |
+
categories[category] = []
|
111 |
+
categories[category].append(graph['embedding'])
|
112 |
+
|
113 |
+
|
114 |
+
for category, embeddings in categories.items():
|
115 |
+
graphs_html += f"<h3>{category}</h3>"
|
116 |
+
for embedding in embeddings:
|
117 |
+
graphs_html += f"<div>{embedding}</div>"
|
118 |
+
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error getting graphs: {e}")
|
122 |
+
|
123 |
+
return graphs_html
|
@@ -41,6 +41,10 @@ class OpenAlex():
|
|
41 |
break
|
42 |
|
43 |
df_works = pd.DataFrame(page)
|
|
|
|
|
|
|
|
|
44 |
df_works = df_works.dropna(subset = ["title"])
|
45 |
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
|
46 |
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
|
@@ -51,8 +55,9 @@ class OpenAlex():
|
|
51 |
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
|
52 |
|
53 |
df_works = df_works.drop(columns = ["abstract_inverted_index"])
|
54 |
-
|
55 |
-
|
|
|
56 |
return df_works
|
57 |
else:
|
58 |
raise Exception("Keywords must be a string")
|
@@ -62,11 +67,10 @@ class OpenAlex():
|
|
62 |
|
63 |
scores = reranker.rank(
|
64 |
query,
|
65 |
-
df["content"].tolist()
|
66 |
-
top_k = len(df),
|
67 |
)
|
68 |
-
scores.
|
69 |
-
scores = [x
|
70 |
df["rerank_score"] = scores
|
71 |
return df
|
72 |
|
|
|
41 |
break
|
42 |
|
43 |
df_works = pd.DataFrame(page)
|
44 |
+
|
45 |
+
if df_works.empty:
|
46 |
+
return df_works
|
47 |
+
|
48 |
df_works = df_works.dropna(subset = ["title"])
|
49 |
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
|
50 |
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
|
|
|
55 |
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
|
56 |
|
57 |
df_works = df_works.drop(columns = ["abstract_inverted_index"])
|
58 |
+
df_works["display_name"] = df_works["primary_location"].apply(lambda x :x["source"] if type(x) == dict and 'source' in x else "").apply(lambda x : x["display_name"] if type(x) == dict and "display_name" in x else "")
|
59 |
+
df_works["subtitle"] = df_works["title"].astype(str) + " - " + df_works["display_name"].astype(str) + " - " + df_works["publication_year"].astype(str)
|
60 |
+
|
61 |
return df_works
|
62 |
else:
|
63 |
raise Exception("Keywords must be a string")
|
|
|
67 |
|
68 |
scores = reranker.rank(
|
69 |
query,
|
70 |
+
df["content"].tolist()
|
|
|
71 |
)
|
72 |
+
scores = sorted(scores.results, key = lambda x : x.document.doc_id)
|
73 |
+
scores = [x.score for x in scores]
|
74 |
df["rerank_score"] = scores
|
75 |
return df
|
76 |
|
@@ -1,81 +1,102 @@
|
|
1 |
-
# https://github.com/langchain-ai/langchain/issues/8623
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
|
5 |
-
from langchain_core.retrievers import BaseRetriever
|
6 |
-
from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
-
from langchain_core.documents.base import Document
|
8 |
-
from langchain_core.vectorstores import VectorStore
|
9 |
-
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
10 |
-
|
11 |
-
from typing import List
|
12 |
-
from pydantic import Field
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # https://github.com/langchain-ai/langchain/issues/8623
|
2 |
+
|
3 |
+
# import pandas as pd
|
4 |
+
|
5 |
+
# from langchain_core.retrievers import BaseRetriever
|
6 |
+
# from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
+
# from langchain_core.documents.base import Document
|
8 |
+
# from langchain_core.vectorstores import VectorStore
|
9 |
+
# from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
10 |
+
|
11 |
+
# from typing import List
|
12 |
+
# from pydantic import Field
|
13 |
+
|
14 |
+
# def _add_metadata_and_score(docs: List) -> Document:
|
15 |
+
# # Add score to metadata
|
16 |
+
# docs_with_metadata = []
|
17 |
+
# for i,(doc,score) in enumerate(docs):
|
18 |
+
# doc.page_content = doc.page_content.replace("\r\n"," ")
|
19 |
+
# doc.metadata["similarity_score"] = score
|
20 |
+
# doc.metadata["content"] = doc.page_content
|
21 |
+
# doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
22 |
+
# # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
23 |
+
# docs_with_metadata.append(doc)
|
24 |
+
# return docs_with_metadata
|
25 |
+
|
26 |
+
# class ClimateQARetriever(BaseRetriever):
|
27 |
+
# vectorstore:VectorStore
|
28 |
+
# sources:list = ["IPCC","IPBES","IPOS"]
|
29 |
+
# reports:list = []
|
30 |
+
# threshold:float = 0.6
|
31 |
+
# k_summary:int = 3
|
32 |
+
# k_total:int = 10
|
33 |
+
# namespace:str = "vectors",
|
34 |
+
# min_size:int = 200,
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
# def _get_relevant_documents(
|
39 |
+
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
40 |
+
# ) -> List[Document]:
|
41 |
+
|
42 |
+
# # Check if all elements in the list are either IPCC or IPBES
|
43 |
+
# assert isinstance(self.sources,list)
|
44 |
+
# assert self.sources
|
45 |
+
# assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
|
46 |
+
# assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
47 |
+
|
48 |
+
# # Prepare base search kwargs
|
49 |
+
# filters = {}
|
50 |
+
|
51 |
+
# if len(self.reports) > 0:
|
52 |
+
# filters["short_name"] = {"$in":self.reports}
|
53 |
+
# else:
|
54 |
+
# filters["source"] = { "$in":self.sources}
|
55 |
+
|
56 |
+
# # Search for k_summary documents in the summaries dataset
|
57 |
+
# filters_summaries = {
|
58 |
+
# **filters,
|
59 |
+
# "chunk_type":"text",
|
60 |
+
# "report_type": { "$in":["SPM"]},
|
61 |
+
# }
|
62 |
+
|
63 |
+
# docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
|
64 |
+
# docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
65 |
+
# # docs_summaries = []
|
66 |
+
|
67 |
+
# # Search for k_total - k_summary documents in the full reports dataset
|
68 |
+
# filters_full = {
|
69 |
+
# **filters,
|
70 |
+
# "chunk_type":"text",
|
71 |
+
# "report_type": { "$nin":["SPM"]},
|
72 |
+
# }
|
73 |
+
# k_full = self.k_total - len(docs_summaries)
|
74 |
+
# docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
75 |
+
|
76 |
+
# # Images
|
77 |
+
# filters_image = {
|
78 |
+
# **filters,
|
79 |
+
# "chunk_type":"image"
|
80 |
+
# }
|
81 |
+
# docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
|
82 |
+
|
83 |
+
# # docs_images = []
|
84 |
+
|
85 |
+
# # Concatenate documents
|
86 |
+
# # docs = docs_summaries + docs_full + docs_images
|
87 |
+
|
88 |
+
# # Filter if scores are below threshold
|
89 |
+
# # docs = [x for x in docs if x[1] > self.threshold]
|
90 |
+
|
91 |
+
# docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
92 |
+
|
93 |
+
# # Filter if length are below threshold
|
94 |
+
# docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
|
95 |
+
# docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
|
96 |
+
|
97 |
+
|
98 |
+
# return {
|
99 |
+
# "docs_summaries" : docs_summaries,
|
100 |
+
# "docs_full" : docs_full,
|
101 |
+
# "docs_images" : docs_images,
|
102 |
+
# }
|
@@ -20,3 +20,16 @@ def get_image_from_azure_blob_storage(path):
|
|
20 |
file_object = get_file_from_azure_blob_storage(path)
|
21 |
image = Image.open(file_object)
|
22 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
file_object = get_file_from_azure_blob_storage(path)
|
21 |
image = Image.open(file_object)
|
22 |
return image
|
23 |
+
|
24 |
+
def remove_duplicates_keep_highest_score(documents):
|
25 |
+
unique_docs = {}
|
26 |
+
|
27 |
+
for doc in documents:
|
28 |
+
doc_id = doc.metadata.get('doc_id')
|
29 |
+
if doc_id in unique_docs:
|
30 |
+
if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
|
31 |
+
unique_docs[doc_id] = doc
|
32 |
+
else:
|
33 |
+
unique_docs[doc_id] = doc
|
34 |
+
|
35 |
+
return list(unique_docs.values())
|
@@ -1,12 +1,19 @@
|
|
1 |
|
2 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
5 |
"""from a list of even lenght, make tupple pairs"""
|
6 |
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
|
7 |
|
8 |
|
9 |
-
def serialize_docs(docs):
|
10 |
new_docs = []
|
11 |
for doc in docs:
|
12 |
new_doc = {}
|
@@ -17,7 +24,7 @@ def serialize_docs(docs):
|
|
17 |
|
18 |
|
19 |
|
20 |
-
def parse_output_llm_with_sources(output):
|
21 |
# Split the content into a list of text and "[Doc X]" references
|
22 |
content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
|
23 |
parts = []
|
@@ -32,6 +39,119 @@ def parse_output_llm_with_sources(output):
|
|
32 |
content_parts = "".join(parts)
|
33 |
return content_parts
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def make_html_source(source,i):
|
37 |
meta = source.metadata
|
@@ -108,6 +228,31 @@ def make_html_source(source,i):
|
|
108 |
return card
|
109 |
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def make_html_figure_sources(source,i,img_str):
|
112 |
meta = source.metadata
|
113 |
content = source.page_content.strip()
|
|
|
1 |
|
2 |
import re
|
3 |
+
from collections import defaultdict
|
4 |
+
from climateqa.utils import get_image_from_azure_blob_storage
|
5 |
+
from climateqa.engine.chains.prompts import audience_prompts
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
+
import base64
|
9 |
|
10 |
+
|
11 |
+
def make_pairs(lst:list)->list:
|
12 |
"""from a list of even lenght, make tupple pairs"""
|
13 |
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
|
14 |
|
15 |
|
16 |
+
def serialize_docs(docs:list)->list:
|
17 |
new_docs = []
|
18 |
for doc in docs:
|
19 |
new_doc = {}
|
|
|
24 |
|
25 |
|
26 |
|
27 |
+
def parse_output_llm_with_sources(output:str)->str:
|
28 |
# Split the content into a list of text and "[Doc X]" references
|
29 |
content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
|
30 |
parts = []
|
|
|
39 |
content_parts = "".join(parts)
|
40 |
return content_parts
|
41 |
|
42 |
+
def process_figures(docs:list)->tuple:
|
43 |
+
gallery=[]
|
44 |
+
used_figures =[]
|
45 |
+
figures = '<div class="figures-container"><p></p> </div>'
|
46 |
+
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
|
47 |
+
for i, doc in enumerate(docs_figures):
|
48 |
+
if doc.metadata["chunk_type"] == "image":
|
49 |
+
if doc.metadata["figure_code"] != "N/A":
|
50 |
+
title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
|
51 |
+
else:
|
52 |
+
title = f"{doc.metadata['short_name']}"
|
53 |
+
|
54 |
+
|
55 |
+
if title not in used_figures:
|
56 |
+
used_figures.append(title)
|
57 |
+
try:
|
58 |
+
key = f"Image {i+1}"
|
59 |
+
|
60 |
+
image_path = doc.metadata["image_path"].split("documents/")[1]
|
61 |
+
img = get_image_from_azure_blob_storage(image_path)
|
62 |
+
|
63 |
+
# Convert the image to a byte buffer
|
64 |
+
buffered = BytesIO()
|
65 |
+
max_image_length = 500
|
66 |
+
img_resized = img.resize((max_image_length, int(max_image_length * img.size[1]/img.size[0])))
|
67 |
+
img_resized.save(buffered, format="PNG")
|
68 |
+
|
69 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
70 |
+
|
71 |
+
figures = figures + make_html_figure_sources(doc, i, img_str)
|
72 |
+
gallery.append(img)
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Skipped adding image {i} because of {e}")
|
75 |
+
|
76 |
+
return figures, gallery
|
77 |
+
|
78 |
+
|
79 |
+
def generate_html_graphs(graphs:list)->str:
|
80 |
+
# Organize graphs by category
|
81 |
+
categories = defaultdict(list)
|
82 |
+
for graph in graphs:
|
83 |
+
category = graph['metadata']['category']
|
84 |
+
categories[category].append(graph['embedding'])
|
85 |
+
|
86 |
+
# Begin constructing the HTML
|
87 |
+
html_code = '''
|
88 |
+
<!DOCTYPE html>
|
89 |
+
<html lang="en">
|
90 |
+
<head>
|
91 |
+
<meta charset="UTF-8">
|
92 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
93 |
+
<title>Graphs by Category</title>
|
94 |
+
<style>
|
95 |
+
.tab-content {
|
96 |
+
display: none;
|
97 |
+
}
|
98 |
+
.tab-content.active {
|
99 |
+
display: block;
|
100 |
+
}
|
101 |
+
.tabs {
|
102 |
+
margin-bottom: 20px;
|
103 |
+
}
|
104 |
+
.tab-button {
|
105 |
+
background-color: #ddd;
|
106 |
+
border: none;
|
107 |
+
padding: 10px 20px;
|
108 |
+
cursor: pointer;
|
109 |
+
margin-right: 5px;
|
110 |
+
}
|
111 |
+
.tab-button.active {
|
112 |
+
background-color: #ccc;
|
113 |
+
}
|
114 |
+
</style>
|
115 |
+
<script>
|
116 |
+
function showTab(tabId) {
|
117 |
+
var contents = document.getElementsByClassName('tab-content');
|
118 |
+
var buttons = document.getElementsByClassName('tab-button');
|
119 |
+
for (var i = 0; i < contents.length; i++) {
|
120 |
+
contents[i].classList.remove('active');
|
121 |
+
buttons[i].classList.remove('active');
|
122 |
+
}
|
123 |
+
document.getElementById(tabId).classList.add('active');
|
124 |
+
document.querySelector('button[data-tab="'+tabId+'"]').classList.add('active');
|
125 |
+
}
|
126 |
+
</script>
|
127 |
+
</head>
|
128 |
+
<body>
|
129 |
+
<div class="tabs">
|
130 |
+
'''
|
131 |
+
|
132 |
+
# Add buttons for each category
|
133 |
+
for i, category in enumerate(categories.keys()):
|
134 |
+
active_class = 'active' if i == 0 else ''
|
135 |
+
html_code += f'<button class="tab-button {active_class}" onclick="showTab(\'tab-{i}\')" data-tab="tab-{i}">{category}</button>'
|
136 |
+
|
137 |
+
html_code += '</div>'
|
138 |
+
|
139 |
+
# Add content for each category
|
140 |
+
for i, (category, embeds) in enumerate(categories.items()):
|
141 |
+
active_class = 'active' if i == 0 else ''
|
142 |
+
html_code += f'<div id="tab-{i}" class="tab-content {active_class}">'
|
143 |
+
for embed in embeds:
|
144 |
+
html_code += embed
|
145 |
+
html_code += '</div>'
|
146 |
+
|
147 |
+
html_code += '''
|
148 |
+
</body>
|
149 |
+
</html>
|
150 |
+
'''
|
151 |
+
|
152 |
+
return html_code
|
153 |
+
|
154 |
+
|
155 |
|
156 |
def make_html_source(source,i):
|
157 |
meta = source.metadata
|
|
|
228 |
return card
|
229 |
|
230 |
|
231 |
+
def make_html_papers(df,i):
|
232 |
+
title = df['title'][i]
|
233 |
+
content = df['abstract'][i]
|
234 |
+
url = df['doi'][i]
|
235 |
+
publication_date = df['publication_year'][i]
|
236 |
+
subtitle = df['subtitle'][i]
|
237 |
+
|
238 |
+
card = f"""
|
239 |
+
<div class="card" id="doc{i}">
|
240 |
+
<div class="card-content">
|
241 |
+
<h2>Doc {i+1} - {title}</h2>
|
242 |
+
<p>{content}</p>
|
243 |
+
</div>
|
244 |
+
<div class="card-footer">
|
245 |
+
<span>{subtitle}</span>
|
246 |
+
<a href="{url}" target="_blank" class="pdf-link">
|
247 |
+
<span role="img" aria-label="Open paper">🔗</span>
|
248 |
+
</a>
|
249 |
+
</div>
|
250 |
+
</div>
|
251 |
+
"""
|
252 |
+
|
253 |
+
return card
|
254 |
+
|
255 |
+
|
256 |
def make_html_figure_sources(source,i,img_str):
|
257 |
meta = source.metadata
|
258 |
content = source.page_content.strip()
|
The diff for this file is too large to render.
See raw diff
|
|
The diff for this file is too large to render.
See raw diff
|
|
The diff for this file is too large to render.
See raw diff
|
|
@@ -3,6 +3,61 @@
|
|
3 |
--user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
|
4 |
} */
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
/* fix for huggingface infinite growth*/
|
8 |
main.flex.flex-1.flex-col {
|
@@ -85,7 +140,12 @@ body.dark .tip-box * {
|
|
85 |
font-size:14px !important;
|
86 |
|
87 |
}
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
a {
|
91 |
text-decoration: none;
|
@@ -161,60 +221,111 @@ a {
|
|
161 |
border:none;
|
162 |
}
|
163 |
|
164 |
-
/* .gallery-item > div:hover{
|
165 |
-
background-color:#7494b0 !important;
|
166 |
-
color:white!important;
|
167 |
-
}
|
168 |
|
169 |
-
.
|
170 |
-
|
171 |
}
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
color:#577b9b!important;
|
176 |
}
|
177 |
|
178 |
-
.
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
/*
|
183 |
-
|
|
|
|
|
184 |
} */
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
|
|
|
|
|
|
|
|
187 |
|
188 |
-
|
189 |
-
background:
|
190 |
-
|
191 |
-
} */
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
content: '';
|
196 |
position: absolute;
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
204 |
border-radius: 50%;
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
label.selected{
|
210 |
-
background:none !important;
|
211 |
}
|
212 |
-
|
213 |
-
|
214 |
-
padding:0px !important;
|
215 |
}
|
216 |
|
|
|
|
|
217 |
@media screen and (min-width: 1024px) {
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
.gradio-container {
|
219 |
max-height: calc(100vh - 190px) !important;
|
220 |
overflow: hidden;
|
@@ -225,6 +336,8 @@ label.selected{
|
|
225 |
|
226 |
} */
|
227 |
|
|
|
|
|
228 |
div#tab-examples{
|
229 |
height:calc(100vh - 190px) !important;
|
230 |
overflow-y: scroll !important;
|
@@ -236,6 +349,10 @@ label.selected{
|
|
236 |
overflow-y: scroll !important;
|
237 |
/* overflow-y: auto !important; */
|
238 |
}
|
|
|
|
|
|
|
|
|
239 |
|
240 |
div#sources-figures{
|
241 |
height:calc(100vh - 300px) !important;
|
@@ -243,6 +360,18 @@ label.selected{
|
|
243 |
overflow-y: scroll !important;
|
244 |
}
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
div#tab-config{
|
247 |
height:calc(100vh - 190px) !important;
|
248 |
overflow-y: scroll !important;
|
@@ -409,8 +538,7 @@ span.chatbot > p > img{
|
|
409 |
}
|
410 |
|
411 |
#dropdown-samples{
|
412 |
-
|
413 |
-
/*! border-width:0px !important; */
|
414 |
background:none !important;
|
415 |
|
416 |
}
|
@@ -468,6 +596,10 @@ span.chatbot > p > img{
|
|
468 |
input[type="checkbox"]:checked + .dropdown-content {
|
469 |
display: block;
|
470 |
}
|
|
|
|
|
|
|
|
|
471 |
|
472 |
.dropdown-content {
|
473 |
display: none;
|
@@ -489,7 +621,7 @@ span.chatbot > p > img{
|
|
489 |
border-bottom: 5px solid black;
|
490 |
}
|
491 |
|
492 |
-
|
493 |
border: 1px solid #d0d0d0 !important; /* Light grey background */
|
494 |
border-top: 1px solid #db3434 !important; /* Blue color */
|
495 |
border-right: 1px solid #3498db !important; /* Blue color */
|
@@ -499,41 +631,64 @@ span.chatbot > p > img{
|
|
499 |
animation: spin 2s linear infinite;
|
500 |
display:inline-block;
|
501 |
margin-right:10px !important;
|
502 |
-
|
503 |
|
504 |
-
|
505 |
color:green !important;
|
506 |
font-size:18px;
|
507 |
margin-right:10px !important;
|
508 |
-
|
509 |
|
510 |
-
|
511 |
0% { transform: rotate(0deg); }
|
512 |
100% { transform: rotate(360deg); }
|
513 |
-
|
514 |
|
515 |
|
516 |
-
|
517 |
margin-top:10px !important;
|
518 |
font-size:10px !important;
|
519 |
font-style:italic;
|
520 |
-
|
521 |
|
522 |
-
|
523 |
color:green !important;
|
524 |
-
|
525 |
|
526 |
-
|
527 |
color:orange !important;
|
528 |
-
|
529 |
|
530 |
-
|
531 |
color:red !important;
|
532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
.message-buttons-left.panel.message-buttons.with-avatar {
|
534 |
display: none;
|
535 |
}
|
536 |
|
|
|
537 |
/* Specific fixes for Hugging Face Space iframe */
|
538 |
.h-full {
|
539 |
height: auto !important;
|
|
|
3 |
--user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
|
4 |
} */
|
5 |
|
6 |
+
#tab-recommended_content{
|
7 |
+
padding-top: 0px;
|
8 |
+
padding-left : 0px;
|
9 |
+
padding-right: 0px;
|
10 |
+
}
|
11 |
+
#group-subtabs {
|
12 |
+
/* display: block; */
|
13 |
+
width: 100%; /* Ensures the parent uses the full width */
|
14 |
+
position : sticky;
|
15 |
+
}
|
16 |
+
|
17 |
+
#group-subtabs .tab-container {
|
18 |
+
display: flex;
|
19 |
+
text-align: center;
|
20 |
+
width: 100%; /* Ensures the tabs span the full width */
|
21 |
+
}
|
22 |
+
|
23 |
+
#group-subtabs .tab-container button {
|
24 |
+
flex: 1; /* Makes each button take equal width */
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
#papers-summary-popup button span{
|
29 |
+
/* make label of accordio in bold, center, and bigger */
|
30 |
+
font-size: 16px;
|
31 |
+
font-weight: bold;
|
32 |
+
text-align: center;
|
33 |
+
|
34 |
+
}
|
35 |
+
|
36 |
+
#papers-relevant-popup span{
|
37 |
+
/* make label of accordio in bold, center, and bigger */
|
38 |
+
font-size: 16px;
|
39 |
+
font-weight: bold;
|
40 |
+
text-align: center;
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
#tab-citations .button{
|
46 |
+
padding: 12px 16px;
|
47 |
+
font-size: 16px;
|
48 |
+
font-weight: bold;
|
49 |
+
cursor: pointer;
|
50 |
+
border: none;
|
51 |
+
outline: none;
|
52 |
+
text-align: left;
|
53 |
+
transition: background-color 0.3s ease;
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
.gradio-container {
|
58 |
+
width: 100%!important;
|
59 |
+
max-width: 100% !important;
|
60 |
+
}
|
61 |
|
62 |
/* fix for huggingface infinite growth*/
|
63 |
main.flex.flex-1.flex-col {
|
|
|
140 |
font-size:14px !important;
|
141 |
|
142 |
}
|
143 |
+
.card-content img {
|
144 |
+
display: block;
|
145 |
+
margin: auto;
|
146 |
+
max-width: 100%; /* Ensures the image is responsive */
|
147 |
+
height: auto;
|
148 |
+
}
|
149 |
|
150 |
a {
|
151 |
text-decoration: none;
|
|
|
221 |
border:none;
|
222 |
}
|
223 |
|
|
|
|
|
|
|
|
|
224 |
|
225 |
+
label.selected{
|
226 |
+
background: #93c5fd !important;
|
227 |
}
|
228 |
|
229 |
+
#submit-button{
|
230 |
+
padding:0px !important;
|
|
|
231 |
}
|
232 |
|
233 |
+
#modal-config .block.modal-block.padded {
|
234 |
+
padding-top: 25px;
|
235 |
+
height: 100vh;
|
236 |
+
|
237 |
+
}
|
238 |
+
#modal-config .modal-container{
|
239 |
+
margin: 0px;
|
240 |
+
padding: 0px;
|
241 |
+
}
|
242 |
+
/* Modal styles */
|
243 |
+
#modal-config {
|
244 |
+
position: fixed;
|
245 |
+
top: 0;
|
246 |
+
left: 0;
|
247 |
+
height: 100vh;
|
248 |
+
width: 500px;
|
249 |
+
background-color: white;
|
250 |
+
box-shadow: 2px 0 10px rgba(0, 0, 0, 0.1);
|
251 |
+
z-index: 1000;
|
252 |
+
padding: 15px;
|
253 |
+
transform: none;
|
254 |
+
}
|
255 |
+
#modal-config .close{
|
256 |
+
display: none;
|
257 |
+
}
|
258 |
|
259 |
+
/* Push main content to the right when modal is open */
|
260 |
+
/* .modal ~ * {
|
261 |
+
margin-left: 300px;
|
262 |
+
transition: margin-left 0.3s ease;
|
263 |
} */
|
264 |
|
265 |
+
#modal-config .modal .wrap ul{
|
266 |
+
position:static;
|
267 |
+
top: 100%;
|
268 |
+
left: 0;
|
269 |
+
/* min-height: 100px; */
|
270 |
+
height: 100%;
|
271 |
+
/* margin-top: 0; */
|
272 |
+
z-index: 9999;
|
273 |
+
pointer-events: auto;
|
274 |
+
height: 200px;
|
275 |
+
}
|
276 |
+
#config-button{
|
277 |
+
background: none;
|
278 |
+
border: none;
|
279 |
+
padding: 8px;
|
280 |
+
cursor: pointer;
|
281 |
+
width: 40px;
|
282 |
+
height: 40px;
|
283 |
+
display: flex;
|
284 |
+
align-items: center;
|
285 |
+
justify-content: center;
|
286 |
+
border-radius: 50%;
|
287 |
+
transition: background-color 0.2s;
|
288 |
+
}
|
289 |
|
290 |
+
#config-button::before {
|
291 |
+
content: '⚙️';
|
292 |
+
font-size: 20px;
|
293 |
+
}
|
294 |
|
295 |
+
#config-button:hover {
|
296 |
+
background-color: rgba(0, 0, 0, 0.1);
|
297 |
+
}
|
|
|
298 |
|
299 |
+
#checkbox-config{
|
300 |
+
display: block;
|
|
|
301 |
position: absolute;
|
302 |
+
background: none;
|
303 |
+
border: none;
|
304 |
+
padding: 8px;
|
305 |
+
cursor: pointer;
|
306 |
+
width: 40px;
|
307 |
+
height: 40px;
|
308 |
+
display: flex;
|
309 |
+
align-items: center;
|
310 |
+
justify-content: center;
|
311 |
border-radius: 50%;
|
312 |
+
transition: background-color 0.2s;
|
313 |
+
font-size: 20px;
|
314 |
+
text-align: center;
|
|
|
|
|
|
|
315 |
}
|
316 |
+
#checkbox-config:checked{
|
317 |
+
display: block;
|
|
|
318 |
}
|
319 |
|
320 |
+
|
321 |
+
|
322 |
@media screen and (min-width: 1024px) {
|
323 |
+
/* Additional style for scrollable tab content */
|
324 |
+
/* div#tab-recommended_content {
|
325 |
+
overflow-y: auto;
|
326 |
+
max-height: 80vh;
|
327 |
+
} */
|
328 |
+
|
329 |
.gradio-container {
|
330 |
max-height: calc(100vh - 190px) !important;
|
331 |
overflow: hidden;
|
|
|
336 |
|
337 |
} */
|
338 |
|
339 |
+
|
340 |
+
|
341 |
div#tab-examples{
|
342 |
height:calc(100vh - 190px) !important;
|
343 |
overflow-y: scroll !important;
|
|
|
349 |
overflow-y: scroll !important;
|
350 |
/* overflow-y: auto !important; */
|
351 |
}
|
352 |
+
div#graphs-container{
|
353 |
+
height:calc(100vh - 210px) !important;
|
354 |
+
overflow-y: scroll !important;
|
355 |
+
}
|
356 |
|
357 |
div#sources-figures{
|
358 |
height:calc(100vh - 300px) !important;
|
|
|
360 |
overflow-y: scroll !important;
|
361 |
}
|
362 |
|
363 |
+
div#graphs-container{
|
364 |
+
height:calc(100vh - 300px) !important;
|
365 |
+
max-height: 90vh !important;
|
366 |
+
overflow-y: scroll !important;
|
367 |
+
}
|
368 |
+
|
369 |
+
div#tab-citations{
|
370 |
+
height:calc(100vh - 300px) !important;
|
371 |
+
max-height: 90vh !important;
|
372 |
+
overflow-y: scroll !important;
|
373 |
+
}
|
374 |
+
|
375 |
div#tab-config{
|
376 |
height:calc(100vh - 190px) !important;
|
377 |
overflow-y: scroll !important;
|
|
|
538 |
}
|
539 |
|
540 |
#dropdown-samples{
|
541 |
+
|
|
|
542 |
background:none !important;
|
543 |
|
544 |
}
|
|
|
596 |
input[type="checkbox"]:checked + .dropdown-content {
|
597 |
display: block;
|
598 |
}
|
599 |
+
|
600 |
+
#checkbox-chat input[type="checkbox"] {
|
601 |
+
display: flex !important;
|
602 |
+
}
|
603 |
|
604 |
.dropdown-content {
|
605 |
display: none;
|
|
|
621 |
border-bottom: 5px solid black;
|
622 |
}
|
623 |
|
624 |
+
.loader {
|
625 |
border: 1px solid #d0d0d0 !important; /* Light grey background */
|
626 |
border-top: 1px solid #db3434 !important; /* Blue color */
|
627 |
border-right: 1px solid #3498db !important; /* Blue color */
|
|
|
631 |
animation: spin 2s linear infinite;
|
632 |
display:inline-block;
|
633 |
margin-right:10px !important;
|
634 |
+
}
|
635 |
|
636 |
+
.checkmark{
|
637 |
color:green !important;
|
638 |
font-size:18px;
|
639 |
margin-right:10px !important;
|
640 |
+
}
|
641 |
|
642 |
+
@keyframes spin {
|
643 |
0% { transform: rotate(0deg); }
|
644 |
100% { transform: rotate(360deg); }
|
645 |
+
}
|
646 |
|
647 |
|
648 |
+
.relevancy-score{
|
649 |
margin-top:10px !important;
|
650 |
font-size:10px !important;
|
651 |
font-style:italic;
|
652 |
+
}
|
653 |
|
654 |
+
.score-green{
|
655 |
color:green !important;
|
656 |
+
}
|
657 |
|
658 |
+
.score-orange{
|
659 |
color:orange !important;
|
660 |
+
}
|
661 |
|
662 |
+
.score-red{
|
663 |
color:red !important;
|
664 |
+
}
|
665 |
+
|
666 |
+
/* Mobile specific adjustments */
|
667 |
+
@media screen and (max-width: 767px) {
|
668 |
+
div#tab-recommended_content {
|
669 |
+
max-height: 50vh; /* Reduce height for smaller screens */
|
670 |
+
overflow-y: auto;
|
671 |
+
}
|
672 |
+
}
|
673 |
+
|
674 |
+
/* Additional style for scrollable tab content */
|
675 |
+
div#tab-saved-graphs {
|
676 |
+
overflow-y: auto; /* Enable vertical scrolling */
|
677 |
+
max-height: 80vh; /* Adjust height as needed */
|
678 |
+
}
|
679 |
+
|
680 |
+
/* Mobile specific adjustments */
|
681 |
+
@media screen and (max-width: 767px) {
|
682 |
+
div#tab-saved-graphs {
|
683 |
+
max-height: 50vh; /* Reduce height for smaller screens */
|
684 |
+
overflow-y: auto;
|
685 |
+
}
|
686 |
+
}
|
687 |
.message-buttons-left.panel.message-buttons.with-avatar {
|
688 |
display: none;
|
689 |
}
|
690 |
|
691 |
+
|
692 |
/* Specific fixes for Hugging Face Space iframe */
|
693 |
.h-full {
|
694 |
height: auto !important;
|