import gradio as gr
import time
import os
import yaml
from qdrant_client import models
from tqdm import tqdm
from collections import defaultdict
import pandas as pd

from spinoza_project.source.backend.llm_utils import (
    get_llm_api,
)
from spinoza_project.source.frontend.utils import (
    init_env,
    parse_output_llm_with_sources,
)
from spinoza_project.source.frontend.gradio_utils import (
    get_sources,
    set_prompts,
    get_config,
    get_prompts,
    get_assets,
    get_theme,
    get_init_prompt,
    get_synthesis_prompt,
    get_qdrants,
    start_agents,
    end_agents,
    next_call,
    zip_longest_fill,
    reformulate,
    answer,
    get_text,
    update_translation,
)

from assets.utils_javascript import (
    accordion_trigger,
    accordion_trigger_end,
    accordion_trigger_spinoza,
    accordion_trigger_spinoza_end,
    update_footer
)

init_env()

with open("./spinoza_project/config.yaml") as f:
    config = yaml.full_load(f)

## Loading Prompts
print("Loading Prompts")
prompts = get_prompts(config)
chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config)
synthesis_prompt_template = get_synthesis_prompt(config)

## Building LLM
print("Building LLM")
groq_model_name = config.get("groq_model_name", "")
llm = get_llm_api(groq_model_name)

## Loading BDDs
print("Loading Databases")
qdrants, df_qdrants = get_qdrants(config)

dataframes_by_source = {
    source: df_qdrants[df_qdrants['Source'] == source].drop(columns=['Source'])
    for source in df_qdrants['Source'].unique()
}

for source, df in dataframes_by_source.items():
    dataframes_by_source[source]['Filter'] = dataframes_by_source[source]['Filter'].fillna('Unknown')

    unknown_percentage = df.apply(lambda x: (x == 'Unknown').mean())
    columns_to_drop = unknown_percentage[unknown_percentage == 1.0].index
    
    if len(columns_to_drop) > 0:
        print(f"Deleting following columns for {source}: {columns_to_drop.tolist()}")
        dataframes_by_source[source] = df.drop(columns=columns_to_drop)

## Loading Assets
print("Loading assets")
css, source_information_fr, source_information_en, about_contact_fr, about_contact_en = get_assets()
theme = get_theme()
init_prompt = get_init_prompt()

## Updating TRANSLATIONS dictionnary
list_tabs = list(config["tabs"])
update_translation(list_tabs, config)

def get_source_df(source_name):
    return dataframes_by_source.get(source_name, pd.DataFrame())

LANGUAGE_MAPPING = {
    "fr": "french/français",
    "en": "english/anglais"
}

def reformulate_questions(
    lang_component,
    question,
    llm=llm,
    chat_reformulation_prompts=chat_reformulation_prompts,
    config=config,
):
    lang = lang_component.value if hasattr(lang_component, 'value') else lang_component
    language = LANGUAGE_MAPPING.get(lang, "french/français")

    for elt in zip_longest_fill(
        *[
            reformulate(language, llm, chat_reformulation_prompts, question, tab, config=config)
            for tab in config["tabs"]
        ]
    ):
        time.sleep(0.02)
        yield elt

def retrieve_sources(
    *questions,
    filters_dict,
    qdrants=qdrants,
    config=config,
):
    if filters_dict is None:
        filters_dict = {}

    formated_sources, text_sources = get_sources(
        questions, filters_dict, qdrants, config
    )

    return (formated_sources, *text_sources)

def retrieve_sources_wrapper(*args):
    questions = list(args[:-1])
    filters = args[-1]

    return retrieve_sources(
        questions,
        filters_dict=filters
    )

def answer_questions(
    lang_component, 
    *questions_sources, 
    llm=llm, 
    chat_qa_prompts=chat_qa_prompts, 
    config=config
):
    lang = lang_component.value if hasattr(lang_component, 'value') else lang_component
    language = LANGUAGE_MAPPING.get(lang, "french/français")

    questions = [elt for elt in questions_sources[: len(questions_sources) // 2]]
    sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]]

    for elt in zip_longest_fill(
        *[
            answer(language, llm, chat_qa_prompts, question, source, tab, config)
            for question, source, tab in zip(questions, sources, config["tabs"])
        ]
    ):
        time.sleep(0.02)
        yield [
            [(question, parse_output_llm_with_sources(ans))]
            for question, ans in zip(questions, elt)
        ]

