timeki TheoLvs commited on
Commit
bcc8503
1 Parent(s): 14a5a97

Add content recommandation (#17)

Browse files

- First commit CQA with Agents (481f3b1453fde4c19018915d101d575b6ea25a3e)
- Connecting to front (088e816846227b694f2d56ca3af739cc010de4bc)
- Update app.py (fd67e156abd0293625d2b73765bda2d3905fa5de)
- agents mode (99e91d83efb40b6cfec5a887f0d464eaffd09431)
- bugfixs (72edd2d9e6ad64e3ecb59505b744cd415b9a6776)
- Update requirements.txt (ae857ef845ac5b3baed5ef7de1e1b8b63874947e)
- Update requirements.txt (25e32e6bdf0ca289bef8617d92ad77d7edeac19f)
- add dora graph recommandation (6b43c8608bbebdffdebdbd315d70c7df60199fab)
- Merge branch 'bugfix/add_dummy_searchs' into feature/graph_recommandation (aa904c191cf4dd783e3ec870883a06746fe52bf8)
- Update .gitignore (8b71e5e7c71f665a8515d8bdbfee913fdaff12f0)
- Update .gitignore (9dd246e7f975322a6be247188bffb7aa0f6d954e)
- add message types (bed4e9bbfb6f7c823789daf54c443f3f27198b45)
- WIP (6a39d69f772f97cef8b0b551a888ace822713753)
- WIP (7da1a3ac2237ded7b9891fdcda32d0674a9b7b4a)
- add steps and fix css for gradio 5 (df5d08d8710beacb04a9c0c281a195c6dd7cc800)
- remove unused code (4ab651938b1c2af7c0d8f155488820e47b42c6c8)
- Update .gitignore (5228f5c0f26f78825d572edbe82200ee3ece6a60)
- fixs (ccd4b9e7b0d2b6d7c0a8b9f2a6513609b4bfe3e1)
- Merge branch 'pr/14' into feature/graph_recommandation (6edd6c2f1c5215b2e3e72b2c36759b934d748606)
- Merge branch 'feature/add_steps_display' into feature/graph_recommandation (89a69e623fa7fec59d135eb57624bcb8d8f67985)
- WIP (49acaf1b850b0914bc1f5d52ceab47d2a22fe944)
- add graphs from ourworldindata (7335378313ce70a4d4ec305a6114e8e6d167af12)
- Merge branch 'main' into feature/graph_recommandation (0c4d82b36b5d6d2f79215460720118c746d88804)
- fix merge (57a1ed70b9a0ad0e48de283e2f044e1e38eac8e1)
- Merge branch 'main' into feature/graph_recommandation (196d79336e51d5deef0215353095954d98d4165b)
- add notification on tabs (484fc0d6f3d80e3fe3afb6ffff20560ef35b6b7c)
- switch owid vectorstore to pinecone (5664fc84569d8e455ce09e2e29c48d7881e126e3)
- remove unused code (fd2ccc64d648490a0ee5acf2159827091d9fc123)
- Move configuration to an other pannel (5c3a3a4b99323e93ebf0c852bc2c2a5401929dae)
- Display images before answering (12c9afe58a2de4b7413bc860f808501c6ee2a6aa)
- Display image in a parallel task (6c5a20c1cab94c503e004acc0991ce4149835a7f)
- Merge branch 'main' into feature/graph_recommandation (b58c53f25d1cc85cfe82bbd452f8d5d11c306da3)
- Add step by step notebook to run steps of langgraphs (a059c938ce111e6673981ad32785d0d4e1c0d177)
- put event handling in separate file (76603dfba448efa3334c2bcb0169f8e1fbd92c60)
- Rerank documents and force summary for policy makers (d562d3805e6230ad8db525229b4bdde42185e721)
- edit prints for logs (9609df9642e477c67d7bf03a83becaef5c3e2b6f)
- Add OpenAlex papers recommandation (c3b815e6e630f551740188fdef719d0df16acd7e)
- Merge branch 'add_openalex_papers' into feature/graph_recommandation (6541df34bd07c0c7e2d1f1c4ffbf03c2778187a3)
- fix merge (d78271b334fb803de4c424420c297cc6983f0c93)
- add time measure (09457a7da47af1e265d9c5c906cecbf2d4586174)
- fix css (22ac4eb7878d7c9e0f3581972f399ad3119dfdde)
- remove unused code (4c4fe76848d84079766d5ec6e94c1e941d5fea01)
- fix answer latency when having multiple sources (40084ba7b8c0424741cfe3d2142eea4d24683c07)
- remove unused code (58bf75084a4d964a486a5a602e81a17e56e1cc82)
- Change number of figures (781788244a88876bdecfc5b3ddbcf65e8bd9ead6)
- front UI change (c9346b33c3251304f5b52f5c837347f10842b87c)
- add owid subtitles (7ec5d9ecfda9f0688da0521126034de80cf9dffc)
- move code from papers in separate file (363fe2eb8548665fb4fe577db1dce7ea682bac8d)
- add search only button (be2863be1c4cfdedca78497d548be321d218c312)
- config in a modal object (7283e6a6e173ab3ced6c21e292e1b6f94e659141)
- few code cleaning (094ee349527297a47eaf9bbb0903651170be47ec)
- update display and fix search only (d396732ee8df5f4aa33c10cca64d6b05d197e4d5)
- Update 20241104 - CQA - StepByStep CQA.ipynb (d7adcaada52ac5feae76017b457623ab308bbfc8)


Co-authored-by: Theo Alves <[email protected]>

app.py CHANGED
@@ -1,13 +1,12 @@
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
4
- from climateqa.knowledge.openalex import OpenAlex
5
  from sentence_transformers import CrossEncoder
6
 
7
  # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
8
- oa = OpenAlex()
9
 
10
  import gradio as gr
 
11
  import pandas as pd
12
  import numpy as np
13
  import os
@@ -29,7 +28,9 @@ from utils import create_user_id
29
 
30
  from gradio_modal import Modal
31
 
 
32
 
 
33
 
34
  # ClimateQ&A imports
35
  from climateqa.engine.llm import get_llm
@@ -39,13 +40,15 @@ from climateqa.engine.reranker import get_reranker
39
  from climateqa.engine.embeddings import get_embeddings_function
40
  from climateqa.engine.chains.prompts import audience_prompts
41
  from climateqa.sample_questions import QUESTIONS
42
- from climateqa.constants import POSSIBLE_REPORTS
43
  from climateqa.utils import get_image_from_azure_blob_storage
44
- from climateqa.engine.keywords import make_keywords_chain
45
- # from climateqa.engine.chains.answer_rag import make_rag_papers_chain
46
- from climateqa.engine.graph import make_graph_agent,display_graph
 
 
47
 
48
- from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox
49
 
50
  # Load environment variables in local mode
51
  try:
@@ -54,6 +57,8 @@ try:
54
  except Exception as e:
55
  pass
56
 
 
 
57
  # Set up Gradio Theme
58
  theme = gr.themes.Base(
59
  primary_hue="blue",
@@ -104,52 +109,47 @@ CITATION_TEXT = r"""@misc{climateqa,
104
 
105
 
106
  # Create vectorstore and retriever
107
- vectorstore = get_pinecone_vectorstore(embeddings_function)
108
- llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
109
- reranker = get_reranker("large")
110
- agent = make_graph_agent(llm,vectorstore,reranker)
111
 
 
 
112
 
 
113
 
 
 
 
114
 
115
- async def chat(query,history,audience,sources,reports):
116
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
117
  (messages in gradio format, messages in langchain format, source documents)"""
118
 
119
  date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
  print(f">> NEW QUESTION ({date_now}) : {query}")
121
 
122
- if audience == "Children":
123
- audience_prompt = audience_prompts["children"]
124
- elif audience == "General public":
125
- audience_prompt = audience_prompts["general"]
126
- elif audience == "Experts":
127
- audience_prompt = audience_prompts["experts"]
128
- else:
129
- audience_prompt = audience_prompts["experts"]
130
 
131
  # Prepare default values
132
- if len(sources) == 0:
133
- sources = ["IPCC"]
134
 
135
- # if len(reports) == 0: # TODO
136
- reports = []
137
 
138
- inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources}
139
  result = agent.astream_events(inputs,version = "v1")
140
-
141
- # path_reformulation = "/logs/reformulation/final_output"
142
- # path_keywords = "/logs/keywords/final_output"
143
- # path_retriever = "/logs/find_documents/final_output"
144
- # path_answer = "/logs/answer/streamed_output_str/-"
145
 
146
  docs = []
 
 
147
  docs_html = ""
148
  output_query = ""
149
  output_language = ""
150
  output_keywords = ""
151
- gallery = []
152
  start_streaming = False
 
153
  figures = '<div class="figures-container"><p></p> </div>'
154
 
155
  steps_display = {
@@ -166,36 +166,29 @@ async def chat(query,history,audience,sources,reports):
166
  node = event["metadata"]["langgraph_node"]
167
 
168
  if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
169
- try:
170
- docs = event["data"]["output"]["documents"]
171
- docs_html = []
172
- textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
173
- for i, d in enumerate(textual_docs, 1):
174
- if d.metadata["chunk_type"] == "text":
175
- docs_html.append(make_html_source(d, i))
176
-
177
- used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
178
- history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
179
-
180
- docs_html = "".join(docs_html)
181
-
182
- except Exception as e:
183
- print(f"Error getting documents: {e}")
184
- print(event)
185
-
186
  elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
187
- event_description,display_output = steps_display[node]
188
  if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
189
  history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
190
 
191
  elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
192
- if start_streaming == False:
193
- start_streaming = True
194
- history.append(ChatMessage(role="assistant", content = ""))
195
- answer_message_content += event["data"]["chunk"].content
196
- answer_message_content = parse_output_llm_with_sources(answer_message_content)
197
- history[-1] = ChatMessage(role="assistant", content = answer_message_content)
198
- # history.append(ChatMessage(role="assistant", content = new_message_content))
199
 
200
  if event["name"] == "transform_query" and event["event"] =="on_chain_end":
201
  if hasattr(history[-1],"content"):
@@ -204,7 +197,7 @@ async def chat(query,history,audience,sources,reports):
204
  if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
205
  print("X")
206
 
207
- yield history,docs_html,output_query,output_language,gallery, figures #,output_query,output_keywords
208
 
209
  except Exception as e:
210
  print(event, "has failed")
@@ -232,68 +225,7 @@ async def chat(query,history,audience,sources,reports):
232
  print(f"Error logging on Azure Blob Storage: {e}")
233
  raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
234
 
235
-
236
-
237
-
238
- # image_dict = {}
239
- # for i,doc in enumerate(docs):
240
-
241
- # if doc.metadata["chunk_type"] == "image":
242
- # try:
243
- # key = f"Image {i+1}"
244
- # image_path = doc.metadata["image_path"].split("documents/")[1]
245
- # img = get_image_from_azure_blob_storage(image_path)
246
-
247
- # # Convert the image to a byte buffer
248
- # buffered = BytesIO()
249
- # img.save(buffered, format="PNG")
250
- # img_str = base64.b64encode(buffered.getvalue()).decode()
251
-
252
- # # Embedding the base64 string in Markdown
253
- # markdown_image = f"![Alt text](data:image/png;base64,{img_str})"
254
- # image_dict[key] = {"img":img,"md":markdown_image,"short_name": doc.metadata["short_name"],"figure_code":doc.metadata["figure_code"],"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"], "img_str" : img_str}
255
- # except Exception as e:
256
- # print(f"Skipped adding image {i} because of {e}")
257
-
258
- # if len(image_dict) > 0:
259
-
260
- # gallery = [x["img"] for x in list(image_dict.values())]
261
- # img = list(image_dict.values())[0]
262
- # img_md = img["md"]
263
- # img_caption = img["caption"]
264
- # img_code = img["figure_code"]
265
- # if img_code != "N/A":
266
- # img_name = f"{img['key']} - {img['figure_code']}"
267
- # else:
268
- # img_name = f"{img['key']}"
269
-
270
- # history.append(ChatMessage(role="assistant", content = f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"))
271
-
272
- docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
273
- for i, doc in enumerate(docs_figures):
274
- if doc.metadata["chunk_type"] == "image":
275
- try:
276
- key = f"Image {i+1}"
277
-
278
- image_path = doc.metadata["image_path"].split("documents/")[1]
279
- img = get_image_from_azure_blob_storage(image_path)
280
-
281
- # Convert the image to a byte buffer
282
- buffered = BytesIO()
283
- img.save(buffered, format="PNG")
284
- img_str = base64.b64encode(buffered.getvalue()).decode()
285
-
286
- figures = figures + make_html_figure_sources(doc, i, img_str)
287
-
288
- gallery.append(img)
289
-
290
- except Exception as e:
291
- print(f"Skipped adding image {i} because of {e}")
292
-
293
-
294
-
295
-
296
- yield history,docs_html,output_query,output_language,gallery, figures#,output_query,output_keywords
297
 
298
 
299
  def save_feedback(feed: str, user_id):
@@ -317,29 +249,9 @@ def log_on_azure(file, logs, share_client):
317
  file_client.upload_file(logs)
318
 
319
 
320
- def generate_keywords(query):
321
- chain = make_keywords_chain(llm)
322
- keywords = chain.invoke(query)
323
- keywords = " AND ".join(keywords["keywords"])
324
- return keywords
325
 
326
 
327
 
328
- papers_cols_widths = {
329
- "doc":50,
330
- "id":100,
331
- "title":300,
332
- "doi":100,
333
- "publication_year":100,
334
- "abstract":500,
335
- "rerank_score":100,
336
- "is_oa":50,
337
- }
338
-
339
- papers_cols = list(papers_cols_widths.keys())
340
- papers_cols_widths = list(papers_cols_widths.values())
341
-
342
-
343
  # --------------------------------------------------------------------
344
  # Gradio
345
  # --------------------------------------------------------------------
@@ -370,10 +282,23 @@ def vote(data: gr.LikeData):
370
  else:
371
  print(data)
372
 
 
 
 
 
 
 
 
 
373
 
374
 
375
  with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
 
 
 
 
376
 
 
377
  with gr.Tab("ClimateQ&A"):
378
 
379
  with gr.Row(elem_id="chatbot-row"):
@@ -396,12 +321,16 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
396
 
397
  with gr.Row(elem_id = "input-message"):
398
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
399
-
 
 
 
 
400
 
401
- with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):
402
 
403
 
404
- with gr.Tabs() as tabs:
405
  with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
406
 
407
  examples_hidden = gr.Textbox(visible = False)
@@ -427,91 +356,210 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
427
  )
428
 
429
  samples.append(group_examples)
 
 
 
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
- with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
433
- sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
434
- docs_textbox = gr.State("")
435
-
436
-
437
-
438
 
439
- # with Modal(visible = False) as config_modal:
440
- with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
441
 
442
- gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
 
443
 
444
 
