# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Zhou Bo
# Date: 01/15/2020
#

import copy
import torch
import os
import random

import json
from .ops import *
from .bert import *
from .bert import BertLayer
from .config import ModelConfig
from .cache_utils import load_model_state
from .nnmodule import NNModule

# from ..utils.bad_grad_viz import register_hooks

__all__ = ['WywLM']

def flatten_states(q_states, mask_index):
    q_states = q_states.reshape((-1, q_states.size(-1)))
    q_states = q_states.index_select(0, mask_index)
    return q_states


class UGDecoder(torch.nn.Module):
    def __init__(self, config, vocab_size):
        super().__init__()
        self.config = config
        self.position_biased_input = getattr(config, 'position_biased_input', True)
        # self.layer = torch.nn.ModuleList([BertLayer(config) for _ in range(2)])

        # self.causal_mask = torch.tril(torch.ones((input_ids.dim(0), input_ids.dim(1), input_ids.dim(1))), diagonal=0)

    def forward(self, ctx_layers, word_embedding, input_ids, z_states, attention_mask, \
                encoder, target_ids=None, relative_pos=None, decode=False, s2s_idx=None):
        causal_outputs, lm_outputs = self.emd_context_layer(ctx_layers, z_states, attention_mask, 
                                                encoder, target_ids, input_ids, 
                                                relative_pos=relative_pos, decode=decode,
                                                word_embedding=word_embedding, s2s_idx=s2s_idx)
        # loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        
        # ctx_layer = mlm_ctx_layers[-1]

        # lm_logits = lm_logits.view(-1, lm_logits.size(-1))

        return causal_outputs[-1], lm_outputs[-1]

    def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, target_ids, input_ids,\
                          relative_pos=None, decode=False, word_embedding=None, s2s_idx=None):
        # if decode:
        #     attention_mask = torch.tril(torch.ones((input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])), diagonal=0).to(input_ids.device)
        # else:
        if attention_mask.dim()<=2:
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            att_mask = extended_attention_mask.byte()
            attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
        elif attention_mask.dim()==3:
            attention_mask = attention_mask.unsqueeze(1)
        
        
        if not self.position_biased_input: 

            
            lm_outputs = []
            # else:
            hidden_states = encoder_layers[-2]
            layers = [encoder.layer[-1] for _ in range(2)]
            z_states += hidden_states
            query_states = z_states
            query_mask = attention_mask
            rel_embeddings = encoder.get_rel_embedding()
            for layer in layers:
                # TODO: pass relative pos ids
                output = layer(hidden_states, query_mask, return_att=False, 
                            query_states=query_states, relative_pos=relative_pos, 
                            rel_embeddings=rel_embeddings)
                query_states = output
                lm_outputs.append(query_states)

            # if decode:
            attention_mask = torch.tril(torch.ones((input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])), 
                                        diagonal=0).to(input_ids.device)
            causal_outputs = []
            # with torch.no_grad():
            target_embd = word_embedding(target_ids)

            target_embd += z_states.detach()
            # self attention of target
            output = layers[-2](target_embd, attention_mask, return_att=False, 
                        query_states=target_embd, relative_pos=relative_pos, 
                        rel_embeddings=encoder.get_rel_embedding())
            causal_outputs.append(output)
            # cross attention
            output = layers[-1](output, attention_mask, return_att=False, 
                        query_states=query_states, relative_pos=relative_pos, 
                        rel_embeddings=encoder.get_rel_embedding())
            causal_outputs.append(output)

        else:
            causal_outputs = [encoder_layers[-1]]
            lm_outputs = [encoder_layers[-1]]
        return causal_outputs, lm_outputs


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


