from typing import List, Optional

import gradio as gr
from datasets import load_dataset
from huggingface_hub import InferenceClient

from prompt_bot import prompt_for_template, template_bot
from util.extract_data import extrair_dados_template
from util.import_dataset import get_response_from_huggingface_dataset

descricao, regras, comportamento = extrair_dados_template()


MODEL: str = "meta-llama/Llama-3.2-3B-Instruct"

TEMPLATE_BOT = template_bot()
prompt_template = prompt_for_template(TEMPLATE_BOT)

# modify future
DATASET = load_dataset("wendellast/GUI-Ban")

client: InferenceClient = InferenceClient(model=MODEL)


def respond(
    message: str,
    history: List[dict],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
) -> any:
    response: Optional[str] = get_response_from_huggingface_dataset(message, DATASET)
    if response:
        yield response
        return

    prompt: str = prompt_template.format(
        description=descricao,
        regras=regras,
        comportamento=comportamento,
        mensagem=message,
    )

    print(prompt)

    messages: List[dict] = [{"role": "system", "content": prompt}]
    response: str = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token: str = message.choices[0].delta.content
        response += token
        yield response


demo: gr.ChatInterface = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    title="Commit-IA",
    type="messages",
)


if __name__ == "__main__":
    demo.launch()