445
- dropdown_sources = gr.CheckboxGroup(
446
- ["IPCC", "IPBES","IPOS"],
447
- label="Select source",
448
- value=["IPCC"],
449
- interactive=True,
450
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
- dropdown_reports = gr.Dropdown(
453
- POSSIBLE_REPORTS,
454
- label="Or select specific reports",
455
- multiselect=True,
456
- value=None,
457
- interactive=True,
458
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
- dropdown_audience = gr.Dropdown(
461
- ["Children","General public","Experts"],
462
- label="Select audience",
463
- value="Experts",
464
- interactive=True,
465
- )
466
 
467
- output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
468
- output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
469
 
 
 
 
 
 
 
 
 
 
 
470
 
471
- with gr.Tab("Figures",elem_id = "tab-figures",id = 3):
472
- with Modal(visible=False, elem_id="modal_figure_galery") as modal:
473
- gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
474
-
475
- show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True)
476
- show_full_size_figures.click(lambda : Modal(visible=True),None,modal)
477
 
478
- figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
479
-
480
 
 
 
 
 
 
 
481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
 
483
 
484
 
485
  #---------------------------------------------------------------------------------------
486
  # OTHER TABS
487
  #---------------------------------------------------------------------------------------
488
 
 
489
 
490
- # with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
491
- # gallery_component = gr.Gallery(object_fit='cover')
492
 
493
- # with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
494
 
495
- # with gr.Row():
496
- # with gr.Column(scale=1):
497
- # query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
498
- # keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
499
- # after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
500
- # search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)
501
 
502
- # with gr.Column(scale=7):
 
 
 
 
 
 
503
 
504
- # with gr.Tab("Summary",elem_id="papers-summary-tab"):
505
- # papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")
 
 
 
 
506
 
507
- # with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
508
- # papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
509
 
510
- # with gr.Tab("Citations network",elem_id="papers-network-tab"):
511
- # citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
512
 
513
 
514
-
515
  with gr.Tab("About",elem_classes = "max-height other-tabs"):
516
  with gr.Row():
517
  with gr.Column(scale=1):
@@ -519,13 +567,15 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
519
 
520
 
521
 
522
- gr.Markdown("""
523
- ### More info
524
- - See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
525
- - Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
526
-
527
- ### Citation
528
- """)
 
 
529
  with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
530
  # # Display citation label and text)
531
  gr.Textbox(
@@ -538,25 +588,61 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
538
 
539
 
540
 
541
- def start_chat(query,history):
542
- # history = history + [(query,None)]
543
- # history = [tuple(x) for x in history]
544
  history = history + [ChatMessage(role="user", content=query)]
545
- return (gr.update(interactive = False),gr.update(selected=1),history)
 
 
 
546
 
547
  def finish_chat():
548
- return (gr.update(interactive = True,value = ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  (textbox
551
- .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
552
- .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,figures_cards],concurrency_limit = 8,api_name = "chat_textbox")
553
  .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
 
554
  )
555
 
556
  (examples_hidden
557
- .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
558
- .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component, figures_cards],concurrency_limit = 8,api_name = "chat_examples")
559
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
 
560
  )
561
 
562
 
@@ -567,9 +653,23 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
567
  return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
568
 
569
 
 
 
 
 
 
 
 
570
 
 
571
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
572
 
 
 
 
 
 
 
573
 
574
  demo.queue()
575
 
 
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
 
4
  from sentence_transformers import CrossEncoder
5
 
6
  # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
 
7
 
8
  import gradio as gr
9
+ from gradio_modal import Modal
10
  import pandas as pd
11
  import numpy as np
12
  import os
 
28
 
29
  from gradio_modal import Modal
30
 
31
+ from PIL import Image
32
 
33
+ from langchain_core.runnables.schema import StreamEvent
34
 
35
  # ClimateQ&A imports
36
  from climateqa.engine.llm import get_llm
 
40
  from climateqa.engine.embeddings import get_embeddings_function
41
  from climateqa.engine.chains.prompts import audience_prompts
42
  from climateqa.sample_questions import QUESTIONS
43
+ from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
44
  from climateqa.utils import get_image_from_azure_blob_storage
45
+ from climateqa.engine.graph import make_graph_agent
46
+ from climateqa.engine.embeddings import get_embeddings_function
47
+ from climateqa.engine.chains.retrieve_papers import find_papers
48
+
49
+ from front.utils import serialize_docs,process_figures
50
 
51
+ from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
52
 
53
  # Load environment variables in local mode
54
  try:
 
57
  except Exception as e:
58
  pass
59
 
60
+ import requests
61
+
62
  # Set up Gradio Theme
