import pdb import os import torch import copy from torch import nn, tensor from .config import ModelConfig from ..utils import xtqdm as tqdm from .cache_utils import load_model_state from .flash import GAULinear from ..utils import get_logger logger = get_logger() __all__ = ['NNModule'] def truncated_normal_(shape, mean=0, std=0.09): with torch.no_grad(): tensor = torch.zeros(shape) tmp = tensor.new_empty(shape + (4,)).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) tensor.data.mul_(std).add_(mean) return tensor class NNModule(nn.Module): """ An abstract class to handle weights initialization and \ a simple interface for dowloading and loading pretrained models. Args: config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module """ def __init__(self, config, *inputs, **kwargs): super().__init__() self.config = config def init_weights(self, module): """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module. Args: module (:obj:`torch.nn.Module`): The module to apply the initialization. Example:: class MyModule(NNModule): def __init__(self, config): # Add construction instructions self.bert = DeBERTa(config) # Add other modules ... # Apply initialization self.apply(self.init_weights) """ if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def init_weights_gau(self, module): """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module. Args: module (:obj:`torch.nn.Module`): The module to apply the initialization. Example:: class MyModule(NNModule): def __init__(self, config): # Add construction instructions self.bert = DeBERTa(config) # Add other modules ... # Apply initialization self.apply(self.init_weights) """ if isinstance(module, GAULinear): module.init_weight() else: if isinstance(module, (nn.Linear, nn.Embedding)): # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.copy_(self.initializer(module.weight.data.shape)) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def initializer(self, shape, dtype=None, order=3, gain=1.0): if shape[1] > 10000 or shape[1] < 10: hidden_size = shape[0] else: hidden_size = shape[1] gain *= self.config.num_hidden_layers ** (-1.0 / order) stddev = 1.13684723 / hidden_size**0.5 * gain return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev)# truncated_normal_(shape, std=stddev) @classmethod def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs): """ Instantiate a sub-class of NNModule from a pre-trained model file. Args: model_path (:obj:`str`): Path or name of the pre-trained model which can be either, - The path of pre-trained model - The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**]. If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models. model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order: 1. ['config'] in the model state dictionary. 2. `model_config.json` aside the `model_path`. If it failed to find a config the method will fail. tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`. no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`. cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa` Return: :obj:`NNModule` : The sub-class object. """ # Load config if model_config: config = ModelConfig.from_json_file(model_config) else: config = None model_config = None model_state = None if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''): model_path = None try: model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir) except Exception as exp: raise Exception(f'Failed to get model {model_path}. Exception: {exp}') if config is not None and model_config is not None: for k in config.__dict__: if k not in ['hidden_size', 'intermediate_size', 'num_attention_heads', 'num_hidden_layers', 'vocab_size', 'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0): model_config.__dict__[k] = config.__dict__[k] if model_config is not None: config = copy.copy(model_config) vocab_size = config.vocab_size # Instantiate model. model = cls(config, *inputs, **kwargs) if not model_state: return model # copy state_dict so _load_from_state_dict can modify it state_dict = model_state.copy() missing_keys = [] unexpected_keys = [] error_msgs = [] metadata = getattr(state_dict, '_metadata', None) def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model) logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}') return model