spinozaed / app.py
momenaca's picture
add features to ease hackathon
f07b5e8
import gradio as gr
import time
import os
from spinoza_project.source.backend.llm_utils import (
get_llm_api,
get_vectorstore_api,
)
from spinoza_project.source.frontend.utils import (
init_env,
parse_output_llm_with_sources,
)
from spinoza_project.source.frontend.gradio_utils import (
get_sources,
set_prompts,
get_config,
get_prompts,
get_assets,
get_theme,
get_init_prompt,
get_synthesis_prompt,
get_qdrants,
get_qdrants_public,
start_agents,
end_agents,
next_call,
zip_longest_fill,
reformulate,
answer,
)
from assets.utils_javascript import (
accordion_trigger,
accordion_trigger_end,
accordion_trigger_spinoza,
accordion_trigger_spinoza_end,
update_footer,
)
init_env()
config = get_config()
## Loading Prompts
print("Loading Prompts")
prompts = get_prompts(config)
chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config)
synthesis_prompt_template = get_synthesis_prompt(config)
## Building LLM
print("Building LLM")
groq_model_name = (
config["groq_model_name"] if not os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME") else ""
)
llm = get_llm_api(groq_model_name)
## Loading BDDs
print("Loading Databases")
qdrants = get_qdrants(config)
if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"):
bdd_presse = get_vectorstore_api("presse")
bdd_afp = get_vectorstore_api("afp")
else:
qdrants_public = get_qdrants_public(config)
qdrants = {**qdrants, **qdrants_public}
bdd_presse = None
bdd_afp = None
## Loading Assets
css, source_information = get_assets()
theme = get_theme()
init_prompt = get_init_prompt()
def reformulate_questions(
question,
llm=llm,
chat_reformulation_prompts=chat_reformulation_prompts,
config=config,
):
for elt in zip_longest_fill(
*[
reformulate(llm, chat_reformulation_prompts, question, tab, config=config)
for tab in config["tabs"]
]
):
time.sleep(0.02)
yield elt
def retrieve_sources(
*questions,
qdrants=qdrants,
bdd_presse=bdd_presse,
bdd_afp=bdd_afp,
config=config,
):
formated_sources, text_sources = get_sources(
questions, qdrants, bdd_presse, bdd_afp, config
)
return (formated_sources, *text_sources)
def answer_questions(
*questions_sources, llm=llm, chat_qa_prompts=chat_qa_prompts, config=config
):
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]]
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]]
for elt in zip_longest_fill(
*[
answer(llm, chat_qa_prompts, question, source, tab, config)
for question, source, tab in zip(questions, sources, config["tabs"])
]
):
time.sleep(0.02)
yield [
[(question, parse_output_llm_with_sources(ans))]
for question, ans in zip(questions, elt)
]
def get_synthesis(
question,
*answers,
llm=llm,
synthesis_prompt_template=synthesis_prompt_template,
config=config,
):
answer = []
for i, tab in enumerate(config["tabs"]):
if len(str(answers[i])) >= 100:
answer.append(
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "")
)
if len(answer) == 0:
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question"
else:
for elt in llm.stream(
synthesis_prompt_template,
{
"question": question.replace("<p>", "").replace("</p>\n", ""),
"answers": "\n\n".join(answer),
},
):
time.sleep(0.01)
yield [(question, parse_output_llm_with_sources(elt))]
with gr.Blocks(
title=f"🔍 Spinoza",
css=css,
js=update_footer(),
theme=theme,
) as demo:
chatbots = {}
question = gr.State("")
docs_textbox = gr.State([""])
agent_questions = {elt: gr.State("") for elt in config["tabs"]}
component_sources = {elt: gr.State("") for elt in config["tabs"]}
text_sources = {elt: gr.State("") for elt in config["tabs"]}
tab_states = {elt: gr.State(elt) for elt in config["tabs"]}
with gr.Tab("Q&A", elem_id="main-component"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2, elem_id="center-panel"):
with gr.Group(elem_id="chatbot-group"):
for tab in list(config["tabs"].keys()) + ["Spinoza"]:
if tab == "Spinoza":
agent_name = f"Spinoza"
elem_id = f"accordion-{tab}"
elem_classes = "accordion accordion-agent spinoza-agent"
else:
agent_name = f"Agent {config['source_mapping'][tab]}"
elem_id = f"accordion-{config['source_mapping'][tab]}"
elem_classes = "accordion accordion-agent"
with gr.Accordion(
agent_name,
open=True if agent_name == "Spinoza" else False,
elem_id=elem_id,
elem_classes=elem_classes,
):
# chatbot_key = agent_name.lower().replace(" ", "_")
chatbots[tab] = gr.Chatbot(
value=(
[(None, init_prompt)]
if agent_name == "Spinoza"
else None
),
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
(
"./assets/logos/spinoza.png"
if agent_name == "Spinoza"
else None
),
),
)
with gr.Row(elem_id="input-message"):
ask = gr.Textbox(
placeholder="Ask me anything here!",
show_label=False,
scale=7,
lines=1,
interactive=True,
elem_id="input-textbox",
)
with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
with gr.TabItem("Sources", elem_id="tab-sources", id=0):
sources_textbox = gr.HTML(
show_label=False, elem_id="sources-textbox"
)
with gr.Tab("Source information", elem_id="source-component"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(source_information)
with gr.Tab("Contact", elem_id="contact-component"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("For any issue contact **[email protected]**.")
ask.submit(
start_agents, inputs=[], outputs=[chatbots["Spinoza"]], js=accordion_trigger()
).then(
fn=reformulate_questions,
inputs=[ask],
outputs=[agent_questions[tab] for tab in config["tabs"]],
).then(
fn=retrieve_sources,
inputs=[agent_questions[tab] for tab in config["tabs"]],
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]],
).then(
fn=answer_questions,
inputs=[agent_questions[tab] for tab in config["tabs"]]
+ [text_sources[tab] for tab in config["tabs"]],
outputs=[chatbots[tab] for tab in config["tabs"]],
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end()
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza()
).then(
fn=get_synthesis,
inputs=[agent_questions[list(config["tabs"].keys())[1]]]
+ [chatbots[tab] for tab in config["tabs"]],
outputs=[chatbots["Spinoza"]],
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end()
).then(
fn=end_agents, inputs=[], outputs=[]
)
if __name__ == "__main__":
demo.queue().launch(debug=True)