63
  theme = gr.themes.Base(
64
  primary_hue="blue",
 
109
 
110
 
111
  # Create vectorstore and retriever
112
+ vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX"))
113
+ vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
 
 
114
 
115
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
116
+ reranker = get_reranker("nano")
117
 
118
+ agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
119
 
120
+ def update_config_modal_visibility(config_open):
121
+ new_config_visibility_status = not config_open
122
+ return gr.update(visible=new_config_visibility_status), new_config_visibility_status
123
 
124
+ async def chat(query, history, audience, sources, reports, relevant_content_sources, search_only):
125
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
126
  (messages in gradio format, messages in langchain format, source documents)"""
127
 
128
  date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
129
  print(f">> NEW QUESTION ({date_now}) : {query}")
130
 
131
+ audience_prompt = init_audience(audience)
 
 
 
 
 
 
 
132
 
133
  # Prepare default values
134
+ if sources is None or len(sources) == 0:
135
+ sources = ["IPCC", "IPBES", "IPOS"]
136
 
137
+ if reports is None or len(reports) == 0:
138
+ reports = []
139
 
140
+ inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources, "relevant_content_sources" : relevant_content_sources, "search_only": search_only}
141
  result = agent.astream_events(inputs,version = "v1")
142
+
 
 
 
 
143
 
144
  docs = []
145
+ used_figures=[]
146
+ related_contents = []
147
  docs_html = ""
148
  output_query = ""
149
  output_language = ""
150
  output_keywords = ""
 
151
  start_streaming = False
152
+ graphs_html = ""
153
  figures = '<div class="figures-container"><p></p> </div>'
154
 
155
  steps_display = {
 
166
  node = event["metadata"]["langgraph_node"]
167
 
168
  if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
169
+ docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
170
+
171
+ elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
172
+
173
+ intent = event["data"]["output"]["intent"]
174
+ if "language" in event["data"]["output"]:
175
+ output_language = event["data"]["output"]["language"]
176
+ else :
177
+ output_language = "English"
178
+ history[-1].content = f"Language identified : {output_language} \n Intent identified : {intent}"
179
+
180
+
 
 
 
 
 
181
  elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
182
+ event_description, display_output = steps_display[node]
183
  if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
184
  history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
185
 
186
  elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
187
+ history, start_streaming, answer_message_content = stream_answer(history, event, start_streaming, answer_message_content)
188
+
189
+ elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
190
+ graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
191
+
 
 
192
 
193
  if event["name"] == "transform_query" and event["event"] =="on_chain_end":
194
  if hasattr(history[-1],"content"):
 
197
  if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
198
  print("X")
199
 
200
+ yield history, docs_html, output_query, output_language, related_contents , graphs_html, #,output_query,output_keywords
201
 
202
  except Exception as e:
203
  print(event, "has failed")
 
225
  print(f"Error logging on Azure Blob Storage: {e}")
226
  raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
227
 
228
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  def save_feedback(feed: str, user_id):
 
249
  file_client.upload_file(logs)
250
 
251
 
 
 
 
 
 
252
 
253
 
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  # --------------------------------------------------------------------
256
  # Gradio
257
  # --------------------------------------------------------------------
 
282
  else:
283
  print(data)
284
 
285
+ def save_graph(saved_graphs_state, embedding, category):
286
+ print(f"\nCategory:\n{saved_graphs_state}\n")
287
+ if category not in saved_graphs_state:
288
+ saved_graphs_state[category] = []
289
+ if embedding not in saved_graphs_state[category]:
290
+ saved_graphs_state[category].append(embedding)
291
+ return saved_graphs_state, gr.Button("Graph Saved")
292
+
293
 
294
 
295
  with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
296
+ chat_completed_state = gr.State(0)
297
+ current_graphs = gr.State([])
298
+ saved_graphs = gr.State({})
299
+ config_open = gr.State(False)
300
 
301
+
302
  with gr.Tab("ClimateQ&A"):
303
 
304
  with gr.Row(elem_id="chatbot-row"):
 
321
 
322
  with gr.Row(elem_id = "input-message"):
323
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
324
+
325
+ config_button = gr.Button("",elem_id="config-button")
326
+ # config_checkbox_button = gr.Checkbox(label = '⚙️', value="show",visible=True, interactive=True, elem_id="checkbox-config")
327
+
328
+
329
 
330
+ with gr.Column(scale=2, variant="panel",elem_id = "right-panel"):
331
 
332
 
333
+ with gr.Tabs(elem_id = "right_panel_tab") as tabs:
334
  with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
335
 
336
  examples_hidden = gr.Textbox(visible = False)
 
356
  )
357
 
358
  samples.append(group_examples)
359
+
360
+ # with gr.Tab("Configuration", id = 10, ) as tab_config:
361
+ # # gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
362
 
363
+ # pass
364
+
365
+ # with gr.Row():
366
+
367
+ # dropdown_sources = gr.CheckboxGroup(
368
+ # ["IPCC", "IPBES","IPOS"],
369
+ # label="Select source",
370
+ # value=["IPCC"],
371
+ # interactive=True,
372
+ # )
373
+ # dropdown_external_sources = gr.CheckboxGroup(
374
+ # ["IPCC figures","OpenAlex", "OurWorldInData"],
375
+ # label="Select database to search for relevant content",
376
+ # value=["IPCC figures"],
377
+ # interactive=True,
378
+ # )
379
+
380
+ # dropdown_reports = gr.Dropdown(
381
+ # POSSIBLE_REPORTS,
382
+ # label="Or select specific reports",
383
+ # multiselect=True,
384
+ # value=None,
385
+ # interactive=True,
386
+ # )
387
+
388
+ # search_only = gr.Checkbox(label="Search only without chating", value=False, interactive=True, elem_id="checkbox-chat")
389
+
390
+
391
+ # dropdown_audience = gr.Dropdown(
392
+ # ["Children","General public","Experts"],
393
+ # label="Select audience",
394
+ # value="Experts",
395
+ # interactive=True,
396
+ # )
397
+
398
+
399
+ # after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False)
400
+
401
 
402
+ # output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
403
+ # output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
 
 
 
 
404
 
 
 
405
 
406
+ # dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after])
407
+ # # dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after], visible=True)
408
 
409
 
410
+ with gr.Tab("Sources",elem_id = "tab-sources",id = 1) as tab_sources:
411
+ sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
412
+
413
+
414
+
415
+ with gr.Tab("Recommended content", elem_id="tab-recommended_content",id=2) as tab_recommended_content:
416
+ with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
417
+
418
+ with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
419
+ sources_raw = gr.State()
420
+
421
+ with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
422
+ gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
423
+
424
+ show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True)
425
+ show_full_size_figures.click(lambda : Modal(visible=True),None,figure_modal)
426
+
427
+ figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
428
+
429
+
430
+
431
+ with gr.Tab("Papers",elem_id = "tab-citations",id = 4) as tab_papers:
432
+ # btn_summary = gr.Button("Summary")
433
+ # Fenêtre simulée pour le Summary
434
+ with gr.Accordion(visible=True, elem_id="papers-summary-popup", label= "See summary of relevant papers", open= False) as summary_popup:
435
+ papers_summary = gr.Markdown("", visible=True, elem_id="papers-summary")
436
+
437
+ # btn_relevant_papers = gr.Button("Relevant papers")
438
+ # Fenêtre simulée pour les Relevant Papers
439
+ with gr.Accordion(visible=True, elem_id="papers-relevant-popup",label= "See relevant papers", open= False) as relevant_popup:
440
+ papers_html = gr.HTML(show_label=False, elem_id="papers-textbox")
441
+
442
+ btn_citations_network = gr.Button("Explore papers citations network")
443
+ # Fenêtre simulée pour le Citations Network
444
+ with Modal(visible=False) as papers_modal:
445
+ citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
446
+ btn_citations_network.click(lambda: Modal(visible=True), None, papers_modal)
447
+
448
+
449
+
450
+ with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
451
+
452
+ graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",elem_id="graphs-container")
453
+ current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container])
454
+
455
+ with Modal(visible=False,elem_id="modal-config") as config_modal:
456
+ gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
457
 
458
+
459
+ # with gr.Row():
460
+
461
+ dropdown_sources = gr.CheckboxGroup(
462
+ ["IPCC", "IPBES","IPOS"],
463
+ label="Select source (by default search in all sources)",
464
+ value=["IPCC"],
465
+ interactive=True,
466
+ )
467
+
468
+ dropdown_reports = gr.Dropdown(
469
+ POSSIBLE_REPORTS,
470
+ label="Or select specific reports",
471
+ multiselect=True,
472
+ value=None,
473
+ interactive=True,
474
+ )
475
+
476
+ dropdown_external_sources = gr.CheckboxGroup(
477
+ ["IPCC figures","OpenAlex", "OurWorldInData"],
478
+ label="Select database to search for relevant content",
479
+ value=["IPCC figures"],
480
+ interactive=True,
481
+ )
482
 
483
+ search_only = gr.Checkbox(label="Search only for recommended content without chating", value=False, interactive=True, elem_id="checkbox-chat")
 
 
 
 
 
484
 
 
 
485
 
486
+ dropdown_audience = gr.Dropdown(
487
+ ["Children","General public","Experts"],
488
+ label="Select audience",
489
+ value="Experts",
490
+ interactive=True,
491
+ )
492
+
493
+
494
+ after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False)
495
+
496
 
497
+ output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
498
+ output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
 
 
 
 
499
 
 
 
500
 
501
+ dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after])
502
+
503
+ close_config_modal = gr.Button("Validate and Close",elem_id="close-config-modal")
504
+ close_config_modal.click(fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open])
505
+ # dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after], visible=True)
506
+
507
 
508
+
509
+ config_button.click(fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open])
510
+
511
+ # with gr.Tab("OECD",elem_id = "tab-oecd",id = 6):
512
+ # oecd_indicator = "RIVER_FLOOD_RP100_POP_SH"
513
+ # oecd_topic = "climate"
514
+ # oecd_latitude = "46.8332"
515
+ # oecd_longitude = "5.3725"
516
+ # oecd_zoom = "5.6442"
517
+ # # Create the HTML content with the iframe
518
+ # iframe_html = f"""
519
+ # <iframe src="https://localdataportal.oecd.org/maps.html?indicator={oecd_indicator}&topic={oecd_topic}&latitude={oecd_latitude}&longitude={oecd_longitude}&zoom={oecd_zoom}"
520
+ # width="100%" height="600" frameborder="0" style="border:0;" allowfullscreen></iframe>
521
+ # """
522
+ # oecd_textbox = gr.HTML(iframe_html, show_label=False, elem_id="oecd-textbox")
523
 
524
+
525
 
526
 
527
  #---------------------------------------------------------------------------------------
528
  # OTHER TABS
529
  #---------------------------------------------------------------------------------------
530
 
531
+ # with gr.Tab("Settings",elem_id = "tab-config",id = 2):
532
 
533
+ # gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
 
534
 
 
535
 
536
+ # dropdown_sources = gr.CheckboxGroup(
537
+ # ["IPCC", "IPBES","IPOS", "OpenAlex"],
538
+ # label="Select source",
539
+ # value=["IPCC"],
540
+ # interactive=True,
541
+ # )
542
 
543
+ # dropdown_reports = gr.Dropdown(
544
+ # POSSIBLE_REPORTS,
545
+ # label="Or select specific reports",
546
+ # multiselect=True,
547
+ # value=None,
548
+ # interactive=True,
549
+ # )
550
 
551
+ # dropdown_audience = gr.Dropdown(
552
+ # ["Children","General public","Experts"],
553
+ # label="Select audience",
554
+ # value="Experts",
555
+ # interactive=True,
556
+ # )
557
 
 
 
558
 
559
+ # output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
560
+ # output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
561
 
562
 
 
563
  with gr.Tab("About",elem_classes = "max-height other-tabs"):
564
  with gr.Row():
565
  with gr.Column(scale=1):
 
567
 
568
 
569
 
570
+ gr.Markdown(
571
+ """
572
+ ### More info
573
+ - See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
574
+ - Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
575
+
576
+ ### Citation
577
+ """
578
+ )
579
  with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
580
  # # Display citation label and text)
581
  gr.Textbox(
 
588
 
589
 
590
 
591
+ def start_chat(query,history,search_only):
 
 
592
  history = history + [ChatMessage(role="user", content=query)]
593
+ if search_only:
594
+ return (gr.update(interactive = False),gr.update(selected=1),history)
595
+ else:
596
+ return (gr.update(interactive = False),gr.update(selected=2),history)
597
 
598
  def finish_chat():
599
+ return gr.update(interactive = True,value = "")
600
+
601
+ # Initialize visibility states
602
+ summary_visible = False
603
+ relevant_visible = False
604
+
605
+ # Functions to toggle visibility
606
+ def toggle_summary_visibility():
607
+ global summary_visible
608
+ summary_visible = not summary_visible
609
+ return gr.update(visible=summary_visible)
610
+
611
+ def toggle_relevant_visibility():
612
+ global relevant_visible
613
+ relevant_visible = not relevant_visible
614
+ return gr.update(visible=relevant_visible)
615
+
616
 
617
+ def change_completion_status(current_state):
618
+ current_state = 1 - current_state
619
+ return current_state
620
+
621
+ def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
622
+ sources_number = sources_textbox.count("<h2>")
623
+ figures_number = figures_cards.count("<h2>")
624
+ graphs_number = current_graphs.count("<iframe")
625
+ papers_number = papers_html.count("<h2>")
626
+ sources_notif_label = f"Sources ({sources_number})"
627
+ figures_notif_label = f"Figures ({figures_number})"
628
+ graphs_notif_label = f"Graphs ({graphs_number})"
629
+ papers_notif_label = f"Papers ({papers_number})"
630
+ recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
631
+
632
+ return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
633
+
634
  (textbox
635
+ .submit(start_chat, [textbox,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
636
+ .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
637
  .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
638
+ # .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
639
  )
640
 
641
  (examples_hidden
642
+ .change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
643
+ .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
644
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
645
+ # .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
646
  )
647
 
648
 
 
653
  return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
654
 
655
 
656
+ sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
657
+
658
+ # update sources numbers
659
+ sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
660
+ figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
661
+ current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
662
+ papers_html.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
663
 
664
+ # other questions examples
665
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
666
 
667
+ # search for papers
668
+ textbox.submit(find_papers,[textbox,after, dropdown_external_sources], [papers_html,citations_network,papers_summary])
669
+ examples_hidden.change(find_papers,[examples_hidden,after,dropdown_external_sources], [papers_html,citations_network,papers_summary])
670
+
671
+ # btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
672
+ # btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
673
 
674
  demo.queue()
675
 
climateqa/constants.py CHANGED
@@ -42,4 +42,25 @@ POSSIBLE_REPORTS = [
42
  "IPBES IAS A C5",
43
  "IPBES IAS A C6",
44
  "IPBES IAS A SPM"
45
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  "IPBES IAS A C5",
43
  "IPBES IAS A C6",
44
  "IPBES IAS A SPM"
45
+ ]
46
+
47
+ OWID_CATEGORIES = ['Access to Energy', 'Agricultural Production',
48
+ 'Agricultural Regulation & Policy', 'Air Pollution',
49
+ 'Animal Welfare', 'Antibiotics', 'Biodiversity', 'Biofuels',
50
+ 'Biological & Chemical Weapons', 'CO2 & Greenhouse Gas Emissions',
51
+ 'COVID-19', 'Clean Water', 'Clean Water & Sanitation',
52
+ 'Climate Change', 'Crop Yields', 'Diet Compositions',
53
+ 'Electricity', 'Electricity Mix', 'Energy', 'Energy Efficiency',
54
+ 'Energy Prices', 'Environmental Impacts of Food Production',
55
+ 'Environmental Protection & Regulation', 'Famines', 'Farm Size',
56
+ 'Fertilizers', 'Fish & Overfishing', 'Food Supply', 'Food Trade',
57
+ 'Food Waste', 'Food and Agriculture', 'Forests & Deforestation',
58
+ 'Fossil Fuels', 'Future Population Growth',
59
+ 'Hunger & Undernourishment', 'Indoor Air Pollution', 'Land Use',
60
+ 'Land Use & Yields in Agriculture', 'Lead Pollution',
61
+ 'Meat & Dairy Production', 'Metals & Minerals',
62
+ 'Natural Disasters', 'Nuclear Energy', 'Nuclear Weapons',
63
+ 'Oil Spills', 'Outdoor Air Pollution', 'Ozone Layer', 'Pandemics',
64
+ 'Pesticides', 'Plastic Pollution', 'Renewable Energy', 'Soil',
65
+ 'Transport', 'Urbanization', 'Waste Management', 'Water Pollution',
66
+ 'Water Use & Stress', 'Wildfires']
climateqa/engine/chains/answer_chitchat.py CHANGED
@@ -45,8 +45,12 @@ def make_chitchat_node(llm):
45
  chitchat_chain = make_chitchat_chain(llm)
46
 
47
  async def answer_chitchat(state,config):
 
 
48
  answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
49
- return {"answer":answer}
 
 
50
 
51
  return answer_chitchat
52
 
 
45
  chitchat_chain = make_chitchat_chain(llm)
46
 
47
  async def answer_chitchat(state,config):
48
+ print("---- Answer chitchat ----")
49
+
50
  answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
51
+ state["answer"] = answer
52
+ return state
53
+ # return {"answer":answer}
54
 
55
  return answer_chitchat
56
 
climateqa/engine/chains/answer_rag.py CHANGED
@@ -7,6 +7,9 @@ from langchain_core.prompts.base import format_document
7
 
8
  from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
  from climateqa.engine.chains.prompts import papers_prompt_template
 
 
 
10
 
11
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
12
 
@@ -40,6 +43,7 @@ def make_rag_chain(llm):
40
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
41
  chain = ({
42
  "context":lambda x : _combine_documents(x["documents"]),
 
43
  "query":itemgetter("query"),
44
  "language":itemgetter("language"),
45
  "audience":itemgetter("audience"),
@@ -51,7 +55,6 @@ def make_rag_chain_without_docs(llm):
51
  chain = prompt | llm | StrOutputParser()
52
  return chain
53
 
54
-
55
  def make_rag_node(llm,with_docs = True):
56
 
57
  if with_docs:
@@ -60,7 +63,17 @@ def make_rag_node(llm,with_docs = True):
60
  rag_chain = make_rag_chain_without_docs(llm)
61
 
62
  async def answer_rag(state,config):
 
 
 
63
  answer = await rag_chain.ainvoke(state,config)
 
 
 
 
 
 
 
64
  return {"answer":answer}
65
 
66
  return answer_rag
@@ -68,32 +81,32 @@ def make_rag_node(llm,with_docs = True):
68
 
69
 
70
 
71
- # def make_rag_papers_chain(llm):
72
 
73
- # prompt = ChatPromptTemplate.from_template(papers_prompt_template)
74
- # input_documents = {
75
- # "context":lambda x : _combine_documents(x["docs"]),
76
- # **pass_values(["question","language"])
77
- # }
78
 
79
- # chain = input_documents | prompt | llm | StrOutputParser()
80
- # chain = rename_chain(chain,"answer")
81
 
82
- # return chain
83
 
84
 
85
 
86
 
87
 
88
 
89
- # def make_illustration_chain(llm):
90
 
91
- # prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
92
 
93
- # input_description_images = {
94
- # "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
95
- # **pass_values(["question","audience","language","answer"]),
96
- # }
97
 
98
- # illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
99
- # return illustration_chain
 
7
 
8
  from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
  from climateqa.engine.chains.prompts import papers_prompt_template
10
+ import time
11
+ from ..utils import rename_chain, pass_values
12
+
13
 
14
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
15
 
 
43
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
44
  chain = ({
45
  "context":lambda x : _combine_documents(x["documents"]),
46
+ "context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
47
  "query":itemgetter("query"),
48
  "language":itemgetter("language"),
49
  "audience":itemgetter("audience"),
 
55
  chain = prompt | llm | StrOutputParser()
56
  return chain
57
 
 
58
  def make_rag_node(llm,with_docs = True):
59
 
60
  if with_docs:
 
63
  rag_chain = make_rag_chain_without_docs(llm)
64
 
65
  async def answer_rag(state,config):
66
+ print("---- Answer RAG ----")
67
+ start_time = time.time()
68
+
69
  answer = await rag_chain.ainvoke(state,config)
70
+
71
+ end_time = time.time()
72
+ elapsed_time = end_time - start_time
73
+ print("RAG elapsed time: ", elapsed_time)
74
+ print("Answer size : ", len(answer))
75
+ # print(f"\n\nAnswer:\n{answer}")
76
+
77
  return {"answer":answer}
78
 
79
  return answer_rag
 
81
 
82
 
83
 
84
+ def make_rag_papers_chain(llm):
85
 
86
+ prompt = ChatPromptTemplate.from_template(papers_prompt_template)
87
+ input_documents = {
88
+ "context":lambda x : _combine_documents(x["docs"]),
89
+ **pass_values(["question","language"])
90
+ }
91
 
92
+ chain = input_documents | prompt | llm | StrOutputParser()
93
+ chain = rename_chain(chain,"answer")
94
 
95
+ return chain
96
 
97
 
98
 
99
 
100
 
101
 
102
+ def make_illustration_chain(llm):
103
 
104
+ prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
105
 
106
+ input_description_images = {
107
+ "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
108
+ **pass_values(["question","audience","language","answer"]),
109
+ }
110
 
111
+ illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
112
+ return illustration_chain
climateqa/engine/chains/chitchat_categorization.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import List
4
+ from typing import Literal
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+
10
+ class IntentCategorizer(BaseModel):
11
+ """Analyzing the user message input"""
12
+
13
+ environment: bool = Field(
14
+ description="Return 'True' if the question relates to climate change, the environment, nature, etc. (Example: should I eat fish?). Return 'False' if the question is just chit chat or not related to the environment or climate change.",
15
+ )
16
+
17
+
18
+ def make_chitchat_intent_categorization_chain(llm):
19
+
20
+ openai_functions = [convert_to_openai_function(IntentCategorizer)]
21
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
22
+
23
+ prompt = ChatPromptTemplate.from_messages([
24
+ ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
25
+ ("user", "input: {input}")
26
+ ])
27
+
28
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
29
+ return chain
30
+
31
+
32
+ def make_chitchat_intent_categorization_node(llm):
33
+
34
+ categorization_chain = make_chitchat_intent_categorization_chain(llm)
35
+
36
+ def categorize_message(state):
37
+ output = categorization_chain.invoke({"input": state["user_input"]})
38
+ print(f"\n\nChit chat output intent categorization: {output}\n")
39
+ state["search_graphs_chitchat"] = output["environment"]
40
+ print(f"\n\nChit chat output intent categorization: {state}\n")
41
+ return state
42
+
43
+ return categorize_message
climateqa/engine/chains/graph_retriever.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from contextlib import contextmanager
4
+
5
+ from ..reranker import rerank_docs
6
+ from ..graph_retriever import retrieve_graphs # GraphRetriever
7
+ from ...utils import remove_duplicates_keep_highest_score
8
+
9
+
10
+ def divide_into_parts(target, parts):
11
+ # Base value for each part
12
+ base = target // parts
13
+ # Remainder to distribute
14
+ remainder = target % parts
15
+ # List to hold the result
16
+ result = []
17
+
18
+ for i in range(parts):
19
+ if i < remainder:
20
+ # These parts get base value + 1
21
+ result.append(base + 1)
22
+ else:
23
+ # The rest get the base value
24
+ result.append(base)
25
+
26
+ return result
27
+
28
+
29
+ @contextmanager
30
+ def suppress_output():
31
+ # Open a null device
32
+ with open(os.devnull, 'w') as devnull:
33
+ # Store the original stdout and stderr
34
+ old_stdout = sys.stdout
35
+ old_stderr = sys.stderr
36
+ # Redirect stdout and stderr to the null device
37
+ sys.stdout = devnull
38
+ sys.stderr = devnull
39
+ try:
40
+ yield
41
+ finally:
42
+ # Restore stdout and stderr
43
+ sys.stdout = old_stdout
44
+ sys.stderr = old_stderr
45
+
46
+
47
+ def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
48
+
49
+ async def node_retrieve_graphs(state):
50
+ print("---- Retrieving graphs ----")
51
+
52
+ POSSIBLE_SOURCES = ["IEA", "OWID"]
53
+ questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
54
+ # sources_input = state["sources_input"]
55
+ sources_input = ["auto"]
56
+
57
+ auto_mode = "auto" in sources_input
58
+
59
+ # There are several options to get the final top k
60
+ # Option 1 - Get 100 documents by question and rerank by question
61
+ # Option 2 - Get 100/n documents by question and rerank the total
62
+ if rerank_by_question:
63
+ k_by_question = divide_into_parts(k_final,len(questions))
64
+
65
+ docs = []
66
+
67
+ for i,q in enumerate(questions):
68
+
69
+ question = q["question"] if isinstance(q, dict) else q
70
+
71
+ print(f"Subquestion {i}: {question}")
72
+
73
+ # If auto mode, we use all sources
74
+ if auto_mode:
75
+ sources = POSSIBLE_SOURCES
76
+ # Otherwise, we use the config
77
+ else:
78
+ sources = sources_input
79
+
80
+ if any([x in POSSIBLE_SOURCES for x in sources]):
81
+
82
+ sources = [x for x in sources if x in POSSIBLE_SOURCES]
83
+
84
+ # Search the document store using the retriever
85
+ docs_question = await retrieve_graphs(
86
+ query = question,
87
+ vectorstore = vectorstore,
88
+ sources = sources,
89
+ k_total = k_before_reranking,
90
+ threshold = 0.5,
91
+ )
92
+ # docs_question = retriever.get_relevant_documents(question)
93
+
94
+ # Rerank
95
+ if reranker is not None and docs_question!=[]:
96
+ with suppress_output():
97
+ docs_question = rerank_docs(reranker,docs_question,question)
98
+ else:
99
+ # Add a default reranking score
100
+ for doc in docs_question:
101
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
102
+
103
+ # If rerank by question we select the top documents for each question
104
+ if rerank_by_question:
105
+ docs_question = docs_question[:k_by_question[i]]
106
+
107
+ # Add sources used in the metadata
108
+ for doc in docs_question:
109
+ doc.metadata["sources_used"] = sources
110
+
111
+ print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
112
+
113
+ docs.extend(docs_question)
114
+
115
+ else:
116
+ print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
117
+
118
+ # Remove duplicates and keep the duplicate document with the highest reranking score
119
+ docs = remove_duplicates_keep_highest_score(docs)
120
+
121
+ # Sorting the list in descending order by rerank_score
122
+ # Then select the top k
123
+ docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
124
+ docs = docs[:k_final]
125
+
126
+ return {"recommended_content": docs}
127
+
128
+ return node_retrieve_graphs
climateqa/engine/chains/intent_categorization.py CHANGED
@@ -17,8 +17,8 @@ class IntentCategorizer(BaseModel):
17
  intent: str = Field(
18
  enum=[
19
  "ai_impact",
20
- "geo_info",
21
- "esg",
22
  "search",
23
  "chitchat",
24
  ],
@@ -28,11 +28,12 @@ class IntentCategorizer(BaseModel):
28
 
29
  Examples:
30
  - ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
31
- - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
32
- - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
33
  - search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
34
  - chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
35
  """,
 
 
 
36
  )
37
 
38
 
@@ -43,7 +44,7 @@ def make_intent_categorization_chain(llm):
43
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
44
 
45
  prompt = ChatPromptTemplate.from_messages([
46
- ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
47
  ("user", "input: {input}")
48
  ])
49
 
@@ -56,7 +57,10 @@ def make_intent_categorization_node(llm):
56
  categorization_chain = make_intent_categorization_chain(llm)
57
 
58
  def categorize_message(state):
59
- output = categorization_chain.invoke({"input":state["user_input"]})
 
 
 
60
  if "language" not in output: output["language"] = "English"
61
  output["query"] = state["user_input"]
62
  return output
 
17
  intent: str = Field(
18
  enum=[
19
  "ai_impact",
20
+ # "geo_info",
21
+ # "esg",
22
  "search",
23
  "chitchat",
24
  ],
 
28
 
29
  Examples:
30
  - ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
 
 
31
  - search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
32
  - chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
33
  """,
34
+ # - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
35
+ # - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
36
+
37
  )
38
 
39
 
 
44
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
45
 
46
  prompt = ChatPromptTemplate.from_messages([
47
+ ("system", "You are a helpful assistant, you will analyze, translate and categorize the user input message using the function provided. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
48
  ("user", "input: {input}")
49
  ])
50
 
 
57
  categorization_chain = make_intent_categorization_chain(llm)
58
 
59
  def categorize_message(state):
60
+ print("---- Categorize_message ----")
61
+
62
+ output = categorization_chain.invoke({"input": state["user_input"]})
63
+ print(f"\n\nOutput intent categorization: {output}\n")
64
  if "language" not in output: output["language"] = "English"
65
  output["query"] = state["user_input"]
66
  return output
climateqa/engine/chains/prompts.py CHANGED
@@ -147,4 +147,27 @@ audience_prompts = {
147
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
148
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
149
  "experts": "expert and climate scientists that are not afraid of technical terms",
150
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
148
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
149
  "experts": "expert and climate scientists that are not afraid of technical terms",
150
+ }
151
+
152
+
153
+ answer_prompt_graph_template = """
154
+ Given the user question and a list of graphs which are related to the question, rank the graphs based on relevance to the user question. ALWAYS follow the guidelines given below.
155
+
156
+ ### Guidelines ###
157
+ - Keep all the graphs that are given to you.
158
+ - NEVER modify the graph HTML embedding, the category or the source leave them exactly as they are given.
159
+ - Return the ranked graphs as a list of dictionaries with keys 'embedding', 'category', and 'source'.
160
+ - Return a valid JSON output.
161
+
162
+ -----------------------
163
+ User question:
164
+ {query}
165
+
166
+ Graphs and their HTML embedding:
167
+ {recommended_content}
168
+
169
+ -----------------------
170
+ {format_instructions}
171
+
172
+ Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
173
+ """
climateqa/engine/chains/query_transformation.py CHANGED
@@ -69,15 +69,15 @@ class QueryAnalysis(BaseModel):
69
  # """
70
  # )
71
 
72
- sources: List[Literal["IPCC", "IPBES", "IPOS","OpenAlex"]] = Field(
73
  ...,
74
  description="""
75
  Given a user question choose which documents would be most relevant for answering their question,
76
  - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
77
  - IPBES is for questions about biodiversity and nature
78
  - IPOS is for questions about the ocean and deep sea mining
79
- - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
80
  """,
 
81
  )
82
  # topics: List[Literal[
83
  # "Climate change",
@@ -138,6 +138,8 @@ def make_query_transform_node(llm,k_final=15):
138
  rewriter_chain = make_query_rewriter_chain(llm)
139
 
140
  def transform_query(state):
 
 
141
 
142
  if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
143
  auto_mode = False
@@ -158,6 +160,12 @@ def make_query_transform_node(llm,k_final=15):
158
  for question in new_state["questions"]:
159
  question_state = {"question":question}
160
  analysis_output = rewriter_chain.invoke({"input":question})
 
 
 
 
 
 
161
  question_state.update(analysis_output)
162
  questions.append(question_state)
163
 
 
69
  # """
70
  # )
71
 
72
+ sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field( #,"OpenAlex"]] = Field(
73
  ...,
74
  description="""
75
  Given a user question choose which documents would be most relevant for answering their question,
76
  - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
77
  - IPBES is for questions about biodiversity and nature
78
  - IPOS is for questions about the ocean and deep sea mining
 
79
  """,
80
+ # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
81
  )
82
  # topics: List[Literal[
83
  # "Climate change",
 
138
  rewriter_chain = make_query_rewriter_chain(llm)
139
 
140
  def transform_query(state):
141
+ print("---- Transform query ----")
142
+
143
 
144
  if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
145
  auto_mode = False
 
160
  for question in new_state["questions"]:
161
  question_state = {"question":question}
162
  analysis_output = rewriter_chain.invoke({"input":question})
163
+
164
+ # TODO WARNING llm should always return smthg
165
+ # The case when the llm does not return any sources
166
+ if not analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in analysis_output["sources"]):
167
+ analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
168
+
169
  question_state.update(analysis_output)
170
  questions.append(question_state)
171
 
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -8,10 +8,13 @@ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
8
  from langchain_core.runnables import RunnableLambda
9
 
10
  from ..reranker import rerank_docs
11
- from ...knowledge.retriever import ClimateQARetriever
12
  from ...knowledge.openalex import OpenAlexRetriever
13
  from .keywords_extraction import make_keywords_extraction_chain
14
  from ..utils import log_event
 
 
 
15
 
16
 
17
 
@@ -57,105 +60,244 @@ def query_retriever(question):
57
  """Just a dummy tool to simulate the retriever query"""
58
  return question
59
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
63
 
 
 
64
 
 
 
 
 
65
 
66
- def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
 
 
 
67
 
68
- # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
69
- @chain
70
- async def retrieve_documents(state,config):
71
-
72
- keywords_extraction = make_keywords_extraction_chain(llm)
73
-
74
- current_question = state["remaining_questions"][0]
75
- remaining_questions = state["remaining_questions"][1:]
76
-
77
- # ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
 
 
 
 
 
 
 
78
 
 
 
79
 
80
- # # There are several options to get the final top k
81
- # # Option 1 - Get 100 documents by question and rerank by question
82
- # # Option 2 - Get 100/n documents by question and rerank the total
83
- # if rerank_by_question:
84
- # k_by_question = divide_into_parts(k_final,len(questions))
85
- if "documents" in state and state["documents"] is not None:
86
- docs = state["documents"]
87
- else:
88
- docs = []
89
-
90
-
91
-
92
- k_by_question = k_final // state["n_questions"]
93
-
94
- sources = current_question["sources"]
95
- question = current_question["question"]
96
- index = current_question["index"]
97
 
 
 
 
 
 
 
 
98
 
99
- await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
- if index == "Vector":
103
-
104
- # Search the document store using the retriever
105
- # Configure high top k for further reranking step
106
- retriever = ClimateQARetriever(
107
- vectorstore=vectorstore,
108
- sources = sources,
109
- min_size = 200,
110
- k_summary = k_summary,
111
- k_total = k_before_reranking,
112
- threshold = 0.5,
113
- )
114
- docs_question = await retriever.ainvoke(question,config)
115
 
116
- elif index == "OpenAlex":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- keywords = keywords_extraction.invoke(question)["keywords"]
119
- openalex_query = " AND ".join(keywords)
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- print(f"... OpenAlex query: {openalex_query}")
122
 
123
- retriever_openalex = OpenAlexRetriever(
124
- min_year = state.get("min_year",1960),
125
- max_year = state.get("max_year",None),
126
- k = k_before_reranking
127
- )
128
- docs_question = await retriever_openalex.ainvoke(openalex_query,config)
 
 
 
 
 
 
 
129
 
130
- else:
131
- raise Exception(f"Index {index} not found in the routing index")
132
-
133
- # Rerank
134
- if reranker is not None:
135
- with suppress_output():
136
- docs_question = rerank_docs(reranker,docs_question,question)
137
- else:
138
- # Add a default reranking score
139
- for doc in docs_question:
140
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
141
-
142
- # If rerank by question we select the top documents for each question
143
- if rerank_by_question:
144
- docs_question = docs_question[:k_by_question]
145
-
146
- # Add sources used in the metadata
147
  for doc in docs_question:
148
- doc.metadata["sources_used"] = sources
149
- doc.metadata["question_used"] = question
150
- doc.metadata["index_used"] = index
151
-
152
- # Add to the list of docs
153
- docs.extend(docs_question)
154
 
155
- # Sorting the list in descending order by rerank_score
156
- docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
157
- new_state = {"documents":docs,"remaining_questions":remaining_questions}
158
- return new_state
 
 
 
 
 
 
 
 
159
 
160
- return retrieve_documents
 
 
 
 
 
 
 
 
 
 
161
 
 
8
  from langchain_core.runnables import RunnableLambda
9
 
10
  from ..reranker import rerank_docs
11
+ # from ...knowledge.retriever import ClimateQARetriever
12
  from ...knowledge.openalex import OpenAlexRetriever
13
  from .keywords_extraction import make_keywords_extraction_chain
14
  from ..utils import log_event
15
+ from langchain_core.vectorstores import VectorStore
16
+ from typing import List
17
+ from langchain_core.documents.base import Document
18
 
19
 
20
 
 
60
  """Just a dummy tool to simulate the retriever query"""
61
  return question
62
 
63
+ def _add_sources_used_in_metadata(docs,sources,question,index):
64
+ for doc in docs:
65
+ doc.metadata["sources_used"] = sources
66
+ doc.metadata["question_used"] = question
67
+ doc.metadata["index_used"] = index
68
+ return docs
69
 
70
+ def _get_k_summary_by_question(n_questions):
71
+ if n_questions == 0:
72
+ return 0
73
+ elif n_questions == 1:
74
+ return 5
75
+ elif n_questions == 2:
76
+ return 3
77
+ elif n_questions == 3:
78
+ return 2
79
+ else:
80
+ return 1
81
+
82
+ def _get_k_images_by_question(n_questions):
83
+ if n_questions == 0:
84
+ return 0
85
+ elif n_questions == 1:
86
+ return 7
87
+ elif n_questions == 2:
88
+ return 5
89
+ elif n_questions == 3:
90
+ return 2
91
+ else:
92
+ return 1
93
+
94
+ def _add_metadata_and_score(docs: List) -> Document:
95
+ # Add score to metadata
96
+ docs_with_metadata = []
97
+ for i,(doc,score) in enumerate(docs):
98
+ doc.page_content = doc.page_content.replace("\r\n"," ")
99
+ doc.metadata["similarity_score"] = score
100
+ doc.metadata["content"] = doc.page_content
101
+ doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
102
+ # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
103
+ docs_with_metadata.append(doc)
104
+ return docs_with_metadata
105
 
106
+ async def get_IPCC_relevant_documents(
107
+ query: str,
108
+ vectorstore:VectorStore,
109
+ sources:list = ["IPCC","IPBES","IPOS"],
110
+ search_figures:bool = False,
111
+ reports:list = [],
112
+ threshold:float = 0.6,
113
+ k_summary:int = 3,
114
+ k_total:int = 10,
115
+ k_images: int = 5,
116
+ namespace:str = "vectors",
117
+ min_size:int = 200,
118
+ search_only:bool = False,
119
+ ) :
120
 
121
+ # Check if all elements in the list are either IPCC or IPBES
122
+ assert isinstance(sources,list)
123
+ assert sources
124
+ assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
125
+ assert k_total > k_summary, "k_total should be greater than k_summary"
126
 
127
+ # Prepare base search kwargs
128
+ filters = {}
129
 
130
+ if len(reports) > 0:
131
+ filters["short_name"] = {"$in":reports}
132
+ else:
133
+ filters["source"] = { "$in": sources}
134
 
135
+ # INIT
136
+ docs_summaries = []
137
+ docs_full = []
138
+ docs_images = []
139
 
140
+ if search_only:
141
+ # Only search for images if search_only is True
142
+ if search_figures:
143
+ filters_image = {
144
+ **filters,
145
+ "chunk_type":"image"
146
+ }
147
+ docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
148
+ docs_images = _add_metadata_and_score(docs_images)
149
+ else:
150
+ # Regular search flow for text and optionally images
151
+ # Search for k_summary documents in the summaries dataset
152
+ filters_summaries = {
153
+ **filters,
154
+ "chunk_type":"text",
155
+ "report_type": { "$in":["SPM"]},
156
+ }
157
 
158
+ docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
159
+ docs_summaries = [x for x in docs_summaries if x[1] > threshold]
160
 
161
+ # Search for k_total - k_summary documents in the full reports dataset
162
+ filters_full = {
163
+ **filters,
164
+ "chunk_type":"text",
165
+ "report_type": { "$nin":["SPM"]},
166
+ }
167
+ k_full = k_total - len(docs_summaries)
168
+ docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
 
 
 
 
 
 
 
 
 
169
 
170
+ if search_figures:
171
+ # Images
172
+ filters_image = {
173
+ **filters,
174
+ "chunk_type":"image"
175
+ }
176
+ docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
177
 
178
+ docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
179
+
180
+ # Filter if length are below threshold
181
+ docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
182
+ docs_full = [x for x in docs_full if len(x.page_content) > min_size]
183
+
184
+ return {
185
+ "docs_summaries" : docs_summaries,
186
+ "docs_full" : docs_full,
187
+ "docs_images" : docs_images,
188
+ }
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
193
+ # @chain
194
+ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
195
+ """
196
+ Retrieve and rerank documents based on the current question in the state.
197
+
198
+ Args:
199
+ state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
200
+ config (dict): Configuration settings for logging and other purposes.
201
+ vectorstore (object): The vector store used to retrieve relevant documents.
202
+ reranker (object): The reranker used to rerank the retrieved documents.
203
+ llm (object): The language model used for processing.
204
+ rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
205
+ k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
206
+ k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
207
+ k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
208
+ k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
209
+ Returns:
210
+ dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
211
+ """
212
+ print("---- Retrieve documents ----")
213
+
214
+ # Get the documents from the state
215
+ if "documents" in state and state["documents"] is not None:
216
+ docs = state["documents"]
217
+ else:
218
+ docs = []
219
+ # Get the related_content from the state
220
+ if "related_content" in state and state["related_content"] is not None:
221
+ related_content = state["related_content"]
222
+ else:
223
+ related_content = []
224
+
225
+ search_figures = "IPCC figures" in state["relevant_content_sources"]
226
+ search_only = state["search_only"]
227
 
228
+ # Get the current question
229
+ current_question = state["remaining_questions"][0]
230
+ remaining_questions = state["remaining_questions"][1:]
231
+
232
+ k_by_question = k_final // state["n_questions"]
233
+ k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
234
+ k_images_by_question = _get_k_images_by_question(state["n_questions"])
235
+
236
+ sources = current_question["sources"]
237
+ question = current_question["question"]
238
+ index = current_question["index"]
239
+
240
+ print(f"Retrieve documents for question: {question}")
241
+ await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
242
 
 
243
 
244
+ if index == "Vector": # always true for now
245
+ docs_question_dict = await get_IPCC_relevant_documents(
246
+ query = question,
247
+ vectorstore=vectorstore,
248
+ search_figures = search_figures,
249
+ sources = sources,
250
+ min_size = 200,
251
+ k_summary = k_summary_by_question,
252
+ k_total = k_before_reranking,
253
+ k_images = k_images_by_question,
254
+ threshold = 0.5,
255
+ search_only = search_only,
256
+ )
257
 
258
+
259
+ # Rerank
260
+ if reranker is not None:
261
+ with suppress_output():
262
+ docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
263
+ docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
264
+ docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
265
+ if rerank_by_question:
266
+ docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
267
+ docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
268
+ docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
269
+ else:
270
+ docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
271
+ # Add a default reranking score
 
 
 
272
  for doc in docs_question:
273
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
274
+
275
+ docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
276
+ docs_question = docs_question[:k_by_question]
277
+ images_question = docs_question_images_reranked[:k_images]
 
278
 
279
+ if reranker is not None and rerank_by_question:
280
+ docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
281
+
282
+ # Add sources used in the metadata
283
+ docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
284
+ images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
285
+
286
+ # Add to the list of docs
287
+ docs.extend(docs_question)
288
+ related_content.extend(images_question)
289
+ new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
290
+ return new_state
291
 
292
+
293
+
294
+ def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
295
+ @chain
296
+ async def retrieve_docs(state, config):
297
+ state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
298
+ return state
299
+
300
+ return retrieve_docs
301
+
302
+
303
 
climateqa/engine/chains/retrieve_papers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.keywords import make_keywords_chain
2
+ from climateqa.engine.llm import get_llm
3
+ from climateqa.knowledge.openalex import OpenAlex
4
+ from climateqa.engine.chains.answer_rag import make_rag_papers_chain
5
+ from front.utils import make_html_papers
6
+ from climateqa.engine.reranker import get_reranker
7
+
8
+ oa = OpenAlex()
9
+
10
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
11
+ reranker = get_reranker("nano")
12
+
13
+
14
+ papers_cols_widths = {
15
+ "id":100,
16
+ "title":300,
17
+ "doi":100,
18
+ "publication_year":100,
19
+ "abstract":500,
20
+ "is_oa":50,
21
+ }
22
+
23
+ papers_cols = list(papers_cols_widths.keys())
24
+ papers_cols_widths = list(papers_cols_widths.values())
25
+
26
+
27
+
28
+ def generate_keywords(query):
29
+ chain = make_keywords_chain(llm)
30
+ keywords = chain.invoke(query)
31
+ keywords = " AND ".join(keywords["keywords"])
32
+ return keywords
33
+
34
+
35
+ async def find_papers(query,after, relevant_content_sources, reranker= reranker):
36
+ if "OpenAlex" in relevant_content_sources:
37
+ summary = ""
38
+ keywords = generate_keywords(query)
39
+ df_works = oa.search(keywords,after = after)
40
+
41
+ print(f"Found {len(df_works)} papers")
42
+
43
+ if not df_works.empty:
44
+ df_works = df_works.dropna(subset=["abstract"])
45
+ df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
46
+ df_works = oa.rerank(query,df_works,reranker)
47
+ df_works = df_works.sort_values("rerank_score",ascending=False)
48
+ docs_html = []
49
+ for i in range(10):
50
+ docs_html.append(make_html_papers(df_works, i))
51
+ docs_html = "".join(docs_html)
52
+ G = oa.make_network(df_works)
53
+
54
+ height = "750px"
55
+ network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
56
+ network_html = network.generate_html()
57
+
58
+ network_html = network_html.replace("'", "\"")
59
+ css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
60
+ network_html = network_html + css_to_inject
61
+
62
+
63
+ network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
64
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
65
+ allow-scripts allow-same-origin allow-popups
66
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
67
+ allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
68
+
69
+
70
+ docs = df_works["content"].head(10).tolist()
71
+
72
+ df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
73
+ df_works["doc"] = df_works["doc"] + 1
74
+ df_works = df_works[papers_cols]
75
+
76
+ yield docs_html, network_html, summary
77
+
78
+ chain = make_rag_papers_chain(llm)
79
+ result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
80
+ path_answer = "/logs/StrOutputParser/streamed_output/-"
81
+
82
+ async for op in result:
83
+
84
+ op = op.ops[0]
85
+
86
+ if op['path'] == path_answer: # reforulated question
87
+ new_token = op['value'] # str
88
+ summary += new_token
89
+ else:
90
+ continue
91
+ yield docs_html, network_html, summary
92
+ else :
93
+ print("No papers found")
94
+ else :
95
+ yield "","", ""
climateqa/engine/chains/retriever.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sys
2
+ # import os
3
+ # from contextlib import contextmanager
4
+
5
+ # from ..reranker import rerank_docs
6
+ # from ...knowledge.retriever import ClimateQARetriever
7
+
8
+
9
+
10
+
11
+ # def divide_into_parts(target, parts):
12
+ # # Base value for each part
13
+ # base = target // parts
14
+ # # Remainder to distribute
15
+ # remainder = target % parts
16
+ # # List to hold the result
17
+ # result = []
18
+
19
+ # for i in range(parts):
20
+ # if i < remainder:
21
+ # # These parts get base value + 1
22
+ # result.append(base + 1)
23
+ # else:
24
+ # # The rest get the base value
25
+ # result.append(base)
26
+
27
+ # return result
28
+
29
+
30
+ # @contextmanager
31
+ # def suppress_output():
32
+ # # Open a null device
33
+ # with open(os.devnull, 'w') as devnull:
34
+ # # Store the original stdout and stderr
35
+ # old_stdout = sys.stdout
36
+ # old_stderr = sys.stderr
37
+ # # Redirect stdout and stderr to the null device
38
+ # sys.stdout = devnull
39
+ # sys.stderr = devnull
40
+ # try:
41
+ # yield
42
+ # finally:
43
+ # # Restore stdout and stderr
44
+ # sys.stdout = old_stdout
45
+ # sys.stderr = old_stderr
46
+
47
+
48
+
49
+ # def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
50
+
51
+ # def retrieve_documents(state):
52
+
53
+ # POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
54
+ # questions = state["questions"]
55
+
56
+ # # Use sources from the user input or from the LLM detection
57
+ # if "sources_input" not in state or state["sources_input"] is None:
58
+ # sources_input = ["auto"]
59
+ # else:
60
+ # sources_input = state["sources_input"]
61
+ # auto_mode = "auto" in sources_input
62
+
63
+ # # There are several options to get the final top k
64
+ # # Option 1 - Get 100 documents by question and rerank by question
65
+ # # Option 2 - Get 100/n documents by question and rerank the total
66
+ # if rerank_by_question:
67
+ # k_by_question = divide_into_parts(k_final,len(questions))
68
+
69
+ # docs = []
70
+
71
+ # for i,q in enumerate(questions):
72
+
73
+ # sources = q["sources"]
74
+ # question = q["question"]
75
+
76
+ # # If auto mode, we use the sources detected by the LLM
77
+ # if auto_mode:
78
+ # sources = [x for x in sources if x in POSSIBLE_SOURCES]
79
+
80
+ # # Otherwise, we use the config
81
+ # else:
82
+ # sources = sources_input
83
+
84
+ # # Search the document store using the retriever
85
+ # # Configure high top k for further reranking step
86
+ # retriever = ClimateQARetriever(
87
+ # vectorstore=vectorstore,
88
+ # sources = sources,
89
+ # # reports = ias_reports,
90
+ # min_size = 200,
91
+ # k_summary = k_summary,
92
+ # k_total = k_before_reranking,
93
+ # threshold = 0.5,
94
+ # )
95
+ # docs_question = retriever.get_relevant_documents(question)
96
+
97
+ # # Rerank
98
+ # if reranker is not None:
99
+ # with suppress_output():
100
+ # docs_question = rerank_docs(reranker,docs_question,question)
101
+ # else:
102
+ # # Add a default reranking score
103
+ # for doc in docs_question:
104
+ # doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
105
+
106
+ # # If rerank by question we select the top documents for each question
107
+ # if rerank_by_question:
108
+ # docs_question = docs_question[:k_by_question[i]]
109
+
110
+ # # Add sources used in the metadata
111
+ # for doc in docs_question:
112
+ # doc.metadata["sources_used"] = sources
113
+
114
+ # # Add to the list of docs
115
+ # docs.extend(docs_question)
116
+
117
+ # # Sorting the list in descending order by rerank_score
118
+ # # Then select the top k
119
+ # docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
120
+ # docs = docs[:k_final]
121
+
122
+ # new_state = {"documents":docs}
123
+ # return new_state
124
+
125
+ # return retrieve_documents
126
+
climateqa/engine/chains/set_defaults.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def set_defaults(state):
2
+ print("---- Setting defaults ----")
3
+
4
+ if not state["audience"] or state["audience"] is None:
5
+ state.update({"audience": "experts"})
6
+
7
+ sources_input = state["sources_input"] if "sources_input" in state else ["auto"]
8
+ state.update({"sources_input": sources_input})
9
+
10
+ # if not state["sources_input"] or state["sources_input"] is None:
11
+ # state.update({"sources_input": ["auto"]})
12
+
13
+ return state
climateqa/engine/chains/translation.py CHANGED
@@ -30,10 +30,11 @@ def make_translation_chain(llm):
30
 
31
 
32
  def make_translation_node(llm):
33
-
34
  translation_chain = make_translation_chain(llm)
35
 
36
  def translate_query(state):
 
 
37
  user_input = state["user_input"]
38
  translation = translation_chain.invoke({"input":user_input})
39
  return {"query":translation["translation"]}
 
30
 
31
 
32
  def make_translation_node(llm):
 
33
  translation_chain = make_translation_chain(llm)
34
 
35
  def translate_query(state):
36
+ print("---- Translate query ----")
37
+
38
  user_input = state["user_input"]
39
  translation = translation_chain.invoke({"input":user_input})
40
  return {"query":translation["translation"]}
climateqa/engine/graph.py CHANGED
@@ -7,7 +7,7 @@ from langgraph.graph import END, StateGraph
7
  from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
8
 
9
  from typing_extensions import TypedDict
10
- from typing import List
11
 
12
  from IPython.display import display, HTML, Image
13
 
@@ -18,6 +18,9 @@ from .chains.translation import make_translation_node
18
  from .chains.intent_categorization import make_intent_categorization_node
19
  from .chains.retrieve_documents import make_retriever_node
20
  from .chains.answer_rag import make_rag_node
 
 
 
21
 
22
  class GraphState(TypedDict):
23
  """
@@ -26,16 +29,21 @@ class GraphState(TypedDict):
26
  user_input : str
27
  language : str
28
  intent : str
 
29
  query: str
30
  remaining_questions : List[dict]
31
  n_questions : int
32
  answer: str
33
  audience: str = "experts"
34
  sources_input: List[str] = ["IPCC","IPBES"]
 
35
  sources_auto: bool = True
36
  min_year: int = 1960
37
  max_year: int = None
38
  documents: List[Document]
 
 
 
39
 
40
  def search(state): #TODO
41
  return state
@@ -52,6 +60,13 @@ def route_intent(state):
52
  else:
53
  # Search route
54
  return "search"
 
 
 
 
 
 
 
55
 
56
  def route_translation(state):
57
  if state["language"].lower() == "english":
@@ -66,11 +81,18 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
66
  else:
67
  return "answer_rag_no_docs"
68
 
 
 
 
 
 
 
 
69
 
70
  def make_id_dict(values):
71
  return {k:k for k in values}
72
 
73
- def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
74
 
75
  workflow = StateGraph(GraphState)
76
 
@@ -80,21 +102,26 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
80
  translate_query = make_translation_node(llm)
81
  answer_chitchat = make_chitchat_node(llm)
82
  answer_ai_impact = make_ai_impact_node(llm)
83
- retrieve_documents = make_retriever_node(vectorstore,reranker,llm)
84
- answer_rag = make_rag_node(llm,with_docs=True)
85
- answer_rag_no_docs = make_rag_node(llm,with_docs=False)
 
 
86
 
87
  # Define the nodes
 
88
  workflow.add_node("categorize_intent", categorize_intent)
89
  workflow.add_node("search", search)
90
  workflow.add_node("answer_search", answer_search)
91
  workflow.add_node("transform_query", transform_query)
92
  workflow.add_node("translate_query", translate_query)
93
  workflow.add_node("answer_chitchat", answer_chitchat)
94
- # workflow.add_node("answer_ai_impact", answer_ai_impact)
95
- workflow.add_node("retrieve_documents",retrieve_documents)
96
- workflow.add_node("answer_rag",answer_rag)
97
- workflow.add_node("answer_rag_no_docs",answer_rag_no_docs)
 
 
98
 
99
  # Entry point
100
  workflow.set_entry_point("categorize_intent")
@@ -106,6 +133,12 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
106
  make_id_dict(["answer_chitchat","search"])
107
  )
108
 
 
 
 
 
 
 
109
  workflow.add_conditional_edges(
110
  "search",
111
  route_translation,
@@ -113,8 +146,9 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
113
  )
114
  workflow.add_conditional_edges(
115
  "retrieve_documents",
116
- lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
117
- make_id_dict(["retrieve_documents","answer_search"])
 
118
  )
119
 
120
  workflow.add_conditional_edges(
@@ -122,14 +156,21 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
122
  lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
123
  make_id_dict(["answer_rag","answer_rag_no_docs"])
124
  )
 
 
 
 
 
125
 
126
  # Define the edges
127
  workflow.add_edge("translate_query", "transform_query")
128
  workflow.add_edge("transform_query", "retrieve_documents")
 
 
129
  workflow.add_edge("answer_rag", END)
130
  workflow.add_edge("answer_rag_no_docs", END)
131
- workflow.add_edge("answer_chitchat", END)
132
- # workflow.add_edge("answer_ai_impact", END)
133
 
134
  # Compile
135
  app = workflow.compile()
@@ -146,4 +187,4 @@ def display_graph(app):
146
  draw_method=MermaidDrawMethod.API,
147
  )
148
  )
149
- )
 
7
  from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
8
 
9
  from typing_extensions import TypedDict
10
+ from typing import List, Dict
11
 
12
  from IPython.display import display, HTML, Image
13
 
 
18
  from .chains.intent_categorization import make_intent_categorization_node
19
  from .chains.retrieve_documents import make_retriever_node
20
  from .chains.answer_rag import make_rag_node
21
+ from .chains.graph_retriever import make_graph_retriever_node
22
+ from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
23
+ # from .chains.set_defaults import set_defaults
24
 
25
  class GraphState(TypedDict):
26
  """
 
29
  user_input : str
30
  language : str
31
  intent : str
32
+ search_graphs_chitchat : bool
33
  query: str
34
  remaining_questions : List[dict]
35
  n_questions : int
36
  answer: str
37
  audience: str = "experts"
38
  sources_input: List[str] = ["IPCC","IPBES"]
39
+ relevant_content_sources: List[str] = ["IPCC figures"]
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
43
  documents: List[Document]
44
+ related_contents : Dict[str,Document]
45
+ recommended_content : List[Document]
46
+ search_only : bool = False
47
 
48
  def search(state): #TODO
49
  return state
 
60
  else:
61
  # Search route
62
  return "search"
63
+
64
+ def chitchat_route_intent(state):
65
+ intent = state["search_graphs_chitchat"]
66
+ if intent is True:
67
+ return "retrieve_graphs_chitchat"
68
+ elif intent is False:
69
+ return END
70
 
71
  def route_translation(state):
72
  if state["language"].lower() == "english":
 
81
  else:
82
  return "answer_rag_no_docs"
83
 
84
+ def route_retrieve_documents(state):
85
+ if state["search_only"] :
86
+ return END
87
+ elif len(state["remaining_questions"]) > 0:
88
+ return "retrieve_documents"
89
+ else:
90
+ return "answer_search"
91
 
92
  def make_id_dict(values):
93
  return {k:k for k in values}
94
 
95
+ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2):
96
 
97
  workflow = StateGraph(GraphState)
98
 
 
102
  translate_query = make_translation_node(llm)
103
  answer_chitchat = make_chitchat_node(llm)
104
  answer_ai_impact = make_ai_impact_node(llm)
105
+ retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker, llm)
106
+ retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
107
+ answer_rag = make_rag_node(llm, with_docs=True)
108
+ answer_rag_no_docs = make_rag_node(llm, with_docs=False)
109
+ chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
110
 
111
  # Define the nodes
112
+ # workflow.add_node("set_defaults", set_defaults)
113
  workflow.add_node("categorize_intent", categorize_intent)
114
  workflow.add_node("search", search)
115
  workflow.add_node("answer_search", answer_search)
116
  workflow.add_node("transform_query", transform_query)
117
  workflow.add_node("translate_query", translate_query)
118
  workflow.add_node("answer_chitchat", answer_chitchat)
119
+ workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
120
+ workflow.add_node("retrieve_graphs", retrieve_graphs)
121
+ workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
122
+ workflow.add_node("retrieve_documents", retrieve_documents)
123
+ workflow.add_node("answer_rag", answer_rag)
124
+ workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
125
 
126
  # Entry point
127
  workflow.set_entry_point("categorize_intent")
 
133
  make_id_dict(["answer_chitchat","search"])
134
  )
135
 
136
+ workflow.add_conditional_edges(
137
+ "chitchat_categorize_intent",
138
+ chitchat_route_intent,
139
+ make_id_dict(["retrieve_graphs_chitchat", END])
140
+ )
141
+
142
  workflow.add_conditional_edges(
143
  "search",
144
  route_translation,
 
146
  )
147
  workflow.add_conditional_edges(
148
  "retrieve_documents",
149
+ # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
150
+ route_retrieve_documents,
151
+ make_id_dict([END,"retrieve_documents","answer_search"])
152
  )
153
 
154
  workflow.add_conditional_edges(
 
156
  lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
157
  make_id_dict(["answer_rag","answer_rag_no_docs"])
158
  )
159
+ workflow.add_conditional_edges(
160
+ "transform_query",
161
+ lambda state : "retrieve_graphs" if "OurWorldInData" in state["relevant_content_sources"] else END,
162
+ make_id_dict(["retrieve_graphs", END])
163
+ )
164
 
165
  # Define the edges
166
  workflow.add_edge("translate_query", "transform_query")
167
  workflow.add_edge("transform_query", "retrieve_documents")
168
+
169
+ workflow.add_edge("retrieve_graphs", END)
170
  workflow.add_edge("answer_rag", END)
171
  workflow.add_edge("answer_rag_no_docs", END)
172
+ workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
173
+
174
 
175
  # Compile
176
  app = workflow.compile()
 
187
  draw_method=MermaidDrawMethod.API,
188
  )
189
  )
190
+ )
climateqa/engine/graph_retriever.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.retrievers import BaseRetriever
2
+ from langchain_core.documents.base import Document
3
+ from langchain_core.vectorstores import VectorStore
4
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
5
+
6
+ from typing import List
7
+
8
+ # class GraphRetriever(BaseRetriever):
9
+ # vectorstore:VectorStore
10
+ # sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
11
+ # threshold:float = 0.5
12
+ # k_total:int = 10
13
+
14
+ # def _get_relevant_documents(
15
+ # self, query: str, *, run_manager: CallbackManagerForRetrieverRun
16
+ # ) -> List[Document]:
17
+
18
+ # # Check if all elements in the list are IEA or OWID
19
+ # assert isinstance(self.sources,list)
20
+ # assert self.sources
21
+ # assert any([x in ["OWID"] for x in self.sources])
22
+
23
+ # # Prepare base search kwargs
24
+ # filters = {}
25
+
26
+ # filters["source"] = {"$in": self.sources}
27
+
28
+ # docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
29
+
30
+ # # Filter if scores are below threshold
31
+ # docs = [x for x in docs if x[1] > self.threshold]
32
+
33
+ # # Remove duplicate documents
34
+ # unique_docs = []
35
+ # seen_docs = []
36
+ # for i, doc in enumerate(docs):
37
+ # if doc[0].page_content not in seen_docs:
38
+ # unique_docs.append(doc)
39
+ # seen_docs.append(doc[0].page_content)
40
+
41
+ # # Add score to metadata
42
+ # results = []
43
+ # for i,(doc,score) in enumerate(unique_docs):
44
+ # doc.metadata["similarity_score"] = score
45
+ # doc.metadata["content"] = doc.page_content
46
+ # results.append(doc)
47
+
48
+ # return results
49
+
50
+ async def retrieve_graphs(
51
+ query: str,
52
+ vectorstore:VectorStore,
53
+ sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
54
+ threshold:float = 0.5,
55
+ k_total:int = 10,
56
+ )-> List[Document]:
57
+
58
+ # Check if all elements in the list are IEA or OWID
59
+ assert isinstance(sources,list)
60
+ assert sources
61
+ assert any([x in ["OWID"] for x in sources])
62
+
63
+ # Prepare base search kwargs
64
+ filters = {}
65
+
66
+ filters["source"] = {"$in": sources}
67
+
68
+ docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
69
+
70
+ # Filter if scores are below threshold
71
+ docs = [x for x in docs if x[1] > threshold]
72
+
73
+ # Remove duplicate documents
74
+ unique_docs = []
75
+ seen_docs = []
76
+ for i, doc in enumerate(docs):
77
+ if doc[0].page_content not in seen_docs:
78
+ unique_docs.append(doc)
79
+ seen_docs.append(doc[0].page_content)
80
+
81
+ # Add score to metadata
82
+ results = []
83
+ for i,(doc,score) in enumerate(unique_docs):
84
+ doc.metadata["similarity_score"] = score
85
+ doc.metadata["content"] = doc.page_content
86
+ results.append(doc)
87
+
88
+ return results
climateqa/engine/keywords.py CHANGED
@@ -11,10 +11,12 @@ class KeywordsOutput(BaseModel):
11
 
12
  keywords: list = Field(
13
  description="""
14
- Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
 
15
 
16
  Example:
17
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
 
18
  - "How will El Nino be impacted by climate change" -> ["el nino"]
19
  - "Is climate change a hoax" -> [Climate change","hoax"]
20
  """
 
11
 
12
  keywords: list = Field(
13
  description="""
14
+ Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers. Answer only with English keywords.
15
+ Do not use special characters or accents.
16
 
17
  Example:
18
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
19
+ - "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
20
  - "How will El Nino be impacted by climate change" -> ["el nino"]
21
  - "Is climate change a hoax" -> [Climate change","hoax"]
22
  """
climateqa/engine/reranker.py CHANGED
@@ -1,11 +1,14 @@
1
  import os
 
2
  from scipy.special import expit, logit
3
  from rerankers import Reranker
 
4
 
 
5
 
6
- def get_reranker(model = "nano",cohere_api_key = None):
7
 
8
- assert model in ["nano","tiny","small","large"]
9
 
10
  if model == "nano":
11
  reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
@@ -17,11 +20,18 @@ def get_reranker(model = "nano",cohere_api_key = None):
17
  if cohere_api_key is None:
18
  cohere_api_key = os.environ["COHERE_API_KEY"]
19
  reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
 
 
 
 
 
20
  return reranker
21
 
22
 
23
 
24
  def rerank_docs(reranker,docs,query):
 
 
25
 
26
  # Get a list of texts from langchain docs
27
  input_docs = [x.page_content for x in docs]
 
1
  import os
2
+ from dotenv import load_dotenv
3
  from scipy.special import expit, logit
4
  from rerankers import Reranker
5
+ from sentence_transformers import CrossEncoder
6
 
7
+ load_dotenv()
8
 
9
+ def get_reranker(model = "nano", cohere_api_key = None):
10
 
11
+ assert model in ["nano","tiny","small","large", "jina"]
12
 
13
  if model == "nano":
14
  reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
 
20
  if cohere_api_key is None:
21
  cohere_api_key = os.environ["COHERE_API_KEY"]
22
  reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
23
+ elif model == "jina":
24
+ # Reached token quota so does not work
25
+ reranker = Reranker("jina-reranker-v2-base-multilingual", api_key = os.getenv("JINA_RERANKER_API_KEY"))
26
+ # marche pas sans gpu ? et anyways returns with another structure donc faudrait changer le code du retriever node
27
+ # reranker = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual", automodel_args={"torch_dtype": "auto"}, trust_remote_code=True,)
28
  return reranker
29
 
30
 
31
 
32
  def rerank_docs(reranker,docs,query):
33
+ if docs == []:
34
+ return []
35
 
36
  # Get a list of texts from langchain docs
37
  input_docs = [x.page_content for x in docs]
climateqa/engine/vectorstore.py CHANGED
@@ -4,6 +4,7 @@
4
  import os
5
  from pinecone import Pinecone
6
  from langchain_community.vectorstores import Pinecone as PineconeVectorstore
 
7
 
8
  # LOAD ENVIRONMENT VARIABLES
9
  try:
@@ -13,7 +14,12 @@ except:
13
  pass
14
 
15
 
16
- def get_pinecone_vectorstore(embeddings,text_key = "content"):
 
 
 
 
 
17
 
18
  # # initialize pinecone
19
  # pinecone.init(
@@ -27,7 +33,7 @@ def get_pinecone_vectorstore(embeddings,text_key = "content"):
27
  # return vectorstore
28
 
29
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
30
- index = pc.Index(os.getenv("PINECONE_API_INDEX"))
31
 
32
  vectorstore = PineconeVectorstore(
33
  index, embeddings, text_key,
 
4
  import os
5
  from pinecone import Pinecone
6
  from langchain_community.vectorstores import Pinecone as PineconeVectorstore
7
+ from langchain_chroma import Chroma
8
 
9
  # LOAD ENVIRONMENT VARIABLES
10
  try:
 
14
  pass
15
 
16
 
17
+ def get_chroma_vectorstore(embedding_function, persist_directory="/home/dora/climate-question-answering/data/vectorstore"):
18
+ vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
19
+ return vectorstore
20
+
21
+
22
+ def get_pinecone_vectorstore(embeddings,text_key = "content", index_name = os.getenv("PINECONE_API_INDEX")):
23
 
24
  # # initialize pinecone
25
  # pinecone.init(
 
33
  # return vectorstore
34
 
35
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
36
+ index = pc.Index(index_name)
37
 
38
  vectorstore = PineconeVectorstore(
39
  index, embeddings, text_key,
climateqa/event_handler.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables.schema import StreamEvent
2
+ from gradio import ChatMessage
3
+ from climateqa.engine.chains.prompts import audience_prompts
4
+ from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
5
+ import numpy as np
6
+
7
+ def init_audience(audience :str) -> str:
8
+ if audience == "Children":
9
+ audience_prompt = audience_prompts["children"]
10
+ elif audience == "General public":
11
+ audience_prompt = audience_prompts["general"]
12
+ elif audience == "Experts":
13
+ audience_prompt = audience_prompts["experts"]
14
+ else:
15
+ audience_prompt = audience_prompts["experts"]
16
+ return audience_prompt
17
+
18
+ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
19
+ """
20
+ Handles the retrieved documents and returns the HTML representation of the documents
21
+
22
+ Args:
23
+ event (StreamEvent): The event containing the retrieved documents
24
+ history (list[ChatMessage]): The current message history
25
+ used_documents (list[str]): The list of used documents
26
+
27
+ Returns:
28
+ tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
29
+ """
30
+ try:
31
+ docs = event["data"]["output"]["documents"]
32
+ docs_html = []
33
+ textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
34
+ for i, d in enumerate(textual_docs, 1):
35
+ if d.metadata["chunk_type"] == "text":
36
+ docs_html.append(make_html_source(d, i))
37
+
38
+ used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
39
+ if used_documents!=[]:
40
+ history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
41
+
42
+ docs_html = "".join(docs_html)
43
+
44
+ related_contents = event["data"]["output"]["related_contents"]
45
+
46
+ except Exception as e:
47
+ print(f"Error getting documents: {e}")
48
+ print(event)
49
+ return docs, docs_html, history, used_documents, related_contents
50
+
51
+ def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
52
+ """
53
+ Handles the streaming of the answer and updates the history with the new message content
54
+
55
+ Args:
56
+ history (list[ChatMessage]): The current message history
57
+ event (StreamEvent): The event containing the streamed answer
58
+ start_streaming (bool): A flag indicating if the streaming has started
59
+ new_message_content (str): The content of the new message
60
+
61
+ Returns:
62
+ tuple[list[ChatMessage], bool, str]: The updated history, the updated streaming flag and the updated message content
63
+ """
64
+ if start_streaming == False:
65
+ start_streaming = True
66
+ history.append(ChatMessage(role="assistant", content = ""))
67
+ answer_message_content += event["data"]["chunk"].content
68
+ answer_message_content = parse_output_llm_with_sources(answer_message_content)
69
+ history[-1] = ChatMessage(role="assistant", content = answer_message_content)
70
+ # history.append(ChatMessage(role="assistant", content = new_message_content))
71
+ return history, start_streaming, answer_message_content
72
+
73
+ def handle_retrieved_owid_graphs(event :StreamEvent, graphs_html: str) -> str:
74
+ """
75
+ Handles the retrieved OWID graphs and returns the HTML representation of the graphs
76
+
77
+ Args:
78
+ event (StreamEvent): The event containing the retrieved graphs
79
+ graphs_html (str): The current HTML representation of the graphs
80
+
81
+ Returns:
82
+ str: The updated HTML representation
83
+ """
84
+ try:
85
+ recommended_content = event["data"]["output"]["recommended_content"]
86
+
87
+ unique_graphs = []
88
+ seen_embeddings = set()
89
+
90
+ for x in recommended_content:
91
+ embedding = x.metadata["returned_content"]
92
+
93
+ # Check if the embedding has already been seen
94
+ if embedding not in seen_embeddings:
95
+ unique_graphs.append({
96
+ "embedding": embedding,
97
+ "metadata": {
98
+ "source": x.metadata["source"],
99
+ "category": x.metadata["category"]
100
+ }
101
+ })
102
+ # Add the embedding to the seen set
103
+ seen_embeddings.add(embedding)
104
+
105
+
106
+ categories = {}
107
+ for graph in unique_graphs:
108
+ category = graph['metadata']['category']
109
+ if category not in categories:
110
+ categories[category] = []
111
+ categories[category].append(graph['embedding'])
112
+
113
+
114
+ for category, embeddings in categories.items():
115
+ graphs_html += f"<h3>{category}</h3>"
116
+ for embedding in embeddings:
117
+ graphs_html += f"<div>{embedding}</div>"
118
+
119
+
120
+ except Exception as e:
121
+ print(f"Error getting graphs: {e}")
122
+
123
+ return graphs_html
climateqa/knowledge/openalex.py CHANGED
@@ -41,6 +41,10 @@ class OpenAlex():
41
  break
42
 
43
  df_works = pd.DataFrame(page)
 
 
 
 
44
  df_works = df_works.dropna(subset = ["title"])
45
  df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
46
  df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
@@ -51,8 +55,9 @@ class OpenAlex():
51
  df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
52
 
53
  df_works = df_works.drop(columns = ["abstract_inverted_index"])
54
- # df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
55
-
 
56
  return df_works
57
  else:
58
  raise Exception("Keywords must be a string")
@@ -62,11 +67,10 @@ class OpenAlex():
62
 
63
  scores = reranker.rank(
64
  query,
65
- df["content"].tolist(),
66
- top_k = len(df),
67
  )
68
- scores.sort(key = lambda x : x["corpus_id"])
69
- scores = [x["score"] for x in scores]
70
  df["rerank_score"] = scores
71
  return df
72
 
 
41
  break
42
 
43
  df_works = pd.DataFrame(page)
44
+
45
+ if df_works.empty:
46
+ return df_works
47
+
48
  df_works = df_works.dropna(subset = ["title"])
49
  df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
50
  df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
 
55
  df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
56
 
57
  df_works = df_works.drop(columns = ["abstract_inverted_index"])
58
+ df_works["display_name"] = df_works["primary_location"].apply(lambda x :x["source"] if type(x) == dict and 'source' in x else "").apply(lambda x : x["display_name"] if type(x) == dict and "display_name" in x else "")
59
+ df_works["subtitle"] = df_works["title"].astype(str) + " - " + df_works["display_name"].astype(str) + " - " + df_works["publication_year"].astype(str)
60
+
61
  return df_works
62
  else:
63
  raise Exception("Keywords must be a string")
 
67
 
68
  scores = reranker.rank(
69
  query,
70
+ df["content"].tolist()
 
71
  )
72
+ scores = sorted(scores.results, key = lambda x : x.document.doc_id)
73
+ scores = [x.score for x in scores]
74
  df["rerank_score"] = scores
75
  return df
76
 
climateqa/knowledge/retriever.py CHANGED
@@ -1,81 +1,102 @@
1
- # https://github.com/langchain-ai/langchain/issues/8623
2
-
3
- import pandas as pd
4
-
5
- from langchain_core.retrievers import BaseRetriever
6
- from langchain_core.vectorstores import VectorStoreRetriever
7
- from langchain_core.documents.base import Document
8
- from langchain_core.vectorstores import VectorStore
9
- from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
10
-
11
- from typing import List
12
- from pydantic import Field
13
-
14
- class ClimateQARetriever(BaseRetriever):
15
- vectorstore:VectorStore
16
- sources:list = ["IPCC","IPBES","IPOS"]
17
- reports:list = []
18
- threshold:float = 0.6
19
- k_summary:int = 3
20
- k_total:int = 10
21
- namespace:str = "vectors",
22
- min_size:int = 200,
23
-
24
-
25
- def _get_relevant_documents(
26
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
27
- ) -> List[Document]:
28
-
29
- # Check if all elements in the list are either IPCC or IPBES
30
- assert isinstance(self.sources,list)
31
- assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
32
- assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
33
-
34
- # Prepare base search kwargs
35
- filters = {}
36
-
37
- if len(self.reports) > 0:
38
- filters["short_name"] = {"$in":self.reports}
39
- else:
40
- filters["source"] = { "$in":self.sources}
41
-
42
- # Search for k_summary documents in the summaries dataset
43
- filters_summaries = {
44
- **filters,
45
- "report_type": { "$in":["SPM"]},
46
- }
47
-
48
- docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
49
- docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
50
-
51
- # Search for k_total - k_summary documents in the full reports dataset
52
- filters_full = {
53
- **filters,
54
- "report_type": { "$nin":["SPM"]},
55
- }
56
- k_full = self.k_total - len(docs_summaries)
57
- docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
58
-
59
- # Concatenate documents
60
- docs = docs_summaries + docs_full
61
-
62
- # Filter if scores are below threshold
63
- docs = [x for x in docs if len(x[0].page_content) > self.min_size]
64
- # docs = [x for x in docs if x[1] > self.threshold]
65
-
66
- # Add score to metadata
67
- results = []
68
- for i,(doc,score) in enumerate(docs):
69
- doc.page_content = doc.page_content.replace("\r\n"," ")
70
- doc.metadata["similarity_score"] = score
71
- doc.metadata["content"] = doc.page_content
72
- doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
73
- # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
74
- results.append(doc)
75
-
76
- # Sort by score
77
- # results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
78
-
79
- return results
80
-
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # https://github.com/langchain-ai/langchain/issues/8623
2
+
3
+ # import pandas as pd
4
+
5
+ # from langchain_core.retrievers import BaseRetriever
6
+ # from langchain_core.vectorstores import VectorStoreRetriever
7
+ # from langchain_core.documents.base import Document
8
+ # from langchain_core.vectorstores import VectorStore
9
+ # from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
10
+
11
+ # from typing import List
12
+ # from pydantic import Field
13
+
14
+ # def _add_metadata_and_score(docs: List) -> Document:
15
+ # # Add score to metadata
16
+ # docs_with_metadata = []
17
+ # for i,(doc,score) in enumerate(docs):
18
+ # doc.page_content = doc.page_content.replace("\r\n"," ")
19
+ # doc.metadata["similarity_score"] = score
20
+ # doc.metadata["content"] = doc.page_content
21
+ # doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
22
+ # # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
23
+ # docs_with_metadata.append(doc)
24
+ # return docs_with_metadata
25
+
26
+ # class ClimateQARetriever(BaseRetriever):
27
+ # vectorstore:VectorStore
28
+ # sources:list = ["IPCC","IPBES","IPOS"]
29
+ # reports:list = []
30
+ # threshold:float = 0.6
31
+ # k_summary:int = 3
32
+ # k_total:int = 10
33
+ # namespace:str = "vectors",
34
+ # min_size:int = 200,
35
+
36
+
37
+
38
+ # def _get_relevant_documents(
39
+ # self, query: str, *, run_manager: CallbackManagerForRetrieverRun
40
+ # ) -> List[Document]:
41
+
42
+ # # Check if all elements in the list are either IPCC or IPBES
43
+ # assert isinstance(self.sources,list)
44
+ # assert self.sources
45
+ # assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
46
+ # assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
47
+
48
+ # # Prepare base search kwargs
49
+ # filters = {}
50
+
51
+ # if len(self.reports) > 0:
52
+ # filters["short_name"] = {"$in":self.reports}
53
+ # else:
54
+ # filters["source"] = { "$in":self.sources}
55
+
56
+ # # Search for k_summary documents in the summaries dataset
57
+ # filters_summaries = {
58
+ # **filters,
59
+ # "chunk_type":"text",
60
+ # "report_type": { "$in":["SPM"]},
61
+ # }
62
+
63
+ # docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
64
+ # docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
65
+ # # docs_summaries = []
66
+
67
+ # # Search for k_total - k_summary documents in the full reports dataset
68
+ # filters_full = {
69
+ # **filters,
70
+ # "chunk_type":"text",
71
+ # "report_type": { "$nin":["SPM"]},
72
+ # }
73
+ # k_full = self.k_total - len(docs_summaries)
74
+ # docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
75
+
76
+ # # Images
77
+ # filters_image = {
78
+ # **filters,
79
+ # "chunk_type":"image"
80
+ # }
81
+ # docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
82
+
83
+ # # docs_images = []
84
+
85
+ # # Concatenate documents
86
+ # # docs = docs_summaries + docs_full + docs_images
87
+
88
+ # # Filter if scores are below threshold
89
+ # # docs = [x for x in docs if x[1] > self.threshold]
90
+
91
+ # docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
92
+
93
+ # # Filter if length are below threshold
94
+ # docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
95
+ # docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
96
+
97
+
98
+ # return {
99
+ # "docs_summaries" : docs_summaries,
100
+ # "docs_full" : docs_full,
101
+ # "docs_images" : docs_images,
102
+ # }
climateqa/utils.py CHANGED
@@ -20,3 +20,16 @@ def get_image_from_azure_blob_storage(path):
20
  file_object = get_file_from_azure_blob_storage(path)
21
  image = Image.open(file_object)
22
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  file_object = get_file_from_azure_blob_storage(path)
21
  image = Image.open(file_object)
22
  return image
23
+
24
+ def remove_duplicates_keep_highest_score(documents):
25
+ unique_docs = {}
26
+
27
+ for doc in documents:
28
+ doc_id = doc.metadata.get('doc_id')
29
+ if doc_id in unique_docs:
30
+ if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
31
+ unique_docs[doc_id] = doc
32
+ else:
33
+ unique_docs[doc_id] = doc
34
+
35
+ return list(unique_docs.values())
front/utils.py CHANGED
@@ -1,12 +1,19 @@
1
 
2
  import re
 
 
 
 
 
 
3
 
4
- def make_pairs(lst):
 
5
  """from a list of even lenght, make tupple pairs"""
6
  return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
7
 
8
 
9
- def serialize_docs(docs):
10
  new_docs = []
11
  for doc in docs:
12
  new_doc = {}
@@ -17,7 +24,7 @@ def serialize_docs(docs):
17
 
18
 
19
 
20
- def parse_output_llm_with_sources(output):
21
  # Split the content into a list of text and "[Doc X]" references
22
  content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
23
  parts = []
@@ -32,6 +39,119 @@ def parse_output_llm_with_sources(output):
32
  content_parts = "".join(parts)
33
  return content_parts
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def make_html_source(source,i):
37
  meta = source.metadata
@@ -108,6 +228,31 @@ def make_html_source(source,i):
108
  return card
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def make_html_figure_sources(source,i,img_str):
112
  meta = source.metadata
113
  content = source.page_content.strip()
 
1
 
2
  import re
3
+ from collections import defaultdict
4
+ from climateqa.utils import get_image_from_azure_blob_storage
5
+ from climateqa.engine.chains.prompts import audience_prompts
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
 
10
+
11
+ def make_pairs(lst:list)->list:
12
  """from a list of even lenght, make tupple pairs"""
13
  return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
14
 
15
 
16
+ def serialize_docs(docs:list)->list:
17
  new_docs = []
18
  for doc in docs:
19
  new_doc = {}
 
24
 
25
 
26
 
27
+ def parse_output_llm_with_sources(output:str)->str:
28
  # Split the content into a list of text and "[Doc X]" references
29
  content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
30
  parts = []
 
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
42
+ def process_figures(docs:list)->tuple:
43
+ gallery=[]
44
+ used_figures =[]
45
+ figures = '<div class="figures-container"><p></p> </div>'
46
+ docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
47
+ for i, doc in enumerate(docs_figures):
48
+ if doc.metadata["chunk_type"] == "image":
49
+ if doc.metadata["figure_code"] != "N/A":
50
+ title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
51
+ else:
52
+ title = f"{doc.metadata['short_name']}"
53
+
54
+
55
+ if title not in used_figures:
56
+ used_figures.append(title)
57
+ try:
58
+ key = f"Image {i+1}"
59
+
60
+ image_path = doc.metadata["image_path"].split("documents/")[1]
61
+ img = get_image_from_azure_blob_storage(image_path)
62
+
63
+ # Convert the image to a byte buffer
64
+ buffered = BytesIO()
65
+ max_image_length = 500
66
+ img_resized = img.resize((max_image_length, int(max_image_length * img.size[1]/img.size[0])))
67
+ img_resized.save(buffered, format="PNG")
68
+
69
+ img_str = base64.b64encode(buffered.getvalue()).decode()
70
+
71
+ figures = figures + make_html_figure_sources(doc, i, img_str)
72
+ gallery.append(img)
73
+ except Exception as e:
74
+ print(f"Skipped adding image {i} because of {e}")
75
+
76
+ return figures, gallery
77
+
78
+
79
+ def generate_html_graphs(graphs:list)->str:
80
+ # Organize graphs by category
81
+ categories = defaultdict(list)
82
+ for graph in graphs:
83
+ category = graph['metadata']['category']
84
+ categories[category].append(graph['embedding'])
85
+
86
+ # Begin constructing the HTML
87
+ html_code = '''
88
+ <!DOCTYPE html>
89
+ <html lang="en">
90
+ <head>
91
+ <meta charset="UTF-8">
92
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
93
+ <title>Graphs by Category</title>
94
+ <style>
95
+ .tab-content {
96
+ display: none;
97
+ }
98
+ .tab-content.active {
99
+ display: block;
100
+ }
101
+ .tabs {
102
+ margin-bottom: 20px;
103
+ }
104
+ .tab-button {
105
+ background-color: #ddd;
106
+ border: none;
107
+ padding: 10px 20px;
108
+ cursor: pointer;
109
+ margin-right: 5px;
110
+ }
111
+ .tab-button.active {
112
+ background-color: #ccc;
113
+ }
114
+ </style>
115
+ <script>
116
+ function showTab(tabId) {
117
+ var contents = document.getElementsByClassName('tab-content');
118
+ var buttons = document.getElementsByClassName('tab-button');
119
+ for (var i = 0; i < contents.length; i++) {
120
+ contents[i].classList.remove('active');
121
+ buttons[i].classList.remove('active');
122
+ }
123
+ document.getElementById(tabId).classList.add('active');
124
+ document.querySelector('button[data-tab="'+tabId+'"]').classList.add('active');
125
+ }
126
+ </script>
127
+ </head>
128
+ <body>
129
+ <div class="tabs">
130
+ '''
131
+
132
+ # Add buttons for each category
133
+ for i, category in enumerate(categories.keys()):
134
+ active_class = 'active' if i == 0 else ''
135
+ html_code += f'<button class="tab-button {active_class}" onclick="showTab(\'tab-{i}\')" data-tab="tab-{i}">{category}</button>'
136
+
137
+ html_code += '</div>'
138
+
139
+ # Add content for each category
140
+ for i, (category, embeds) in enumerate(categories.items()):
141
+ active_class = 'active' if i == 0 else ''
142
+ html_code += f'<div id="tab-{i}" class="tab-content {active_class}">'
143
+ for embed in embeds:
144
+ html_code += embed
145
+ html_code += '</div>'
146
+
147
+ html_code += '''
148
+ </body>
149
+ </html>
150
+ '''
151
+
152
+ return html_code
153
+
154
+
155
 
156
  def make_html_source(source,i):
157
  meta = source.metadata
 
228
  return card
229
 
230
 
231
+ def make_html_papers(df,i):
232
+ title = df['title'][i]
233
+ content = df['abstract'][i]
234
+ url = df['doi'][i]
235
+ publication_date = df['publication_year'][i]
236
+ subtitle = df['subtitle'][i]
237
+
238
+ card = f"""
239
+ <div class="card" id="doc{i}">
240
+ <div class="card-content">
241
+ <h2>Doc {i+1} - {title}</h2>
242
+ <p>{content}</p>
243
+ </div>
244
+ <div class="card-footer">
245
+ <span>{subtitle}</span>
246
+ <a href="{url}" target="_blank" class="pdf-link">
247
+ <span role="img" aria-label="Open paper">🔗</span>
248
+ </a>
249
+ </div>
250
+ </div>
251
+ """
252
+
253
+ return card
254
+
255
+
256
  def make_html_figure_sources(source,i,img_str):
257
  meta = source.metadata
258
  content = source.page_content.strip()
sandbox/20240310 - CQA - Semantic Routing 1.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
sandbox/20240702 - CQA - Graph Functionality.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
sandbox/20241104 - CQA - StepByStep CQA.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
style.css CHANGED
@@ -3,6 +3,61 @@
3
  --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
4
  } */
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  /* fix for huggingface infinite growth*/
8
  main.flex.flex-1.flex-col {
@@ -85,7 +140,12 @@ body.dark .tip-box * {
85
  font-size:14px !important;
86
 
87
  }
88
-
 
 
 
 
 
89
 
90
  a {
91
  text-decoration: none;
@@ -161,60 +221,111 @@ a {
161
  border:none;
162
  }
163
 
164
- /* .gallery-item > div:hover{
165
- background-color:#7494b0 !important;
166
- color:white!important;
167
- }
168
 
169
- .gallery-item:hover{
170
- border:#7494b0 !important;
171
  }
172
 
173
- .gallery-item > div{
174
- background-color:white !important;
175
- color:#577b9b!important;
176
  }
177
 
178
- .label{
179
- color:#577b9b!important;
180
- } */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- /* .paginate{
183
- color:#577b9b!important;
 
 
184
  } */
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
 
 
 
 
187
 
188
- /* span[data-testid="block-info"]{
189
- background:none !important;
190
- color:#577b9b;
191
- } */
192
 
193
- /* Pseudo-element for the circularly cropped picture */
194
- /* .message.bot::before {
195
- content: '';
196
  position: absolute;
197
- top: -10px;
198
- left: -10px;
199
- width: 30px;
200
- height: 30px;
201
- background-image: var(--user-image);
202
- background-size: cover;
203
- background-position: center;
 
 
204
  border-radius: 50%;
205
- z-index: 10;
206
- }
207
- */
208
-
209
- label.selected{
210
- background:none !important;
211
  }
212
-
213
- #submit-button{
214
- padding:0px !important;
215
  }
216
 
 
 
217
  @media screen and (min-width: 1024px) {
 
 
 
 
 
 
218
  .gradio-container {
219
  max-height: calc(100vh - 190px) !important;
220
  overflow: hidden;
@@ -225,6 +336,8 @@ label.selected{
225
 
226
  } */
227
 
 
 
228
  div#tab-examples{
229
  height:calc(100vh - 190px) !important;
230
  overflow-y: scroll !important;
@@ -236,6 +349,10 @@ label.selected{
236
  overflow-y: scroll !important;
237
  /* overflow-y: auto !important; */
238
  }
 
 
 
 
239
 
240
  div#sources-figures{
241
  height:calc(100vh - 300px) !important;
@@ -243,6 +360,18 @@ label.selected{
243
  overflow-y: scroll !important;
244
  }
245
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  div#tab-config{
247
  height:calc(100vh - 190px) !important;
248
  overflow-y: scroll !important;
@@ -409,8 +538,7 @@ span.chatbot > p > img{
409
  }
410
 
411
  #dropdown-samples{
412
- /*! border:none !important; */
413
- /*! border-width:0px !important; */
414
  background:none !important;
415
 
416
  }
@@ -468,6 +596,10 @@ span.chatbot > p > img{
468
  input[type="checkbox"]:checked + .dropdown-content {
469
  display: block;
470
  }
 
 
 
 
471
 
472
  .dropdown-content {
473
  display: none;
@@ -489,7 +621,7 @@ span.chatbot > p > img{
489
  border-bottom: 5px solid black;
490
  }
491
 
492
- .loader {
493
  border: 1px solid #d0d0d0 !important; /* Light grey background */
494
  border-top: 1px solid #db3434 !important; /* Blue color */
495
  border-right: 1px solid #3498db !important; /* Blue color */
@@ -499,41 +631,64 @@ span.chatbot > p > img{
499
  animation: spin 2s linear infinite;
500
  display:inline-block;
501
  margin-right:10px !important;
502
- }
503
 
504
- .checkmark{
505
  color:green !important;
506
  font-size:18px;
507
  margin-right:10px !important;
508
- }
509
 
510
- @keyframes spin {
511
  0% { transform: rotate(0deg); }
512
  100% { transform: rotate(360deg); }
513
- }
514
 
515
 
516
- .relevancy-score{
517
  margin-top:10px !important;
518
  font-size:10px !important;
519
  font-style:italic;
520
- }
521
 
522
- .score-green{
523
  color:green !important;
524
- }
525
 
526
- .score-orange{
527
  color:orange !important;
528
- }
529
 
530
- .score-red{
531
  color:red !important;
532
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  .message-buttons-left.panel.message-buttons.with-avatar {
534
  display: none;
535
  }
536
 
 
537
  /* Specific fixes for Hugging Face Space iframe */
538
  .h-full {
539
  height: auto !important;
 
3
  --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
4
  } */
5
 
6
+ #tab-recommended_content{
7
+ padding-top: 0px;
8
+ padding-left : 0px;
9
+ padding-right: 0px;
10
+ }
11
+ #group-subtabs {
12
+ /* display: block; */
13
+ width: 100%; /* Ensures the parent uses the full width */
14
+ position : sticky;
15
+ }
16
+
17
+ #group-subtabs .tab-container {
18
+ display: flex;
19
+ text-align: center;
20
+ width: 100%; /* Ensures the tabs span the full width */
21
+ }
22
+
23
+ #group-subtabs .tab-container button {
24
+ flex: 1; /* Makes each button take equal width */
25
+ }
26
+
27
+
28
+ #papers-summary-popup button span{
29
+ /* make label of accordio in bold, center, and bigger */
30
+ font-size: 16px;
31
+ font-weight: bold;
32
+ text-align: center;
33
+
34
+ }
35
+
36
+ #papers-relevant-popup span{
37
+ /* make label of accordio in bold, center, and bigger */
38
+ font-size: 16px;
39
+ font-weight: bold;
40
+ text-align: center;
41
+ }
42
+
43
+
44
+
45
+ #tab-citations .button{
46
+ padding: 12px 16px;
47
+ font-size: 16px;
48
+ font-weight: bold;
49
+ cursor: pointer;
50
+ border: none;
51
+ outline: none;
52
+ text-align: left;
53
+ transition: background-color 0.3s ease;
54
+ }
55
+
56
+
57
+ .gradio-container {
58
+ width: 100%!important;
59
+ max-width: 100% !important;
60
+ }
61
 
62
  /* fix for huggingface infinite growth*/
63
  main.flex.flex-1.flex-col {
 
140
  font-size:14px !important;
141
 
142
  }
143
+ .card-content img {
144
+ display: block;
145
+ margin: auto;
146
+ max-width: 100%; /* Ensures the image is responsive */
147
+ height: auto;
148
+ }
149
 
150
  a {
151
  text-decoration: none;
 
221
  border:none;
222
  }
223
 
 
 
 
 
224
 
225
+ label.selected{
226
+ background: #93c5fd !important;
227
  }
228
 
229
+ #submit-button{
230
+ padding:0px !important;
 
231
  }
232
 
233
+ #modal-config .block.modal-block.padded {
234
+ padding-top: 25px;
235
+ height: 100vh;
236
+
237
+ }
238
+ #modal-config .modal-container{
239
+ margin: 0px;
240
+ padding: 0px;
241
+ }
242
+ /* Modal styles */
243
+ #modal-config {
244
+ position: fixed;
245
+ top: 0;
246
+ left: 0;
247
+ height: 100vh;
248
+ width: 500px;
249
+ background-color: white;
250
+ box-shadow: 2px 0 10px rgba(0, 0, 0, 0.1);
251
+ z-index: 1000;
252
+ padding: 15px;
253
+ transform: none;
254
+ }
255
+ #modal-config .close{
256
+ display: none;
257
+ }
258
 
