Spaces:
Sleeping
Sleeping
File size: 12,322 Bytes
847e3e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
"""Tweaked version of corresponding AllenNLP file"""
import logging
from copy import deepcopy
from typing import Dict
import torch
import torch.nn.functional as F
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.nn import util
from transformers import AutoModel, PreTrainedModel
logger = logging.getLogger(__name__)
class PretrainedBertModel:
"""
In some instances you may want to load the same BERT model twice
(e.g. to use as a token embedder and also as a pooling layer).
This factory provides a cache so that you don't actually have to load the model twice.
"""
_cache: Dict[str, PreTrainedModel] = {}
@classmethod
def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel:
if model_name in cls._cache:
return PretrainedBertModel._cache[model_name]
model = AutoModel.from_pretrained(model_name)
if cache_model:
cls._cache[model_name] = model
return model
class BertEmbedder(TokenEmbedder):
"""
A ``TokenEmbedder`` that produces BERT embeddings for your tokens.
Should be paired with a ``BertIndexer``, which produces wordpiece ids.
Most likely you probably want to use ``PretrainedBertEmbedder``
for one of the named pretrained models, not this base class.
Parameters
----------
bert_model: ``BertModel``
The BERT model being wrapped.
top_layer_only: ``bool``, optional (default = ``False``)
If ``True``, then only return the top layer instead of apply the scalar mix.
max_pieces : int, optional (default: 512)
The BERT embedder uses positional embeddings and so has a corresponding
maximum length for its input ids. Assuming the inputs are windowed
and padded appropriately by this length, the embedder will split them into a
large batch, feed them into BERT, and recombine the output as if it was a
longer sequence.
num_start_tokens : int, optional (default: 1)
The number of starting special tokens input to BERT (usually 1, i.e., [CLS])
num_end_tokens : int, optional (default: 1)
The number of ending tokens input to BERT (usually 1, i.e., [SEP])
scalar_mix_parameters: ``List[float]``, optional, (default = None)
If not ``None``, use these scalar mix parameters to weight the representations
produced by different layers. These mixing weights are not updated during
training.
"""
def __init__(
self,
bert_model: PreTrainedModel,
top_layer_only: bool = False,
max_pieces: int = 512,
num_start_tokens: int = 1,
num_end_tokens: int = 1
) -> None:
super().__init__()
self.bert_model = deepcopy(bert_model)
self.output_dim = bert_model.config.hidden_size
self.max_pieces = max_pieces
self.num_start_tokens = num_start_tokens
self.num_end_tokens = num_end_tokens
self._scalar_mix = None
def set_weights(self, freeze):
for param in self.bert_model.parameters():
param.requires_grad = not freeze
return
def get_output_dim(self) -> int:
return self.output_dim
def forward(
self,
input_ids: torch.LongTensor,
offsets: torch.LongTensor = None
) -> torch.Tensor:
"""
Parameters
----------
input_ids : ``torch.LongTensor``
The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
offsets : ``torch.LongTensor``, optional
The BERT embeddings are one per wordpiece. However it's possible/likely
you might want one per original token. In that case, ``offsets``
represents the indices of the desired wordpiece for each original token.
Depending on how your token indexer is configured, this could be the
position of the last wordpiece for each token, or it could be the position
of the first wordpiece for each token.
For example, if you had the sentence "Definitely not", and if the corresponding
wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
If offsets are provided, the returned tensor will contain only the wordpiece
embeddings at those positions, and (in particular) will contain one embedding
per token. If offsets are not provided, the entire tensor of wordpiece embeddings
will be returned.
"""
batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
initial_dims = list(input_ids.shape[:-1])
# The embedder may receive an input tensor that has a sequence length longer than can
# be fit. In that case, we should expect the wordpiece indexer to create padded windows
# of length `self.max_pieces` for us, and have them concatenated into one long sequence.
# E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
# We can then split the sequence into sub-sequences of that length, and concatenate them
# along the batch dimension so we effectively have one huge batch of partial sentences.
# This can then be fed into BERT without any sentence length issues. Keep in mind
# that the memory consumption can dramatically increase for large batches with extremely
# long sentences.
needs_split = full_seq_len > self.max_pieces
last_window_size = 0
if needs_split:
# Split the flattened list by the window size, `max_pieces`
split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))
# We want all sequences to be the same length, so pad the last sequence
last_window_size = split_input_ids[-1].size(-1)
padding_amount = self.max_pieces - last_window_size
split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0)
# Now combine the sequences along the batch dimension
input_ids = torch.cat(split_input_ids, dim=0)
input_mask = (input_ids != 0).long()
# input_ids may have extra dimensions, so we reshape down to 2-d
# before calling the BERT model and then reshape back at the end.
all_encoder_layers = self.bert_model(
input_ids=util.combine_initial_dims(input_ids),
attention_mask=util.combine_initial_dims(input_mask),
)[0]
if len(all_encoder_layers[0].shape) == 3:
all_encoder_layers = torch.stack(all_encoder_layers)
elif len(all_encoder_layers[0].shape) == 2:
all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0)
if needs_split:
# First, unpack the output embeddings into one long sequence again
unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1)
unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)
# Next, select indices of the sequence such that it will result in embeddings representing the original
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
# sub-sequence (plus any leftover start and final edge windows), e.g.,
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
# and final windows with indices [0, 1] and [14, 15] respectively.
# Find the stride as half the max pieces, ignoring the special start and end tokens
# Calculate an offset to extract the centermost embeddings of each window
stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2
stride_offset = stride // 2 + self.num_start_tokens
first_window = list(range(stride_offset))
max_context_windows = [
i
for i in range(full_seq_len)
if stride_offset - 1 < i % self.max_pieces < stride_offset + stride
]
# Lookback what's left, unless it's the whole self.max_pieces window
if full_seq_len % self.max_pieces == 0:
lookback = self.max_pieces
else:
lookback = full_seq_len % self.max_pieces
final_window_start = full_seq_len - lookback + stride_offset + stride
final_window = list(range(final_window_start, full_seq_len))
select_indices = first_window + max_context_windows + final_window
initial_dims.append(len(select_indices))
recombined_embeddings = unpacked_embeddings[:, :, select_indices]
else:
recombined_embeddings = all_encoder_layers
# Recombine the outputs of all layers
# (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
# recombined = torch.cat(combined, dim=2)
input_mask = (recombined_embeddings != 0).long()
if self._scalar_mix is not None:
mix = self._scalar_mix(recombined_embeddings, input_mask)
else:
mix = recombined_embeddings[-1]
# At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)
if offsets is None:
# Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
dims = initial_dims if needs_split else input_ids.size()
return util.uncombine_initial_dims(mix, dims)
else:
# offsets is (batch_size, d1, ..., dn, orig_sequence_length)
offsets2d = util.combine_initial_dims(offsets)
# now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
range_vector = util.get_range_vector(
offsets2d.size(0), device=util.get_device_of(mix)
).unsqueeze(1)
# selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
selected_embeddings = mix[range_vector, offsets2d]
return util.uncombine_initial_dims(selected_embeddings, offsets.size())
# @TokenEmbedder.register("bert-pretrained")
class PretrainedBertEmbedder(BertEmbedder):
"""
Parameters
----------
pretrained_model: ``str``
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
or the path to the .tar.gz file with the model weights.
If the name is a key in the list of pretrained models at
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41
the corresponding path will be used; otherwise it will be interpreted as a path or URL.
requires_grad : ``bool``, optional (default = False)
If True, compute gradient of BERT parameters for fine tuning.
top_layer_only: ``bool``, optional (default = ``False``)
If ``True``, then only return the top layer instead of apply the scalar mix.
scalar_mix_parameters: ``List[float]``, optional, (default = None)
If not ``None``, use these scalar mix parameters to weight the representations
produced by different layers. These mixing weights are not updated during
training.
"""
def __init__(
self,
pretrained_model: str,
requires_grad: bool = False,
top_layer_only: bool = False,
special_tokens_fix: int = 0,
) -> None:
model = PretrainedBertModel.load(pretrained_model)
for param in model.parameters():
param.requires_grad = requires_grad
super().__init__(
bert_model=model,
top_layer_only=top_layer_only
)
if special_tokens_fix:
try:
vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings
except AttributeError:
# reserve more space
vocab_size = self.bert_model.word_embedding.num_embeddings + 5
self.bert_model.resize_token_embeddings(vocab_size + 1)
|