import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.protein_solubility.task import ProteinSolubilityTask
from mammal.keys import (
    CLS_PRED,
    ENCODER_INPUTS_STR,
    SCORES,
)
from mammal.model import Mammal

from mammal_demo.demo_framework import MammalObjectBroker, MammalTask

data_preprocessing = ProteinSolubilityTask.data_preprocessing
process_model_output = ProteinSolubilityTask.process_model_output


class PsTask(MammalTask):
    def __init__(self, model_dict):
        super().__init__(name="Protein Solubility", model_dict=model_dict)
        self.description = "Protein Solubility (PS)"
        self.examples = {
            "protein_seq": "LLQTGIHVRVSQPSL",
        }
        self.markup_text = """
# Mammal based protein solubility estimation

Given the protein sequence, estimate if it's water-soluble.
"""

    def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
        """convert sample_inputs to sample_dict including creating a proper prompt

        Args:
            sample_inputs (dict): dictionary containing the inputs to the model
            model_holder (MammalObjectBroker): model holder
        Returns:
           dict: sample_dict for feeding into model
        """
        sample_dict = dict(sample_inputs)  # shallow copy
        sample_dict = data_preprocessing(
            sample_dict=sample_dict,
            protein_sequence_key="protein_seq",
            tokenizer_op=model_holder.tokenizer_op,
            device=model_holder.model.device,
        )

        return sample_dict

    def run_model(self, sample_dict, model: Mammal):
        # Generate Prediction
        batch_dict = model.generate(
            [sample_dict],
            output_scores=True,
            return_dict_in_generate=True,
            max_new_tokens=5,
        )
        return batch_dict

    def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp) -> list:
        """
        Extract predicted class and scores
        """
        ans_dict = process_model_output(
            tokenizer_op=tokenizer_op,
            decoder_output=batch_dict[CLS_PRED][0],
            decoder_output_scores=batch_dict[SCORES][0],
        )
        ans = [
            tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
            ans_dict["pred"],
            ans_dict["not_normalized_scores"].item(),
            ans_dict["normalized_scores"].item(),
        ]
        return ans

    def create_and_run_prompt(self, model_name, protein_seq):
        model_holder = self.model_dict[model_name]
        sample_inputs = {
            "protein_seq": protein_seq,
        }
        sample_dict = self.crate_sample_dict(
            sample_inputs=sample_inputs, model_holder=model_holder
        )
        prompt = sample_dict[ENCODER_INPUTS_STR]
        batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
        res = prompt, *self.decode_output(
            batch_dict, tokenizer_op=model_holder.tokenizer_op
        )
        return res

    def create_demo(self, model_name_widget):

        with gr.Group() as demo:
            gr.Markdown(self.markup_text)
            with gr.Row():
                protein_textbox = gr.Textbox(
                    label="Protein sequence",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["protein_seq"],
                )
            with gr.Row():
                run_mammal = gr.Button(
                    "Run Mammal prompt for protein solubility estimation",
                    variant="primary",
                )
            with gr.Row():
                prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

            with gr.Row():
                decoded = gr.Textbox(label="Mammal output")
                predicted_class = gr.Textbox(label="Mammal prediction")
                with gr.Column():
                    non_norm_score = gr.Number(label="Non normalized score")
                    norm_score = gr.Number(label="normalized score")
                run_mammal.click(
                    fn=self.create_and_run_prompt,
                    inputs=[model_name_widget, protein_textbox],
                    outputs=[
                        prompt_box,
                        decoded,
                        predicted_class,
                        non_norm_score,
                        norm_score,
                    ],
                )
            demo.visible = False
            return demo