259
+ /* Push main content to the right when modal is open */
260
+ /* .modal ~ * {
261
+ margin-left: 300px;
262
+ transition: margin-left 0.3s ease;
263
  } */
264
 
265
+ #modal-config .modal .wrap ul{
266
+ position:static;
267
+ top: 100%;
268
+ left: 0;
269
+ /* min-height: 100px; */
270
+ height: 100%;
271
+ /* margin-top: 0; */
272
+ z-index: 9999;
273
+ pointer-events: auto;
274
+ height: 200px;
275
+ }
276
+ #config-button{
277
+ background: none;
278
+ border: none;
279
+ padding: 8px;
280
+ cursor: pointer;
281
+ width: 40px;
282
+ height: 40px;
283
+ display: flex;
284
+ align-items: center;
285
+ justify-content: center;
286
+ border-radius: 50%;
287
+ transition: background-color 0.2s;
288
+ }
289
 
290
+ #config-button::before {
291
+ content: '⚙️';
292
+ font-size: 20px;
293
+ }
294
 
295
+ #config-button:hover {
296
+ background-color: rgba(0, 0, 0, 0.1);
297
+ }
 
298
 
299
+ #checkbox-config{
300
+ display: block;
 
301
  position: absolute;
302
+ background: none;
303
+ border: none;
304
+ padding: 8px;
305
+ cursor: pointer;
306
+ width: 40px;
307
+ height: 40px;
308
+ display: flex;
309
+ align-items: center;
310
+ justify-content: center;
311
  border-radius: 50%;