def get_synthesis(
    lang_component,
    question,
    *answers,
    llm=llm,
    synthesis_prompt_template=synthesis_prompt_template,
    config=config,
):
    lang = lang_component.value if hasattr(lang_component, 'value') else lang_component
    language = LANGUAGE_MAPPING.get(lang, "french/français")

    answer = []
    for i, tab in enumerate(config["tabs"]):
        if len(str(answers[i])) >= 100:
            answer.append(
                f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "")
            )

    if len(answer) == 0:
        return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question"
    else:
        for elt in llm.stream(
            synthesis_prompt_template,
            {
                "question": question.replace("<p>", "").replace("</p>\n", ""),
                "answers": "\n\n".join(answer),
                "language": language
            },
        ):
            time.sleep(0.01)
            yield [(question, parse_output_llm_with_sources(elt))]

def get_unique_values_filters(df):
    filters_values = sorted([
        str(x) for x in df['Filter'].unique() 
        if pd.notna(x) and str(x).strip() != ''
    ])

    return filters_values

def filter_data(filter, source): 
    if source not in dataframes_by_source:
        raise ValueError(f"'{source}' not found withing the sources availible")
    
    df = dataframes_by_source[source]

    if filter:
        df = df[df['Filter'].fillna('').astype(str).isin(filter)]

    return df.values.tolist()
              
def update_filters(filters_dict, agent, values):
    field = "file_filtering_modality"
    if filters_dict is None:
        filters_dict = {}
    new_filters = dict(filters_dict)

    if agent not in new_filters:
        new_filters[agent] = {}

    if not values or isinstance(values, list):
        if field in new_filters[agent]:
            del new_filters[agent][field]
        if not new_filters[agent]:
            del new_filters[agent]
    else:
        new_filters[agent][field] = values

    return new_filters, new_filters

