File size: 6,429 Bytes
23fe031 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import pdb
import os
import torch
import copy
from torch import nn, tensor
from .config import ModelConfig
from ..utils import xtqdm as tqdm
from .cache_utils import load_model_state
from .flash import GAULinear
from ..utils import get_logger
logger = get_logger()
__all__ = ['NNModule']
def truncated_normal_(shape, mean=0, std=0.09):
with torch.no_grad():
tensor = torch.zeros(shape)
tmp = tensor.new_empty(shape + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
class NNModule(nn.Module):
""" An abstract class to handle weights initialization and \
a simple interface for dowloading and loading pretrained models.
Args:
config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__()
self.config = config
def init_weights(self, module):
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
Args:
module (:obj:`torch.nn.Module`): The module to apply the initialization.
Example::
class MyModule(NNModule):
def __init__(self, config):
# Add construction instructions
self.bert = DeBERTa(config)
# Add other modules
...
# Apply initialization
self.apply(self.init_weights)
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def init_weights_gau(self, module):
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
Args:
module (:obj:`torch.nn.Module`): The module to apply the initialization.
Example::
class MyModule(NNModule):
def __init__(self, config):
# Add construction instructions
self.bert = DeBERTa(config)
# Add other modules
...
# Apply initialization
self.apply(self.init_weights)
"""
if isinstance(module, GAULinear):
module.init_weight()
else:
if isinstance(module, (nn.Linear, nn.Embedding)):
# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.copy_(self.initializer(module.weight.data.shape))
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def initializer(self, shape, dtype=None, order=3, gain=1.0):
if shape[1] > 10000 or shape[1] < 10:
hidden_size = shape[0]
else:
hidden_size = shape[1]
gain *= self.config.num_hidden_layers ** (-1.0 / order)
stddev = 1.13684723 / hidden_size**0.5 * gain
return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev)# truncated_normal_(shape, std=stddev)
@classmethod
def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
""" Instantiate a sub-class of NNModule from a pre-trained model file.
Args:
model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
- The path of pre-trained model
- The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].
If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.
model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
1. ['config'] in the model state dictionary.
2. `model_config.json` aside the `model_path`.
If it failed to find a config the method will fail.
tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.
no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.
cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`
Return:
:obj:`NNModule` : The sub-class object.
"""
# Load config
if model_config:
config = ModelConfig.from_json_file(model_config)
else:
config = None
model_config = None
model_state = None
if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
model_path = None
try:
model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
except Exception as exp:
raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
if config is not None and model_config is not None:
for k in config.__dict__:
if k not in ['hidden_size',
'intermediate_size',
'num_attention_heads',
'num_hidden_layers',
'vocab_size',
'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
model_config.__dict__[k] = config.__dict__[k]
if model_config is not None:
config = copy.copy(model_config)
vocab_size = config.vocab_size
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if not model_state:
return model
# copy state_dict so _load_from_state_dict can modify it
state_dict = model_state.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, '_metadata', None)
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model)
logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
return model
|