312
+ transition: background-color 0.2s;
313
+ font-size: 20px;
314
+ text-align: center;
 
 
 
315
  }
316
+ #checkbox-config:checked{
317
+ display: block;
 
318
  }
319
 
320
+
321
+
322
  @media screen and (min-width: 1024px) {
323
+ /* Additional style for scrollable tab content */
324
+ /* div#tab-recommended_content {
325
+ overflow-y: auto;
326
+ max-height: 80vh;
327
+ } */
328
+
329
  .gradio-container {
330
  max-height: calc(100vh - 190px) !important;
331
  overflow: hidden;
 
336
 
337
  } */
338
 
339
+
340
+
341
  div#tab-examples{
342
  height:calc(100vh - 190px) !important;
343
  overflow-y: scroll !important;
 
349
  overflow-y: scroll !important;
350
  /* overflow-y: auto !important; */
351
  }
352
+ div#graphs-container{
353
+ height:calc(100vh - 210px) !important;
354
+ overflow-y: scroll !important;
355
+ }
356
 
357
  div#sources-figures{
358
  height:calc(100vh - 300px) !important;
 
360
  overflow-y: scroll !important;
361
  }
362
 
363
+ div#graphs-container{
364
+ height:calc(100vh - 300px) !important;
365
+ max-height: 90vh !important;
366
+ overflow-y: scroll !important;
367
+ }
368
+
369
+ div#tab-citations{
370
+ height:calc(100vh - 300px) !important;
371
+ max-height: 90vh !important;
372
+ overflow-y: scroll !important;
373
+ }
374
+
375
  div#tab-config{
