import pdb
import os
import torch
import copy
from torch import nn, tensor
from .config import ModelConfig
from ..utils import xtqdm as tqdm
from .cache_utils import load_model_state
from .flash import GAULinear

from ..utils import get_logger
logger = get_logger()

__all__ = ['NNModule']

def truncated_normal_(shape, mean=0, std=0.09):
    with torch.no_grad():
        tensor = torch.zeros(shape)
        tmp = tensor.new_empty(shape + (4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor

class NNModule(nn.Module):
  """ An abstract class to handle weights initialization and \
    a simple interface for dowloading and loading pretrained models.

  Args:
    
    config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module

  """

  def __init__(self, config, *inputs, **kwargs):
    super().__init__()
    self.config = config

  def init_weights(self, module):
    """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.

    Args:
      
      module (:obj:`torch.nn.Module`): The module to apply the initialization.
    
    Example::
      
      class MyModule(NNModule):
        def __init__(self, config):
          # Add construction instructions
          self.bert = DeBERTa(config)
          
          # Add other modules
          ...

          # Apply initialization
          self.apply(self.init_weights)

    """
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  def init_weights_gau(self, module):
    """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.

    Args:
      
      module (:obj:`torch.nn.Module`): The module to apply the initialization.
    
    Example::
      
      class MyModule(NNModule):
        def __init__(self, config):
          # Add construction instructions
          self.bert = DeBERTa(config)
          
          # Add other modules
          ...

          # Apply initialization
          self.apply(self.init_weights)

    """
    if isinstance(module, GAULinear):
      module.init_weight()
    else:
      if isinstance(module, (nn.Linear, nn.Embedding)):
        # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        module.weight.data.copy_(self.initializer(module.weight.data.shape))
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  def initializer(self, shape, dtype=None, order=3, gain=1.0):
    if shape[1] > 10000 or shape[1] < 10:
      hidden_size = shape[0]
    else:
      hidden_size = shape[1]
    gain *= self.config.num_hidden_layers ** (-1.0 / order)
    stddev = 1.13684723 / hidden_size**0.5 * gain
    return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev)# truncated_normal_(shape, std=stddev)

  @classmethod
  def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
    """ Instantiate a sub-class of NNModule from a pre-trained model file.
      
    Args:

      model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
        
        - The path of pre-trained model

        - The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].

        If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.

      model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
        
        1. ['config'] in the model state dictionary.

        2. `model_config.json` aside the `model_path`.
        
        If it failed to find a config the method will fail.

      tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.

      no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.

      cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`

    Return:
      
      :obj:`NNModule` : The sub-class object.

    """
    # Load config
    if model_config:
      config = ModelConfig.from_json_file(model_config)
    else:
      config = None
    model_config = None
    model_state = None
    if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
      model_path = None
    try:
      model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
    except Exception as exp:
      raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
    
    if config is not None and model_config is not None:
      for k in config.__dict__:
        if k not in ['hidden_size',
          'intermediate_size',
          'num_attention_heads',
          'num_hidden_layers',
          'vocab_size',
          'max_position_embeddings'] or (k not in  model_config.__dict__) or (model_config.__dict__[k] < 0):
          model_config.__dict__[k] = config.__dict__[k]
    if model_config is not None:
      config = copy.copy(model_config)
    vocab_size = config.vocab_size
    # Instantiate model.
    model = cls(config, *inputs, **kwargs)
    if not model_state:
      return model
    # copy state_dict so _load_from_state_dict can modify it
    state_dict = model_state.copy()

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    metadata = getattr(state_dict, '_metadata', None)
    def load(module, prefix=''):
      local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
      module._load_from_state_dict(
        state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
      for name, child in module._modules.items():
        if child is not None:
          load(child, prefix + name + '.')
    load(model)
    logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
    return model