class WywLMLoss(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
        hidden_size = getattr(config, 'embedding_size', config.hidden_size)
        self.compare = torch.nn.Linear(hidden_size * 3, 2)
        # self.mlm_head = BertLMPredictionHead(config, config.vocab_size)
        self.lm_head = BertLMPredictionHead(config, config.vocab_size)

    def forward(self, logits, lm_logits, target_ids, dict_pos, input_ids, target_ids_s2s, decode=False, ebd_weight=None, task=0):
        loss_compare = torch.tensor(0).to(logits).float()
        mlm_loss = torch.tensor(0).to(logits).float()
        lm_loss = torch.tensor(0).to(logits).float()

        # else:
        if task == 1:
            compare_logits = []
            compare_labels = []
            for bi, sampel_pos in enumerate(dict_pos):
                num_pos = int((sampel_pos > 0).sum().detach().cpu().numpy() / 4) - 1
                if num_pos <= 1:
                    continue
                for pi in range(num_pos):
                    pos = sampel_pos[pi]
                    entry_logits = logits[bi][pos[0]: pos[1]]
                    desc_logits = logits[bi][pos[2]: pos[3]]
                    neg_num = random.randint(0, num_pos) # torch.randint(low=0, high=num_pos, size=(1,))
                    ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]]
                    ids_pos = input_ids[bi][pos[0]: pos[1]]
                    if neg_num == pi or (ids_neg.shape == ids_pos.shape and torch.all(ids_neg == ids_pos)):
                        neg_num = -1
                        for ni in range(num_pos):
                            neg_num = random.randint(0, num_pos)# torch.randint(low=0, high=num_pos, size=(1,))
                            ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]]
                            if neg_num != pi and (ids_neg.shape != ids_pos.shape or not torch.all(ids_neg == ids_pos)):
                                break
                            else:
                                neg_num = -1
                    if neg_num == -1:
                        continue
                    neg_desc_logits = logits[bi][sampel_pos[neg_num][2]: sampel_pos[neg_num][3]]
                    if torch.any(torch.isnan(neg_desc_logits)):
                        print('error')
                    entry_logits = entry_logits.mean(dim=0, keepdim=True).float()
                    desc_logits = desc_logits.mean(dim=0, keepdim=True).float()
                    neg_desc_logits = neg_desc_logits.mean(dim=0, keepdim=True).float()
                    compare_logits.append(torch.concat([entry_logits, desc_logits, entry_logits - desc_logits], dim=1))
                    compare_logits.append(torch.concat([entry_logits, neg_desc_logits, entry_logits - neg_desc_logits], dim=1))
                    compare_labels += [1, 0]
            if len(compare_logits) > 0:
                compare_logits = torch.concat(compare_logits, dim=0).to(logits.dtype)
                compare_pred = self.compare(compare_logits)
                loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean()

        if torch.all(loss_compare == 0):
            entry_logits = logits[0][0].unsqueeze(0)
            compare_logits = torch.concat([entry_logits, entry_logits, entry_logits - entry_logits], dim=1)
            compare_pred = self.compare(compare_logits)
            compare_labels = [1]
            loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean()

        # if decode:
        # lm_labels = target_ids_s2s.index_select(0, (target_ids_s2s.sum(-1) > 0).nonzero().view(-1)[0])
        # lm_labels = lm_labels.repeat(logits.shape[0], 1).clone().view(-1)
        # lm_labels = target_ids_s2s.clone()
        # target_ids_s2s = shift_tokens_right(target_ids_s2s, 0, 1)
        # target_ids_s2s.masked_fill_(target_ids_s2s==0, 3)
        if task == 0:
            _mask_index = (target_ids_s2s > 0).view(-1).nonzero().view(-1)
            lm_logits_ = flatten_states(lm_logits, _mask_index)
            lm_pred = self.lm_head(lm_logits_, ebd_weight).float()
            lm_labels = target_ids_s2s.clone().reshape(-1)
            lm_labels = lm_labels.index_select(0, _mask_index)
            # lm_pred = torch.nn.functional.log_softmax(lm_pred)
            # lm_loss = torch.nn.functional.nll_loss(lm_pred, lm_labels.long())
            lm_loss = self.loss_fn(lm_pred, lm_labels.long())
        # dot = register_hooks(lm_loss)
        # lm_loss.backward()
        # dot().save('tmp.dot')


        _mask_index = (target_ids > 0).view(-1).nonzero().view(-1)
        mlm_logits = flatten_states(logits, _mask_index)
        mlm_pred = self.lm_head(mlm_logits, ebd_weight).float()
        mlm_labels = target_ids.view(-1)
        mlm_labels = mlm_labels.index_select(0, _mask_index)
        mlm_loss = self.loss_fn(mlm_pred, mlm_labels.long())
        return loss_compare, mlm_loss, lm_loss