376
  height:calc(100vh - 190px) !important;
377
  overflow-y: scroll !important;
 
538
  }
539
 
540
  #dropdown-samples{
541
+
 
542
  background:none !important;
543
 
544
  }
 
596
  input[type="checkbox"]:checked + .dropdown-content {
597
  display: block;
598
  }
599
+
600
+ #checkbox-chat input[type="checkbox"] {
601
+ display: flex !important;
602
+ }
603
 
604
  .dropdown-content {
605
  display: none;
 
621
  border-bottom: 5px solid black;
622
  }
623
 
624
+ .loader {
625
  border: 1px solid #d0d0d0 !important; /* Light grey background */
626
  border-top: 1px solid #db3434 !important; /* Blue color */
627
  border-right: 1px solid #3498db !important; /* Blue color */
 
631
  animation: spin 2s linear infinite;
632
  display:inline-block;
633
  margin-right:10px !important;
634
+ }
635
 
636
+ .checkmark{
637
  color:green !important;
638
  font-size:18px;
639
  margin-right:10px !important;
640
+ }
641
 
642
+ @keyframes spin {
643
  0% { transform: rotate(0deg); }
644
  100% { transform: rotate(360deg); }
645
+ }
646
 
647
 
648
+ .relevancy-score{
649
  margin-top:10px !important;
650
  font-size:10px !important;
651
  font-style:italic;
652
+ }
653
 
654
+ .score-green{
655
  color:green !important;
656
+ }
657
 
658
+ .score-orange{
659
  color:orange !important;
660
+ }
661
 
662
+ .score-red{
663
  color:red !important;
664
+ }
665
+
666
+ /* Mobile specific adjustments */
667
+ @media screen and (max-width: 767px) {
668
+ div#tab-recommended_content {
669
+ max-height: 50vh; /* Reduce height for smaller screens */
670
+ overflow-y: auto;
671
+ }
672
+ }
673
+
674
+ /* Additional style for scrollable tab content */
675
+ div#tab-saved-graphs {
676
+ overflow-y: auto; /* Enable vertical scrolling */
677
+ max-height: 80vh; /* Adjust height as needed */
678
+ }
679
+
680
+ /* Mobile specific adjustments */
681
+ @media screen and (max-width: 767px) {
682
+ div#tab-saved-graphs {
683
+ max-height: 50vh; /* Reduce height for smaller screens */
684
+ overflow-y: auto;
685
+ }
686
+ }
687
  .message-buttons-left.panel.message-buttons.with-avatar {
688
  display: none;
689
  }
690
 
691
+
692
  /* Specific fixes for Hugging Face Space iframe */
693
  .h-full {
694
  height: auto !important;