Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
import os | |
import yaml | |
from qdrant_client import models | |
from tqdm import tqdm | |
from collections import defaultdict | |
import pandas as pd | |
from spinoza_project.source.backend.llm_utils import ( | |
get_llm_api, | |
) | |
from spinoza_project.source.frontend.utils import ( | |
init_env, | |
parse_output_llm_with_sources, | |
) | |
from spinoza_project.source.frontend.gradio_utils import ( | |
get_sources, | |
set_prompts, | |
get_config, | |
get_prompts, | |
get_assets, | |
get_theme, | |
get_init_prompt, | |
get_synthesis_prompt, | |
get_qdrants, | |
start_agents, | |
end_agents, | |
next_call, | |
zip_longest_fill, | |
reformulate, | |
answer, | |
get_text, | |
update_translation, | |
) | |
from assets.utils_javascript import ( | |
accordion_trigger, | |
accordion_trigger_end, | |
accordion_trigger_spinoza, | |
accordion_trigger_spinoza_end, | |
update_footer | |
) | |
init_env() | |
with open("./spinoza_project/config.yaml") as f: | |
config = yaml.full_load(f) | |
## Loading Prompts | |
print("Loading Prompts") | |
prompts = get_prompts(config) | |
chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config) | |
synthesis_prompt_template = get_synthesis_prompt(config) | |
## Building LLM | |
print("Building LLM") | |
groq_model_name = config.get("groq_model_name", "") | |
llm = get_llm_api(groq_model_name) | |
## Loading BDDs | |
print("Loading Databases") | |
qdrants, df_qdrants = get_qdrants(config) | |
dataframes_by_source = { | |
source: df_qdrants[df_qdrants['Source'] == source].drop(columns=['Source']) | |
for source in df_qdrants['Source'].unique() | |
} | |
for source, df in dataframes_by_source.items(): | |
dataframes_by_source[source]['Filter'] = dataframes_by_source[source]['Filter'].fillna('Unknown') | |
unknown_percentage = df.apply(lambda x: (x == 'Unknown').mean()) | |
columns_to_drop = unknown_percentage[unknown_percentage == 1.0].index | |
if len(columns_to_drop) > 0: | |
print(f"Deleting following columns for {source}: {columns_to_drop.tolist()}") | |
dataframes_by_source[source] = df.drop(columns=columns_to_drop) | |
## Loading Assets | |
print("Loading assets") | |
css, source_information_fr, source_information_en, about_contact_fr, about_contact_en = get_assets() | |
theme = get_theme() | |
init_prompt = get_init_prompt() | |
## Updating TRANSLATIONS dictionnary | |
list_tabs = list(config["tabs"]) | |
update_translation(list_tabs, config) | |
def get_source_df(source_name): | |
return dataframes_by_source.get(source_name, pd.DataFrame()) | |
LANGUAGE_MAPPING = { | |
"fr": "french/français", | |
"en": "english/anglais" | |
} | |
def reformulate_questions( | |
lang_component, | |
question, | |
llm=llm, | |
chat_reformulation_prompts=chat_reformulation_prompts, | |
config=config, | |
): | |
lang = lang_component.value if hasattr(lang_component, 'value') else lang_component | |
language = LANGUAGE_MAPPING.get(lang, "french/français") | |
for elt in zip_longest_fill( | |
*[ | |
reformulate(language, llm, chat_reformulation_prompts, question, tab, config=config) | |
for tab in config["tabs"] | |
] | |
): | |
time.sleep(0.02) | |
yield elt | |
def retrieve_sources( | |
*questions, | |
filters_dict, | |
qdrants=qdrants, | |
config=config, | |
): | |
if filters_dict is None: | |
filters_dict = {} | |
formated_sources, text_sources = get_sources( | |
questions, filters_dict, qdrants, config | |
) | |
return (formated_sources, *text_sources) | |
def retrieve_sources_wrapper(*args): | |
questions = list(args[:-1]) | |
filters = args[-1] | |
return retrieve_sources( | |
questions, | |
filters_dict=filters | |
) | |
def answer_questions( | |
lang_component, | |
*questions_sources, | |
llm=llm, | |
chat_qa_prompts=chat_qa_prompts, | |
config=config | |
): | |
lang = lang_component.value if hasattr(lang_component, 'value') else lang_component | |
language = LANGUAGE_MAPPING.get(lang, "french/français") | |
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] | |
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] | |
for elt in zip_longest_fill( | |
*[ | |
answer(language, llm, chat_qa_prompts, question, source, tab, config) | |
for question, source, tab in zip(questions, sources, config["tabs"]) | |
] | |
): | |
time.sleep(0.02) | |
yield [ | |
[(question, parse_output_llm_with_sources(ans))] | |
for question, ans in zip(questions, elt) | |
] | |
def get_synthesis( | |
lang_component, | |
question, | |
*answers, | |
llm=llm, | |
synthesis_prompt_template=synthesis_prompt_template, | |
config=config, | |
): | |
lang = lang_component.value if hasattr(lang_component, 'value') else lang_component | |
language = LANGUAGE_MAPPING.get(lang, "french/français") | |
answer = [] | |
for i, tab in enumerate(config["tabs"]): | |
if len(str(answers[i])) >= 100: | |
answer.append( | |
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") | |
) | |
if len(answer) == 0: | |
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" | |
else: | |
for elt in llm.stream( | |
synthesis_prompt_template, | |
{ | |
"question": question.replace("<p>", "").replace("</p>\n", ""), | |
"answers": "\n\n".join(answer), | |
"language": language | |
}, | |
): | |
time.sleep(0.01) | |
yield [(question, parse_output_llm_with_sources(elt))] | |
def get_unique_values_filters(df): | |
filters_values = sorted([ | |
str(x) for x in df['Filter'].unique() | |
if pd.notna(x) and str(x).strip() != '' | |
]) | |
return filters_values | |
def filter_data(filter, source): | |
if source not in dataframes_by_source: | |
raise ValueError(f"'{source}' not found withing the sources availible") | |
df = dataframes_by_source[source] | |
if filter: | |
df = df[df['Filter'].fillna('').astype(str).isin(filter)] | |
return df.values.tolist() | |
def update_filters(filters_dict, agent, values): | |
field = "file_filtering_modality" | |
if filters_dict is None: | |
filters_dict = {} | |
new_filters = dict(filters_dict) | |
if agent not in new_filters: | |
new_filters[agent] = {} | |
if not values or isinstance(values, list): | |
if field in new_filters[agent]: | |
del new_filters[agent][field] | |
if not new_filters[agent]: | |
del new_filters[agent] | |
else: | |
new_filters[agent][field] = values | |
return new_filters, new_filters | |
with gr.Blocks( | |
title=f"🔍 Spinoza", | |
css=css, | |
js=update_footer(), | |
theme=theme, | |
) as demo: | |
accordions_qa = {} | |
accordions_filters = {} | |
current_language = gr.State(value="fr") | |
chatbots = {} | |
question = gr.State("") | |
agt_input_flt = {} | |
agt_desc = {} | |
agt_input_dsp = gr.State({}) | |
docs_textbox = gr.State([""]) | |
agent_questions = {elt: gr.State("") for elt in config["tabs"]} | |
component_sources = {elt: gr.State("") for elt in config["tabs"]} | |
text_sources = {elt: gr.State("") for elt in config["tabs"]} | |
tab_states = {elt: gr.State(elt) for elt in config["tabs"]} | |
filters_state = gr.State({}) | |
filters_display = gr.JSON( | |
label="Filtres sélectionnés", | |
value={}, | |
visible=False | |
) | |
with gr.Row(elem_classes="header-row"): | |
button_fr = gr.Button("", elem_id="fr-button", elem_classes="lang-button", icon='./assets/logos/france_round.png') | |
button_en = gr.Button("", elem_id="en-button", elem_classes="lang-button", icon='./assets/logos/us_round.png') | |
with gr.Row(elem_classes="main-row"): | |
with gr.Tab("Q&A", elem_id="main-component"): | |
with gr.Row(elem_id="chatbot-row"): | |
with gr.Column(scale=2, elem_id="center-panel"): | |
with gr.Row(elem_id="input-message"): | |
ask = gr.Textbox( | |
placeholder=get_text("ask_placeholder", current_language.value), | |
show_label=False, | |
scale=7, | |
lines=1, | |
interactive=True, | |
elem_id="input-textbox", | |
) | |
with gr.Group(elem_id="chatbot-group"): | |
for tab in list(config["tabs"].keys()): | |
agent_name = get_text(f"agent_{config['source_mapping'][tab]}_qa", current_language.value) | |
elem_id = f"accordion-{config['source_mapping'][tab]}" | |
elem_classes = "accordion accordion-agent" | |
with gr.Accordion( | |
label=agent_name, | |
open=False, | |
elem_id=elem_id, | |
elem_classes=elem_classes, | |
) as accordions_qa[config['source_mapping'][tab]]: | |
# chatbot_key = agent_name.lower().replace(" ", "_") | |
chatbots[tab] = gr.Chatbot( | |
value=None, | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
( | |
"./assets/logos/spinoza.png" | |
if agent_name == "Spinoza" | |
else None | |
), | |
) | |
) | |
agent_name = "Spinoza" | |
with gr.Accordion( | |
label=agent_name, | |
open=True, | |
elem_id="accordion-Spinoza", | |
elem_classes="accordion accordion-agent spinoza-agent", | |
) as accordion_spinoza: | |
# chatbot_key = agent_name.lower().replace(" ", "_") | |
chatbots["Spinoza"] = gr.Chatbot( | |
value=([(None, get_text("init_prompt", current_language.value))]), | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
"./assets/logos/spinoza.png", | |
), | |
) | |
with gr.Column(scale=1, variant="panel", elem_id="right-panel"): | |
with gr.TabItem("Sources", elem_id="tab-sources", id=0): | |
sources_textbox = gr.HTML( | |
show_label=False, elem_id="sources-textbox" | |
) | |
with gr.Tab(label=get_text("source_filter_label", current_language.value), elem_id="filter-component") as source_filter_tab: | |
source_filter_title= gr.Markdown(value=get_text("source_filter_title", current_language.value)) | |
source_filter_subtitle = gr.Markdown(value=get_text("source_filter_subtitle", current_language.value)) | |
with gr.Row(elem_id="filter-row"): | |
with gr.Column(scale=2, elem_id="filter-center-panel"): | |
with gr.Group(elem_id="filter-group"): | |
for tab in list(config["tabs"].keys()): | |
agent_name = get_text(f"agent_{config['source_mapping'][tab]}_flt", current_language.value) | |
elem_id = f"accordion-filter-{config['source_mapping'][tab]}" | |
elem_classes = "accordion accordion-source" | |
with gr.Accordion( | |
label=agent_name, | |
open=False, | |
elem_id=elem_id, | |
elem_classes=elem_classes, | |
) as accordions_filters[config['source_mapping'][tab]]: | |
question_filter = gr.Markdown(value=get_text("question_filter", current_language.value)) | |
with gr.Tabs(): | |
df = get_source_df(config['source_mapping'][tab]) | |
if not df.empty and 'Filter' in df.columns: | |
filters = get_unique_values_filters(df) | |
with gr.Row(): | |
var_name = f"{config['source_mapping'][tab]}_input_flt" | |
agt_input_flt[var_name] = gr.CheckboxGroup( | |
[filter for filter in filters], | |
label="Filter(s):" | |
) | |
agt_input_flt[var_name].change( | |
fn=update_filters, | |
inputs=[filters_state, gr.State(config['source_mapping'][tab]), agt_input_flt[var_name]], | |
outputs=[filters_state, filters_display] | |
) | |
else: | |
gr.Markdown("**Error:** No data / 'Filter' column doesn't exist...") | |
with gr.Tab(label=get_text("source_informatation_label", current_language.value), elem_id="source-component") as source_information_tab: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
display_info_desc = gr.Markdown(value=get_text("display_info_desc", current_language.value)) | |
accordions_inf = {} | |
with gr.Tabs(elem_id="main-tab-disp"): | |
for tab in list(config["tabs"].keys()): | |
agent_name = get_text(f"agent_{config['source_mapping'][tab]}_tab", current_language.value) | |
elem_id = f"accordion-{config['source_mapping'][tab]}-tab" | |
elem_classes = "disp-tabs" | |
with gr.Tab( | |
label=agent_name, | |
elem_id=elem_id, | |
elem_classes=elem_classes | |
) as accordions_inf[config['source_mapping'][tab]]: | |
var_name = f"{config['source_mapping'][tab]}_desc" | |
agt_desc[var_name] = gr.Markdown(value=get_text(f"{config['source_mapping'][tab]}_desc", current_language.value)) | |
df = get_source_df(config['source_mapping'][tab]) | |
if not df.empty and 'Filter' in df.columns: | |
filters = get_unique_values_filters(df) | |
with gr.Row(): | |
var_name = f"{config['source_mapping'][tab]}_input_dsp" | |
agt_input_dsp.value[var_name] = gr.CheckboxGroup( | |
[filter for filter in filters], | |
label="Filter(s):" | |
) | |
output_df = gr.Dataframe( | |
headers=['Title', 'Pages', 'Filter Category', 'Publishing Date'], | |
datatype=['str', 'number', 'str', 'number'], | |
value=df.values.tolist(), | |
column_widths=[300, 100, 100, 150], | |
wrap=True | |
) | |
agt_input_dsp.value[var_name].change( | |
filter_data, | |
inputs=[agt_input_dsp.value[var_name]]+[gr.State(config['source_mapping'][tab])], | |
outputs=[output_df] | |
) | |
else: | |
gr.Markdown("**Error:** No data / 'Filter' column doesn't exist...") | |
with gr.Tab(label=get_text("contact_label", current_language.value), elem_id="contact-component") as contact_label: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
contact_info = gr.Markdown(value=about_contact_fr) | |
ask.submit( | |
start_agents, inputs=[current_language], outputs=[chatbots["Spinoza"]] + [source_filter_tab], js=accordion_trigger() | |
).then( | |
fn=reformulate_questions, | |
inputs=[current_language]+ | |
[ask], | |
outputs=[agent_questions[tab] for tab in config["tabs"]], | |
).then( | |
fn=retrieve_sources_wrapper, | |
inputs=[agent_questions[tab] for tab in config["tabs"]] + [filters_state], | |
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], | |
).then( | |
fn=answer_questions, | |
inputs=[current_language] | |
+ [agent_questions[tab] for tab in config["tabs"]] | |
+ [text_sources[tab] for tab in config["tabs"]], | |
outputs=[chatbots[tab] for tab in config["tabs"]], | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() | |
).then( | |
fn=get_synthesis, | |
inputs=[current_language] | |
+ [ask] | |
+ [chatbots[tab] for tab in config["tabs"]], | |
outputs=[chatbots["Spinoza"]], | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() | |
).then( | |
fn=end_agents, inputs=[current_language], outputs=[source_filter_tab] | |
) | |
def reset_app(language): | |
chatbot_updates = {} | |
for tab in config["tabs"]: | |
chatbot_updates[tab] = gr.update(value=None) | |
chatbot_updates["Spinoza"] = gr.update(value=[(None, get_text("init_prompt", language))]) | |
empty_checkbox = gr.update(value=None) | |
checkbox_components = list(agt_input_flt.keys()) + list(agt_input_dsp.value.keys()) | |
checkbox_updates = {component: empty_checkbox for component in checkbox_components} | |
return { | |
"chatbots": chatbot_updates, | |
"filters_state": gr.update(value={}), | |
"filters_display": gr.update(value={}), | |
"ask": gr.update(value="", placeholder=get_text("ask_placeholder", language)), | |
"sources_textbox": gr.update(value=""), | |
"checkbox_updates": checkbox_updates | |
} | |
def toggle_language_fr(): | |
reset_state = reset_app("fr") | |
return [ | |
"fr", | |
reset_state["ask"], | |
reset_state["chatbots"]["Spinoza"], | |
*[reset_state["chatbots"][tab] for tab in config["tabs"]], | |
*[ | |
gr.update( | |
label=get_text(f"agent_{config['source_mapping'][tab]}_qa", "fr"), | |
open=False, | |
elem_id=f"accordion-{config['source_mapping'][tab]}", | |
elem_classes="accordion accordion-agent" | |
) | |
for tab in list(config["tabs"].keys()) | |
], | |
gr.update(label=get_text("source_filter_label", "fr"), elem_id="filter-component"), | |
*[ | |
gr.update( | |
label=get_text(f"agent_{config['source_mapping'][tab]}_flt", "fr"), | |
elem_id=f"accordion-filter-{config['source_mapping'][tab]}", | |
elem_classes="accordion accordion-source" | |
) | |
for tab in list(config["tabs"].keys()) | |
], | |
gr.update(value=get_text("source_filter_title", 'fr')), | |
gr.update(value=get_text("source_filter_subtitle", 'fr')), | |
gr.update(value=get_text("question_filter", 'fr')), | |
gr.update(label=get_text("source_informatation_label", "fr"), elem_id="source-component"), | |
gr.update(value=get_text("display_info_desc", "fr")), | |
*[ | |
gr.update(value=get_text(f"{config['source_mapping'][tab]}_desc", "fr")) | |
for tab in list(config["tabs"].keys()) | |
], | |
*[ | |
gr.update( | |
label=get_text(f"agent_{config['source_mapping'][tab]}_tab", "fr"), | |
elem_id=f"accordion-{config['source_mapping'][tab]}-tab", | |
elem_classes="disp-tabs" | |
) | |
for tab in list(config["tabs"].keys()) | |
], | |
gr.update(label=get_text("contact_label", "fr")), | |
gr.update(value=about_contact_fr), | |
gr.update(value=""), | |
gr.update(value={}), | |
gr.update(value={}), | |
*[ | |
gr.update(value=None) for _ in range(len(agt_input_flt)) | |
] | |
] | |
def toggle_language_en(): | |
reset_state = reset_app("en") | |
return [ | |
"en", | |
reset_state["ask"], | |
reset_state["chatbots"]["Spinoza"], | |
*[reset_state["chatbots"][tab] for tab in config["tabs"]], | |
*[ | |
gr.update( | |
label=get_text(f"agent_{config['source_mapping'][tab]}_qa", "en"), | |
open=False, | |
elem_id=f"accordion-{config['source_mapping'][tab]}", | |
elem_classes="accordion accordion-agent" | |
) | |
for tab in list(config["tabs"].keys()) | |
], | |
gr.update(label=get_text("source_filter_label", "en"), elem_id="filter-component"), | |
*[ | |
gr.update( | |
label=get_text(f"agent_{config['source_mapping'][tab]}_flt", "en"), | |
elem_id=f"accordion-filter-{config['source_mapping'][tab]}", | |
elem_classes="accordion accordion-source" | |
) | |
for tab in list(config["tabs"].keys()) | |
], | |
gr.update(value=get_text("source_filter_title", 'en')), | |
gr.update(value=get_text("source_filter_subtitle", 'en')), | |
gr.update(value=get_text("question_filter", 'en')), | |
gr.update(label=get_text("source_informatation_label", "en"), elem_id="source-component"), | |
gr.update(value=get_text("display_info_desc", "en")), | |
*[ | |
gr.update(value=get_text(f"{config['source_mapping'][tab]}_desc", "en")) | |
for tab in list(config["tabs"].keys()) | |
], | |
*[ | |
gr.update( | |
label=get_text(f"agent_{config['source_mapping'][tab]}_tab", "en"), | |
elem_id=f"accordion-{config['source_mapping'][tab]}-tab", | |
elem_classes="disp-tabs" | |
) | |
for tab in list(config["tabs"].keys()) | |
], | |
gr.update(label=get_text("contact_label", "en")), | |
gr.update(value=about_contact_en), | |
gr.update(value=""), | |
gr.update(value={}), | |
gr.update(value={}), | |
*[ | |
gr.update(value=None) for _ in range(len(agt_input_flt)) | |
] | |
] | |
button_fr.click( | |
fn=toggle_language_fr, | |
inputs=[], | |
outputs=[ | |
current_language, | |
ask, | |
chatbots["Spinoza"], | |
*[chatbots[tab] for tab in config["tabs"]], | |
*[accordions_qa[key] for key in accordions_qa.keys()], | |
source_filter_tab, | |
*[accordions_filters[key] for key in accordions_filters.keys()], | |
source_filter_title, | |
source_filter_subtitle, | |
question_filter, | |
source_information_tab, | |
display_info_desc, | |
*[agt_desc[key] for key in agt_desc.keys()], | |
*[accordions_inf[key] for key in accordions_inf.keys()], | |
contact_label, | |
contact_info, | |
sources_textbox, | |
filters_state, | |
filters_display, | |
*[agt_input_flt[key] for key in agt_input_flt.keys()] | |
] | |
) | |
button_en.click( | |
fn=toggle_language_en, | |
inputs=[], | |
outputs=[ | |
current_language, | |
ask, | |
chatbots["Spinoza"], | |
*[chatbots[tab] for tab in config["tabs"]], | |
*[accordions_qa[key] for key in accordions_qa.keys()], | |
source_filter_tab, | |
*[accordions_filters[key] for key in accordions_filters.keys()], | |
source_filter_title, | |
source_filter_subtitle, | |
question_filter, | |
source_information_tab, | |
display_info_desc, | |
*[agt_desc[key] for key in agt_desc.keys()], | |
*[accordions_inf[key] for key in accordions_inf.keys()], | |
contact_label, | |
contact_info, | |
sources_textbox, | |
filters_state, | |
filters_display, | |
*[agt_input_flt[key] for key in agt_input_flt.keys()] | |
] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True, share=True) | |