from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, WhisperForConditionalGeneration, PretrainedConfig, PreTrainedModel, BertConfig, AutoProcessor
from transformers.models.bert.modeling_bert import BertEncoder
from torch import nn
import torch
import os


class Desta2Config(PretrainedConfig):
    model_type = "DestaModel"

    def __init__(
        self,
        llama_model_id="meta-llama/Meta-Llama-3-8B-Instruct",
        whisper_model_id="openai/whisper-small",
        prompt_size=64,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.llama_model_id = llama_model_id
        self.whisper_model_id = whisper_model_id
        self.prompt_size = prompt_size

        self.whisper_config = AutoConfig.from_pretrained(self.whisper_model_id)
        self.llama_config = AutoConfig.from_pretrained(self.llama_model_id)

class QformerConnector(PreTrainedModel):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.cfg = cfg
        
        
        if self.cfg.whisper_model_id == "openai/whisper-medium":
            self.target_layer_ids = [5, 11, 17, 23]
        elif self.cfg.whisper_model_id == "openai/whisper-small":
            self.target_layer_ids = [2, 5, 8, 11]
        elif self.cfg.whisper_model_id == "openai/whisper-tiny":
            self.target_layer_ids = [0,1,2,3]
        elif self.cfg.whisper_model_id == "openai/whisper-large-v3":
            self.target_layer_ids = [3, 7, 11, 15, 19, 23, 27, 31]
        else:
            raise NotImplementedError(f"model_id {self.cfg.whisper_model_id} not implemented")


        self.layer_prompts = nn.ParameterList([
            nn.Parameter(torch.randn(1, self.cfg.prompt_size, self.cfg.whisper_config.d_model)) for _ in range(len(self.target_layer_ids))]
        )
        
        
        # (prompt_size, target_layers)
        self.layer_weights = nn.Parameter(torch.zeros(self.cfg.prompt_size, len(self.target_layer_ids), dtype=torch.float))

        qformer_config = BertConfig()
        qformer_config.num_hidden_layers = 2
        qformer_config.num_attention_heads = self.cfg.whisper_config.encoder_attention_heads
        qformer_config.hidden_size = self.cfg.whisper_config.d_model
        qformer_config.add_cross_attention = True
        qformer_config.is_decoder = True

        self.qformer = BertEncoder(qformer_config)
        self.proj = nn.Sequential(
                nn.LayerNorm(self.cfg.whisper_config.d_model),
                nn.Linear(self.cfg.whisper_config.d_model, self.cfg.llama_config.hidden_size) # project to llama hidden size
            )
    
    def forward(self, encoder_hidden_states):
        layer_prompt_outputs = []
        for idx, encoder_hidden_state in enumerate(encoder_hidden_states):
            if idx in self.target_layer_ids:
                layer_prompt = self.layer_prompts[self.target_layer_ids.index(idx)].expand(encoder_hidden_state.size(0), -1, -1)
                qformer_output = self.qformer(
                    hidden_states=layer_prompt,
                    encoder_hidden_states=encoder_hidden_state,
                )
                layer_prompt_output = qformer_output.last_hidden_state
                layer_prompt_outputs.append(layer_prompt_output)
        
        layer_prompt_outputs = torch.stack(layer_prompt_outputs, dim=0)
        layer_prompt_outputs = layer_prompt_outputs.permute(1, 2, 0, 3)
        
        self.norm_weights = torch.nn.functional.softmax(self.layer_weights, dim=-1).unsqueeze(-1)
        
        output = (layer_prompt_outputs * self.norm_weights).sum(dim=2) # (b, prompt_size, d_model)
        
        output = self.proj(output)
        
        return output

class SpeechPerception(PreTrainedModel):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.cfg = cfg

        self.whisper = WhisperForConditionalGeneration.from_pretrained(cfg.whisper_model_id)
        self.processor = AutoProcessor.from_pretrained(cfg.whisper_model_id)

        self.connector = QformerConnector(cfg)

    def generate(self, input_features):
        input_features = input_features.to(self.whisper.device)
        
        outputs = self.whisper.generate(inputs=input_features, return_dict_in_generate=True, output_hidden_states=True) # here we use default generate config for whisper

        transcriptions = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
        speech_features = self.connector(outputs.encoder_hidden_states)

        return transcriptions, speech_features


class DestaModel(PreTrainedModel):
    config_class = Desta2Config

    def __init__(self, config):
        super().__init__(config)

        self.speech_perception = SpeechPerception(config)
        self.llama = AutoModelForCausalLM.from_pretrained(config.llama_model_id, torch_dtype=torch.bfloat16)
        self.tokenizer = AutoTokenizer.from_pretrained(config.llama_model_id)
        

    def chat(self, messages, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9):
        """
        messages: list of dicts with keys "role" and "content"
        ```
        [
            {"role": "system", "content": "You are a helpful voice assistant."},
            {"role": "audio", "content": "<path_to_audio_file>"},
            {"role": "user", "content": "Describe the audio."}
        ]
        ```
        """

        audio_path, input_features = self.load_audio(messages)
        transcription, audio_features = self.speech_perception.generate(input_features)
        inputs, audio_position = self.process_text(messages, audio_path, transcription)

        inputs_embeds, attention_mask = self.prepare_llm_input(
            input_ids=inputs.input_ids, 
            attention_mask=inputs.attention_mask, 
            audio_position=audio_position,
            audio_features=audio_features
        )

        outputs = self.llama.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            pad_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
        )
        return outputs

    def process_text(self, messages, audio_path, transcription):
        context = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        left_text, right_text = context.split(audio_path)
        right_text = transcription + right_text # 
        
        audio_position = len(self.tokenizer.tokenize(left_text))
        context = left_text + right_text

        inputs = self.tokenizer(context, return_tensors="pt")

        return inputs, audio_position


    def prepare_llm_input(self, input_ids, attention_mask, audio_position, audio_features):
        input_ids = input_ids.to(self.llama.device)
        attention_mask = attention_mask.to(self.llama.device)
        audio_features = audio_features.to(self.llama.device)
        audio_feature_length = audio_features.size(1)

        inputs_embeds = self.llama.model.embed_tokens(input_ids) # [bs, seq_len, hidden_size]


        inputs_embeds = torch.cat([inputs_embeds[0, :audio_position], audio_features[0, :], inputs_embeds[0, audio_position:]], dim=0)
        attention_mask = torch.cat([attention_mask[0, :audio_position], torch.ones([ audio_feature_length], dtype=torch.long, device=self.llama.device), attention_mask[0, audio_position:]], dim=0)

        inputs_embeds = inputs_embeds.to(self.llama.dtype)
        attention_mask = attention_mask.to(self.llama.dtype)
        return inputs_embeds.unsqueeze(0), attention_mask.unsqueeze(0)

    
    def load_audio(self, messages):
        audio_path = None
        for message in messages:
            if message["role"] == "audio" and audio_path is not None:
                raise ValueError("Multiple audio file paths found in messages. We only support one audio file per message at this moment.")
            if message["role"] == "audio":
                audio_path = message["content"]
        if audio_path is None:
            raise ValueError("No audio file path found in messages")
        audio, ori_sr = librosa.load(audio_path)
        audio = librosa.resample(audio, orig_sr=ori_sr, target_sr=16000)
        input_features = self.speech_perception.processor(audio, sampling_rate=16000, return_tensors="pt").input_features

        return audio_path, input_features
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, cache_dir=None,**kwargs):
        config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
        model = cls(config)

        if os.path.isdir(pretrained_model_name_or_path):
            model.speech_perception.connector.load_state_dict(
                torch.load(os.path.join(pretrained_model_name_or_path, "qformer_connector.pth"))
            )
        else:
            from huggingface_hub import hf_hub_download
            path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="qformer_connector.pth")
            model.speech_perception.connector.load_state_dict(
                torch.load(path)
            )

        return model