with gr.Blocks(
    title=f"🔍 Spinoza",
    css=css,
    js=update_footer(),
    theme=theme,
) as demo:
    accordions_qa = {}
    accordions_filters = {}
    current_language = gr.State(value="fr")
    chatbots = {}
    question = gr.State("")
    agt_input_flt = {}
    agt_desc = {}
    agt_input_dsp = gr.State({})
    docs_textbox = gr.State([""])
    agent_questions = {elt: gr.State("") for elt in config["tabs"]}
    component_sources = {elt: gr.State("") for elt in config["tabs"]}
    text_sources = {elt: gr.State("") for elt in config["tabs"]}
    tab_states = {elt: gr.State(elt) for elt in config["tabs"]}
    filters_state = gr.State({})
    filters_display = gr.JSON(
        label="Filtres sélectionnés",
        value={},
        visible=False
    )

    with gr.Row(elem_classes="header-row"):
        button_fr = gr.Button("", elem_id="fr-button", elem_classes="lang-button", icon='./assets/logos/france_round.png') 
        button_en = gr.Button("", elem_id="en-button", elem_classes="lang-button", icon='./assets/logos/us_round.png')

    with gr.Row(elem_classes="main-row"):
        with gr.Tab("Q&A", elem_id="main-component"):
            with gr.Row(elem_id="chatbot-row"):
                with gr.Column(scale=2, elem_id="center-panel"):
                    with gr.Row(elem_id="input-message"):
                        ask = gr.Textbox(
                            placeholder=get_text("ask_placeholder", current_language.value),
                            show_label=False,
                            scale=7,
                            lines=1,
                            interactive=True,
                            elem_id="input-textbox",
                        )

                    with gr.Group(elem_id="chatbot-group"):
                        for tab in list(config["tabs"].keys()):
                            agent_name = get_text(f"agent_{config['source_mapping'][tab]}_qa", current_language.value)
                            elem_id = f"accordion-{config['source_mapping'][tab]}"
                            elem_classes = "accordion accordion-agent"

                            with gr.Accordion(
                                label=agent_name,
                                open=False,
                                elem_id=elem_id,
                                elem_classes=elem_classes,
                            ) as accordions_qa[config['source_mapping'][tab]]:
                                # chatbot_key = agent_name.lower().replace(" ", "_")
                                chatbots[tab] = gr.Chatbot(
                                    value=None,
                                    show_copy_button=True,
                                    show_share_button=False,
                                    show_label=False,
                                    elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}",
                                    layout="panel",
                                    avatar_images=(
                                        "./assets/logos/help.png",
                                        (
                                            "./assets/logos/spinoza.png"
                                            if agent_name == "Spinoza"
                                            else None
                                        ),
                                    )
                                )

                        agent_name = "Spinoza"
                        with gr.Accordion(
                            label=agent_name,
                            open=True,
                            elem_id="accordion-Spinoza",
                            elem_classes="accordion accordion-agent spinoza-agent",
                        ) as accordion_spinoza:
                            # chatbot_key = agent_name.lower().replace(" ", "_")
                            chatbots["Spinoza"] = gr.Chatbot(
                                value=([(None, get_text("init_prompt", current_language.value))]),
                                show_copy_button=True,
                                show_share_button=False,
                                show_label=False,
                                elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}",
                                layout="panel",
                                avatar_images=(
                                    "./assets/logos/help.png",
                                    "./assets/logos/spinoza.png",
                                ),
                            )

                with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
                    with gr.TabItem("Sources", elem_id="tab-sources", id=0):
                        sources_textbox = gr.HTML(
                            show_label=False, elem_id="sources-textbox"
                        )

        with gr.Tab(label=get_text("source_filter_label", current_language.value), elem_id="filter-component") as source_filter_tab:
            source_filter_title= gr.Markdown(value=get_text("source_filter_title", current_language.value))
            source_filter_subtitle = gr.Markdown(value=get_text("source_filter_subtitle", current_language.value))

            with gr.Row(elem_id="filter-row"):
                with gr.Column(scale=2, elem_id="filter-center-panel"):
                    with gr.Group(elem_id="filter-group"):
                        for tab in list(config["tabs"].keys()):
                            agent_name = get_text(f"agent_{config['source_mapping'][tab]}_flt", current_language.value)
                            elem_id = f"accordion-filter-{config['source_mapping'][tab]}"
                            elem_classes = "accordion accordion-source"
                            
                            with gr.Accordion(
                                label=agent_name,
                                open=False,
                                elem_id=elem_id,
                                elem_classes=elem_classes,
                            ) as accordions_filters[config['source_mapping'][tab]]:
                                question_filter = gr.Markdown(value=get_text("question_filter", current_language.value))
                                with gr.Tabs():
                                    df = get_source_df(config['source_mapping'][tab])
                                    if not df.empty and 'Filter' in df.columns:                             
                                        filters = get_unique_values_filters(df)
                                        
                                        with gr.Row():
                                            var_name = f"{config['source_mapping'][tab]}_input_flt"            
                                            agt_input_flt[var_name] = gr.CheckboxGroup(
                                                [filter for filter in filters],
                                                label="Filter(s):"
                                            )

                                            agt_input_flt[var_name].change(
                                                fn=update_filters,
                                                inputs=[filters_state, gr.State(config['source_mapping'][tab]), agt_input_flt[var_name]],
                                                outputs=[filters_state, filters_display]
                                            )

                                    else:
                                        gr.Markdown("**Error:** No data / 'Filter' column doesn't exist...")                                         

        with gr.Tab(label=get_text("source_informatation_label", current_language.value), elem_id="source-component") as source_information_tab:
            with gr.Row():
                with gr.Column(scale=1):
                    display_info_desc = gr.Markdown(value=get_text("display_info_desc", current_language.value))
                    accordions_inf = {}
                    with gr.Tabs(elem_id="main-tab-disp"):
                            for tab in list(config["tabs"].keys()):
                                agent_name = get_text(f"agent_{config['source_mapping'][tab]}_tab", current_language.value)
                                elem_id = f"accordion-{config['source_mapping'][tab]}-tab"
                                elem_classes = "disp-tabs"

                                with gr.Tab(
                                    label=agent_name,
                                    elem_id=elem_id,
                                    elem_classes=elem_classes
                                ) as accordions_inf[config['source_mapping'][tab]]:
                                    var_name = f"{config['source_mapping'][tab]}_desc"
                                    agt_desc[var_name] = gr.Markdown(value=get_text(f"{config['source_mapping'][tab]}_desc", current_language.value))
                                    df = get_source_df(config['source_mapping'][tab])
                                    if not df.empty and 'Filter' in df.columns:
                                        filters = get_unique_values_filters(df)
                                        
                                        with gr.Row():
                                            var_name = f"{config['source_mapping'][tab]}_input_dsp"                                                                          
                                            agt_input_dsp.value[var_name] = gr.CheckboxGroup(
                                                [filter for filter in filters],
                                                label="Filter(s):"
                                            )

                                        output_df = gr.Dataframe(
                                            headers=['Title', 'Pages', 'Filter Category', 'Publishing Date'],
                                            datatype=['str', 'number', 'str', 'number'],
                                            value=df.values.tolist(),
                                            column_widths=[300, 100, 100, 150],
                                            wrap=True
                                        )
                
                                        agt_input_dsp.value[var_name].change(
                                            filter_data,
                                            inputs=[agt_input_dsp.value[var_name]]+[gr.State(config['source_mapping'][tab])],
                                            outputs=[output_df]
                                        )
                                    
                                    else:
                                        gr.Markdown("**Error:** No data / 'Filter' column doesn't exist...")  

        with gr.Tab(label=get_text("contact_label", current_language.value), elem_id="contact-component") as contact_label:
            with gr.Row():
                with gr.Column(scale=1):
                    contact_info = gr.Markdown(value=about_contact_fr)

    ask.submit(
        start_agents, inputs=[current_language], outputs=[chatbots["Spinoza"]] + [source_filter_tab], js=accordion_trigger()
    ).then(
        fn=reformulate_questions,
        inputs=[current_language]+
        [ask],
        outputs=[agent_questions[tab] for tab in config["tabs"]],
    ).then(
        fn=retrieve_sources_wrapper,
        inputs=[agent_questions[tab] for tab in config["tabs"]] + [filters_state],
        outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]],
    ).then(
        fn=answer_questions,
        inputs=[current_language]
        + [agent_questions[tab] for tab in config["tabs"]]
        + [text_sources[tab] for tab in config["tabs"]],
        outputs=[chatbots[tab] for tab in config["tabs"]],
    ).then(
        fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end()
    ).then(
        fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza()
    ).then(
        fn=get_synthesis,
        inputs=[current_language]
        + [ask]
        + [chatbots[tab] for tab in config["tabs"]],
        outputs=[chatbots["Spinoza"]],
    ).then(
        fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end()
    ).then(
        fn=end_agents, inputs=[current_language], outputs=[source_filter_tab]
    )

    def reset_app(language):
       
        chatbot_updates = {}
        for tab in config["tabs"]:
            chatbot_updates[tab] = gr.update(value=None)
        chatbot_updates["Spinoza"] = gr.update(value=[(None, get_text("init_prompt", language))])
        
        empty_checkbox = gr.update(value=None)
        checkbox_components = list(agt_input_flt.keys()) + list(agt_input_dsp.value.keys())
        checkbox_updates = {component: empty_checkbox for component in checkbox_components}
        
        return {
            "chatbots": chatbot_updates,
            "filters_state": gr.update(value={}),
            "filters_display": gr.update(value={}),
            "ask": gr.update(value="", placeholder=get_text("ask_placeholder", language)),
            "sources_textbox": gr.update(value=""),
            "checkbox_updates": checkbox_updates
        }
    
    def toggle_language_fr():
        reset_state = reset_app("fr")
        return [
            "fr",
            reset_state["ask"],
            reset_state["chatbots"]["Spinoza"],
            *[reset_state["chatbots"][tab] for tab in config["tabs"]],
            *[
                gr.update(
                    label=get_text(f"agent_{config['source_mapping'][tab]}_qa", "fr"),
                    open=False,
                    elem_id=f"accordion-{config['source_mapping'][tab]}",
                    elem_classes="accordion accordion-agent"
                )
                for tab in list(config["tabs"].keys())
            ],
            gr.update(label=get_text("source_filter_label", "fr"), elem_id="filter-component"),
            *[
                gr.update(
                    label=get_text(f"agent_{config['source_mapping'][tab]}_flt", "fr"),
                    elem_id=f"accordion-filter-{config['source_mapping'][tab]}",
                    elem_classes="accordion accordion-source"
                )
                for tab in list(config["tabs"].keys())
            ],
            gr.update(value=get_text("source_filter_title", 'fr')),
            gr.update(value=get_text("source_filter_subtitle", 'fr')),
            gr.update(value=get_text("question_filter", 'fr')),
            gr.update(label=get_text("source_informatation_label", "fr"), elem_id="source-component"),
            gr.update(value=get_text("display_info_desc", "fr")),
            *[
                gr.update(value=get_text(f"{config['source_mapping'][tab]}_desc", "fr"))
                for tab in list(config["tabs"].keys())
            ],
            *[
                gr.update(
                    label=get_text(f"agent_{config['source_mapping'][tab]}_tab", "fr"),
                    elem_id=f"accordion-{config['source_mapping'][tab]}-tab",
                    elem_classes="disp-tabs"
                )
                for tab in list(config["tabs"].keys())
            ],
            gr.update(label=get_text("contact_label", "fr")),
            gr.update(value=about_contact_fr),
            gr.update(value=""),
            gr.update(value={}),
            gr.update(value={}),
            *[
                gr.update(value=None) for _ in range(len(agt_input_flt))
            ]    
        ]

    def toggle_language_en():
        reset_state = reset_app("en")
        return [
            "en",
            reset_state["ask"],
            reset_state["chatbots"]["Spinoza"],
            *[reset_state["chatbots"][tab] for tab in config["tabs"]],
            *[
                gr.update(
                    label=get_text(f"agent_{config['source_mapping'][tab]}_qa", "en"),
                    open=False,
                    elem_id=f"accordion-{config['source_mapping'][tab]}",
                    elem_classes="accordion accordion-agent"
                )
                for tab in list(config["tabs"].keys())
            ],
            gr.update(label=get_text("source_filter_label", "en"), elem_id="filter-component"),
            *[
                gr.update(
                    label=get_text(f"agent_{config['source_mapping'][tab]}_flt", "en"),
                    elem_id=f"accordion-filter-{config['source_mapping'][tab]}",
                    elem_classes="accordion accordion-source"
                )
                for tab in list(config["tabs"].keys())
            ],
            gr.update(value=get_text("source_filter_title", 'en')),
            gr.update(value=get_text("source_filter_subtitle", 'en')),
            gr.update(value=get_text("question_filter", 'en')),
            gr.update(label=get_text("source_informatation_label", "en"), elem_id="source-component"),
            gr.update(value=get_text("display_info_desc", "en")),
            *[
                gr.update(value=get_text(f"{config['source_mapping'][tab]}_desc", "en"))
                for tab in list(config["tabs"].keys())
            ],
            *[
                gr.update(
                    label=get_text(f"agent_{config['source_mapping'][tab]}_tab", "en"),
                    elem_id=f"accordion-{config['source_mapping'][tab]}-tab",
                    elem_classes="disp-tabs"
                )
                for tab in list(config["tabs"].keys())
            ],
            gr.update(label=get_text("contact_label", "en")),
            gr.update(value=about_contact_en),
            gr.update(value=""),
            gr.update(value={}),
            gr.update(value={}),
            *[
                gr.update(value=None) for _ in range(len(agt_input_flt))
            ]    
        ]

    button_fr.click(
        fn=toggle_language_fr,
        inputs=[],
        outputs=[
            current_language, 
            ask, 
            chatbots["Spinoza"],
            *[chatbots[tab] for tab in config["tabs"]],
            *[accordions_qa[key] for key in accordions_qa.keys()],
            source_filter_tab,
            *[accordions_filters[key] for key in accordions_filters.keys()],
            source_filter_title, 
            source_filter_subtitle,
            question_filter,
            source_information_tab,
            display_info_desc, 
            *[agt_desc[key] for key in agt_desc.keys()],
            *[accordions_inf[key] for key in accordions_inf.keys()],
            contact_label,
            contact_info,
            sources_textbox,
            filters_state,
            filters_display,
            *[agt_input_flt[key] for key in agt_input_flt.keys()]
        ]
    )

    button_en.click(
        fn=toggle_language_en,
        inputs=[],
        outputs=[
            current_language, 
            ask, 
            chatbots["Spinoza"],
            *[chatbots[tab] for tab in config["tabs"]],
            *[accordions_qa[key] for key in accordions_qa.keys()],
            source_filter_tab,
            *[accordions_filters[key] for key in accordions_filters.keys()],
            source_filter_title, 
            source_filter_subtitle,
            question_filter,
            source_information_tab,
            display_info_desc, 
            *[agt_desc[key] for key in agt_desc.keys()],
            *[accordions_inf[key] for key in accordions_inf.keys()],
            contact_label,
            contact_info,
            sources_textbox,
            filters_state,
            filters_display,
            *[agt_input_flt[key] for key in agt_input_flt.keys()]
        ]
    )

if __name__ == "__main__":
    demo.queue().launch(debug=True, share=True)