class WywLM(torch.nn.Module):
    """ DeBERTa encoder
    This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.

    Parameters:
        config:
            A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
                    for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`

        pre_trained:
            The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
                    i.e. [**base, large, base_mnli, large_mnli**]

    """

    def __init__(self, config=None, pre_trained=None):
        super().__init__()
        state = None
        if pre_trained is not None:
            state, model_config = load_model_state(pre_trained)
            if config is not None and model_config is not None:
                for k in config.__dict__:
                    if k not in ['hidden_size',
                        'intermediate_size',
                        'num_attention_heads',
                        'num_hidden_layers',
                        'vocab_size',
                        'max_position_embeddings']:
                        model_config.__dict__[k] = config.__dict__[k]
            config = copy.copy(model_config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.config = config
        self.pre_trained = pre_trained
        self.apply_state(state)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
        """
        Args:
            input_ids:
                a torch.LongTensor of shape [batch_size, sequence_length] \
            with the word token indices in the vocabulary

            attention_mask:
                an optional parameter for input mask or attention mask.

                - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
            input sequence length in the current batch. It's the mask that we typically use for attention when \
            a batch has varying length sentences.

                - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
            In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.

            token_type_ids:
                an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
            a `sentence B` token (see BERT paper for more details).

            output_all_encoded_layers:
                whether to output results of all encoder layers, default, True

        Returns:

            - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
            the last layer of stacked transformer layers

            - Attention matrix of self-attention layers if `return_att=True`


        Example::

            # Batch of wordPiece token ids.
            # Each sample was padded with zero to the maxium length of the batch
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            # Mask of valid input ids
            attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])

            # DeBERTa model initialized with pretrained base model
            bert = DeBERTa(pre_trained='base')

            encoder_layers = bert(input_ids, attention_mask=attention_mask)

        """

        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
            token_mask = torch.ones_like(input_ids)
        else:
            idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
            token_mask = idxs > 0
            token_mask = token_mask.byte()
        ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, token_mask)
        embedding_output = ebd_output['embeddings']
        encoder_output = self.encoder(embedding_output,
                                     attention_mask,
                                     output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
        encoder_output.update(ebd_output)
        return encoder_output

    def apply_state(self, state = None):
        """ Load state from previous loaded model state dictionary.

            Args:
                state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
                        If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
                        the `DeBERTa` model
        """
        if self.pre_trained is None and state is None:
            return
        if state is None:
            state, config = load_model_state(self.pre_trained)
            self.config = config
        
        prefix = ''
        for k in state:
            if 'embeddings.' in k:
                if not k.startswith('embeddings.'):
                    prefix = k[:k.index('embeddings.')]
                break

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)


