|
import pdb |
|
import os |
|
import torch |
|
import copy |
|
from torch import nn, tensor |
|
from .config import ModelConfig |
|
from ..utils import xtqdm as tqdm |
|
from .cache_utils import load_model_state |
|
from .flash import GAULinear |
|
|
|
from ..utils import get_logger |
|
logger = get_logger() |
|
|
|
__all__ = ['NNModule'] |
|
|
|
def truncated_normal_(shape, mean=0, std=0.09): |
|
with torch.no_grad(): |
|
tensor = torch.zeros(shape) |
|
tmp = tensor.new_empty(shape + (4,)).normal_() |
|
valid = (tmp < 2) & (tmp > -2) |
|
ind = valid.max(-1, keepdim=True)[1] |
|
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) |
|
tensor.data.mul_(std).add_(mean) |
|
return tensor |
|
|
|
class NNModule(nn.Module): |
|
""" An abstract class to handle weights initialization and \ |
|
a simple interface for dowloading and loading pretrained models. |
|
|
|
Args: |
|
|
|
config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module |
|
|
|
""" |
|
|
|
def __init__(self, config, *inputs, **kwargs): |
|
super().__init__() |
|
self.config = config |
|
|
|
def init_weights(self, module): |
|
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module. |
|
|
|
Args: |
|
|
|
module (:obj:`torch.nn.Module`): The module to apply the initialization. |
|
|
|
Example:: |
|
|
|
class MyModule(NNModule): |
|
def __init__(self, config): |
|
# Add construction instructions |
|
self.bert = DeBERTa(config) |
|
|
|
# Add other modules |
|
... |
|
|
|
# Apply initialization |
|
self.apply(self.init_weights) |
|
|
|
""" |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def init_weights_gau(self, module): |
|
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module. |
|
|
|
Args: |
|
|
|
module (:obj:`torch.nn.Module`): The module to apply the initialization. |
|
|
|
Example:: |
|
|
|
class MyModule(NNModule): |
|
def __init__(self, config): |
|
# Add construction instructions |
|
self.bert = DeBERTa(config) |
|
|
|
# Add other modules |
|
... |
|
|
|
# Apply initialization |
|
self.apply(self.init_weights) |
|
|
|
""" |
|
if isinstance(module, GAULinear): |
|
module.init_weight() |
|
else: |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
|
module.weight.data.copy_(self.initializer(module.weight.data.shape)) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def initializer(self, shape, dtype=None, order=3, gain=1.0): |
|
if shape[1] > 10000 or shape[1] < 10: |
|
hidden_size = shape[0] |
|
else: |
|
hidden_size = shape[1] |
|
gain *= self.config.num_hidden_layers ** (-1.0 / order) |
|
stddev = 1.13684723 / hidden_size**0.5 * gain |
|
return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev) |
|
|
|
@classmethod |
|
def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs): |
|
""" Instantiate a sub-class of NNModule from a pre-trained model file. |
|
|
|
Args: |
|
|
|
model_path (:obj:`str`): Path or name of the pre-trained model which can be either, |
|
|
|
- The path of pre-trained model |
|
|
|
- The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**]. |
|
|
|
If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models. |
|
|
|
model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order: |
|
|
|
1. ['config'] in the model state dictionary. |
|
|
|
2. `model_config.json` aside the `model_path`. |
|
|
|
If it failed to find a config the method will fail. |
|
|
|
tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`. |
|
|
|
no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`. |
|
|
|
cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa` |
|
|
|
Return: |
|
|
|
:obj:`NNModule` : The sub-class object. |
|
|
|
""" |
|
|
|
if model_config: |
|
config = ModelConfig.from_json_file(model_config) |
|
else: |
|
config = None |
|
model_config = None |
|
model_state = None |
|
if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''): |
|
model_path = None |
|
try: |
|
model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir) |
|
except Exception as exp: |
|
raise Exception(f'Failed to get model {model_path}. Exception: {exp}') |
|
|
|
if config is not None and model_config is not None: |
|
for k in config.__dict__: |
|
if k not in ['hidden_size', |
|
'intermediate_size', |
|
'num_attention_heads', |
|
'num_hidden_layers', |
|
'vocab_size', |
|
'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0): |
|
model_config.__dict__[k] = config.__dict__[k] |
|
if model_config is not None: |
|
config = copy.copy(model_config) |
|
vocab_size = config.vocab_size |
|
|
|
model = cls(config, *inputs, **kwargs) |
|
if not model_state: |
|
return model |
|
|
|
state_dict = model_state.copy() |
|
|
|
missing_keys = [] |
|
unexpected_keys = [] |
|
error_msgs = [] |
|
metadata = getattr(state_dict, '_metadata', None) |
|
def load(module, prefix=''): |
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
|
module._load_from_state_dict( |
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) |
|
for name, child in module._modules.items(): |
|
if child is not None: |
|
load(child, prefix + name + '.') |
|
load(model) |
|
logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}') |
|
return model |
|
|