DeBERTa-base / modeling /nnmodule.py
bozhou's picture
Upload 22 files
23fe031
raw
history blame
6.43 kB
import pdb
import os
import torch
import copy
from torch import nn, tensor
from .config import ModelConfig
from ..utils import xtqdm as tqdm
from .cache_utils import load_model_state
from .flash import GAULinear
from ..utils import get_logger
logger = get_logger()
__all__ = ['NNModule']
def truncated_normal_(shape, mean=0, std=0.09):
with torch.no_grad():
tensor = torch.zeros(shape)
tmp = tensor.new_empty(shape + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
class NNModule(nn.Module):
""" An abstract class to handle weights initialization and \
a simple interface for dowloading and loading pretrained models.
Args:
config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__()
self.config = config
def init_weights(self, module):
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
Args:
module (:obj:`torch.nn.Module`): The module to apply the initialization.
Example::
class MyModule(NNModule):
def __init__(self, config):
# Add construction instructions
self.bert = DeBERTa(config)
# Add other modules
...
# Apply initialization
self.apply(self.init_weights)
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def init_weights_gau(self, module):
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
Args:
module (:obj:`torch.nn.Module`): The module to apply the initialization.
Example::
class MyModule(NNModule):
def __init__(self, config):
# Add construction instructions
self.bert = DeBERTa(config)
# Add other modules
...
# Apply initialization
self.apply(self.init_weights)
"""
if isinstance(module, GAULinear):
module.init_weight()
else:
if isinstance(module, (nn.Linear, nn.Embedding)):
# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.copy_(self.initializer(module.weight.data.shape))
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def initializer(self, shape, dtype=None, order=3, gain=1.0):
if shape[1] > 10000 or shape[1] < 10:
hidden_size = shape[0]
else:
hidden_size = shape[1]
gain *= self.config.num_hidden_layers ** (-1.0 / order)
stddev = 1.13684723 / hidden_size**0.5 * gain
return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev)# truncated_normal_(shape, std=stddev)
@classmethod
def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
""" Instantiate a sub-class of NNModule from a pre-trained model file.
Args:
model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
- The path of pre-trained model
- The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].
If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.
model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
1. ['config'] in the model state dictionary.
2. `model_config.json` aside the `model_path`.
If it failed to find a config the method will fail.
tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.
no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.
cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`
Return:
:obj:`NNModule` : The sub-class object.
"""
# Load config
if model_config:
config = ModelConfig.from_json_file(model_config)
else:
config = None
model_config = None
model_state = None
if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
model_path = None
try:
model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
except Exception as exp:
raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
if config is not None and model_config is not None:
for k in config.__dict__:
if k not in ['hidden_size',
'intermediate_size',
'num_attention_heads',
'num_hidden_layers',
'vocab_size',
'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
model_config.__dict__[k] = config.__dict__[k]
if model_config is not None:
config = copy.copy(model_config)
vocab_size = config.vocab_size
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if not model_state:
return model
# copy state_dict so _load_from_state_dict can modify it
state_dict = model_state.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, '_metadata', None)
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model)
logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
return model