class MaskedLanguageModel(NNModule):
    """ Masked language model
    """
    def __init__(self, config, *wargs, **kwargs):
        super().__init__(config)
        self.backbone = WywLM(config)

        self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
        self.position_buckets = getattr(config, 'position_buckets', -1)
        if self.max_relative_positions <1:
            self.max_relative_positions = config.max_position_embeddings
        # self.mlm_predictions = UGDecoder(self.backbone.config, self.backbone.embeddings.word_embeddings.weight.size(0))
        self.lm_predictions = UGDecoder(self.backbone.config, self.backbone.embeddings.word_embeddings.weight.size(0))
        self.device = None
        self.loss = WywLMLoss(config)
        # self.loss_lm = WywLMLoss(config)
        self.apply(self.init_weights)

    def forward(self, samples, position_ids=None):
        task = samples['task']
        if task == 0:
            input_ids = samples['s2s_input_ids']
            type_ids = samples['s2s_token_type_ids']
            attention_mask = samples['s2s_attention_mask']
            labels = samples['s2s_masked_lm_labels']
            dict_pos = samples['dict_pos']
            s2s_label = samples['s2s_label']
        else:
            input_ids = samples['input_ids']
            type_ids = samples['token_type_ids']
            attention_mask = samples['attention_mask']
            labels = samples['masked_lm_labels']
            dict_pos = samples['dict_pos']
            s2s_label = samples['s2s_label']
        
        if self.device is None:
            self.device = list(self.parameters())[0].device
        
        input_ids = input_ids.to(self.device)

        type_ids = None
        lm_labels = labels.to(self.device)
        s2s_label = s2s_label.to(self.device)
        attention_mask = attention_mask.to(self.device)

        encoder_output = self.backbone(input_ids, attention_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids)
        encoder_layers = encoder_output['hidden_states']
        z_states = encoder_output['position_embeddings']
        ctx_layer = encoder_layers[-1]
        mlm_loss = torch.tensor(0).to(ctx_layer).float()
        lm_loss = torch.tensor(0).to(ctx_layer).float()
        lm_logits = None
        label_inputs = None
        loss = torch.tensor(0).to(ctx_layer).float()
        loss_compare = torch.tensor(0).to(ctx_layer).float()

        ebd_weight = self.backbone.embeddings.word_embeddings.weight
        lm_logits, mlm_logits = self.lm_predictions(encoder_layers, self.backbone.embeddings.word_embeddings, 
                                        input_ids, z_states, 
                                        attention_mask, self.backbone.encoder,
                                        target_ids=lm_labels)
        # if lm_labels.detach().sum() != 0:
        loss_compare, mlm_loss, lm_loss = self.loss(mlm_logits, 
                                                    lm_logits,
                                                    lm_labels, 
                                                    dict_pos, 
                                                    target_ids_s2s=s2s_label,
                                                    decode=False, 
                                                    ebd_weight=ebd_weight, 
                                                    input_ids=input_ids,
                                                    task=task)
        loss = loss_compare * 10 + mlm_loss + lm_loss
        # if s2s_label.detach().sum() != 0:
        #     s2s_idx = (s2s_label.sum(-1)>0).nonzero().view(-1)
        #     s2s_label = s2s_label.index_select(0, s2s_idx)
        #     # ebd_weight = self.backbone.embeddings.word_embeddings.weight
        #     # lm_logits = self.lm_predictions(encoder_layers[-3], self.backbone.embeddings.word_embeddings, 
        #     #                                 input_ids.index_select(0, s2s_idx), z_states.index_select(0, s2s_idx), 
        #     #                                 attention_mask.index_select(0, s2s_idx), self.backbone.encoder,
        #     #                                 target_ids=s2s_label,
        #     #                                 decode=True, s2s_idx=s2s_idx)
        #     # lm_logits = encoder_layers[-1].detach().index_select(0, s2s_idx)
        #     _, lm_loss = self.loss_lm(lm_logits, 
        #                                 s2s_label, 
        #                                 torch.zeros_like(dict_pos), 
        #                                 decode=True, 
        #                                 ebd_weight=ebd_weight, 
        #                                 input_ids=input_ids.index_select(0, s2s_idx))
            # lm_loss = lm_logits.max()
            # loss = loss + lm_loss

        return {
                'logits' : lm_logits,
                'labels' : lm_labels,
                's2s_label': s2s_label,
                'loss' : loss.float(),
                'loss_compare': loss_compare.float(),
                'lm_loss': lm_loss.float(),
                'mlm_loss': mlm_loss.float()
            }