import torch
import transformers
import gradio as gr
from ragatouille import RAGPretrainedModel
import re
from datetime import datetime
import json
import arxiv

from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search

# Constants
RETRIEVE_RESULTS = 20
LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'
GENERATE_KWARGS = {
    "temperature": None,
    "max_new_tokens": 512,
    "top_p": None,
    "do_sample": False,
}

try:
    # RAG Model setup
    RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
    semantic_search_available = True

    try:
        gr.Info("Setting up retriever, please wait...")
        rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1)
        gr.Info("Retriever working successfully!")
    except Exception as e:
        gr.Warning(f"Retriever not working: {str(e)}")

except FileNotFoundError:
    RAG = None
    semantic_search_available = False
    gr.Warning("Colbert index not found. Semantic search will be unavailable.")

# Header setup
mark_text = '# πŸ©ΊπŸ” Search Results\n'
header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"

try:
    with open("README.md", "r") as f:
        mdfile = f.read()
    date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
    match = re.search(date_pattern, mdfile)
    date = match.group().split(': ')[1]
    formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
    header_text += f'Index Last Updated: {formatted_date}\n'
    index_info = f"Semantic Search - up to {formatted_date}"
except FileNotFoundError:
    index_info = "Semantic Search"

if semantic_search_available:
    database_choices = [index_info, 'Arxiv Search - Latest']
else:
    database_choices = ['Arxiv Search - Latest']

# Arxiv API setup
arx_client = arxiv.Client()
is_arxiv_available = True
check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS)
if len(check_arxiv_result) == 0:
    is_arxiv_available = False
    print("Arxiv search not working, switching to default search ...")
    database_choices = [index_info]

# Gradio UI setup
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    header = gr.Markdown(header_text)
    
    with gr.Group():
        search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?')
        
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Row(equal_height=True):
                llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model')
                llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
                database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source')
                stream_results = gr.Checkbox(value=True, label="Stream output", visible=False)
    
    output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True)
    input = gr.Textbox(show_label=False, visible=False)
    gr_md = gr.Markdown(mark_text)
    
    def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL):
        prompt_text_from_data = ""
        
        if database_choice == index_info and semantic_search_available:
            rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
            database_to_use = 'Semantic Search'
        else:
            arxiv_search_success = True
            try:
                rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS)
                if len(rag_out) == 0:
                    arxiv_search_success = False
            except Exception as e:
                arxiv_search_success = False
                gr.Warning(f"Arxiv Search not working: {str(e)}")
            
            if not arxiv_search_success:
                gr.Warning("Arxiv search failed. Please try again later.")
                return "", ""
            
            database_to_use = 'Arxiv Search'
        
        md_text_updated = mark_text
        for i, rag_answer in enumerate(rag_out):
            if i < llm_results_use:
                md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True)
                prompt_text_from_data += f"{i+1}. {prompt_text}"
            else:
                md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use)
            md_text_updated += md_text_paper
        
        prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked)
        return md_text_updated, prompt
    
    def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False):
        model_disabled_text = "LLM Model is disabled"
        output = ""
        
        if llm_model_picked == 'None':
            if stream_outputs:
                for out in model_disabled_text:
                    output += out
                    yield output
            else:
                return model_disabled_text
        
        client = InferenceClient(llm_model_picked)
        try:
            response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS)
            
            if stream_outputs:
                for token in response:
                    output += token
                    yield SaveResponseAndRead(output)
            else:
                output = response
        except Exception as e:
            gr.Warning(f"LLM Inference failed: {str(e)}")
            output = ""
        
        return output
    
    search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)

demo.queue().launch()