File size: 6,429 Bytes
23fe031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import pdb
import os
import torch
import copy
from torch import nn, tensor
from .config import ModelConfig
from ..utils import xtqdm as tqdm
from .cache_utils import load_model_state
from .flash import GAULinear

from ..utils import get_logger
logger = get_logger()

__all__ = ['NNModule']

def truncated_normal_(shape, mean=0, std=0.09):
    with torch.no_grad():
        tensor = torch.zeros(shape)
        tmp = tensor.new_empty(shape + (4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor

class NNModule(nn.Module):
  """ An abstract class to handle weights initialization and \
    a simple interface for dowloading and loading pretrained models.

  Args:
    
    config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module

  """

  def __init__(self, config, *inputs, **kwargs):
    super().__init__()
    self.config = config

  def init_weights(self, module):
    """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.

    Args:
      
      module (:obj:`torch.nn.Module`): The module to apply the initialization.
    
    Example::
      
      class MyModule(NNModule):
        def __init__(self, config):
          # Add construction instructions
          self.bert = DeBERTa(config)
          
          # Add other modules
          ...

          # Apply initialization
          self.apply(self.init_weights)

    """
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  def init_weights_gau(self, module):
    """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.

    Args:
      
      module (:obj:`torch.nn.Module`): The module to apply the initialization.
    
    Example::
      
      class MyModule(NNModule):
        def __init__(self, config):
          # Add construction instructions
          self.bert = DeBERTa(config)
          
          # Add other modules
          ...

          # Apply initialization
          self.apply(self.init_weights)

    """
    if isinstance(module, GAULinear):
      module.init_weight()
    else:
      if isinstance(module, (nn.Linear, nn.Embedding)):
        # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        module.weight.data.copy_(self.initializer(module.weight.data.shape))
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  def initializer(self, shape, dtype=None, order=3, gain=1.0):
    if shape[1] > 10000 or shape[1] < 10:
      hidden_size = shape[0]
    else:
      hidden_size = shape[1]
    gain *= self.config.num_hidden_layers ** (-1.0 / order)
    stddev = 1.13684723 / hidden_size**0.5 * gain
    return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev)# truncated_normal_(shape, std=stddev)

  @classmethod
  def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
    """ Instantiate a sub-class of NNModule from a pre-trained model file.
      
    Args:

      model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
        
        - The path of pre-trained model

        - The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].

        If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.

      model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
        
        1. ['config'] in the model state dictionary.

        2. `model_config.json` aside the `model_path`.
        
        If it failed to find a config the method will fail.

      tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.

      no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.

      cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`

    Return:
      
      :obj:`NNModule` : The sub-class object.

    """
    # Load config
    if model_config:
      config = ModelConfig.from_json_file(model_config)
    else:
      config = None
    model_config = None
    model_state = None
    if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
      model_path = None
    try:
      model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
    except Exception as exp:
      raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
    
    if config is not None and model_config is not None:
      for k in config.__dict__:
        if k not in ['hidden_size',
          'intermediate_size',
          'num_attention_heads',
          'num_hidden_layers',
          'vocab_size',
          'max_position_embeddings'] or (k not in  model_config.__dict__) or (model_config.__dict__[k] < 0):
          model_config.__dict__[k] = config.__dict__[k]
    if model_config is not None:
      config = copy.copy(model_config)
    vocab_size = config.vocab_size
    # Instantiate model.
    model = cls(config, *inputs, **kwargs)
    if not model_state:
      return model
    # copy state_dict so _load_from_state_dict can modify it
    state_dict = model_state.copy()

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    metadata = getattr(state_dict, '_metadata', None)
    def load(module, prefix=''):
      local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
      module._load_from_state_dict(
        state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
      for name, child in module._modules.items():
        if child is not None:
          load(child, prefix + name + '.')
    load(model)
    logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
    return model