Spaces:
Running
Running
| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """tf.keras Models for NHNet.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| # from __future__ import google_type_annotations | |
| from __future__ import print_function | |
| from absl import logging | |
| import gin | |
| import tensorflow as tf | |
| from typing import Optional, Text | |
| from official.modeling import tf_utils | |
| from official.modeling.hyperparams import params_dict | |
| from official.nlp.modeling import networks | |
| from official.nlp.modeling.layers import multi_channel_attention | |
| from official.nlp.nhnet import configs | |
| from official.nlp.nhnet import decoder | |
| from official.nlp.nhnet import utils | |
| from official.nlp.transformer import beam_search | |
| def embedding_linear(embedding_matrix, x): | |
| """Uses embeddings as linear transformation weights.""" | |
| with tf.name_scope("presoftmax_linear"): | |
| batch_size = tf.shape(x)[0] | |
| length = tf.shape(x)[1] | |
| hidden_size = tf.shape(x)[2] | |
| vocab_size = tf.shape(embedding_matrix)[0] | |
| x = tf.reshape(x, [-1, hidden_size]) | |
| logits = tf.matmul(x, embedding_matrix, transpose_b=True) | |
| return tf.reshape(logits, [batch_size, length, vocab_size]) | |
| def _add_sos_to_seq(seq, start_token_id): | |
| """Add a start sequence token while keeping seq length.""" | |
| batch_size = tf.shape(seq)[0] | |
| seq_len = tf.shape(seq)[1] | |
| sos_ids = tf.ones([batch_size], tf.int32) * start_token_id | |
| targets = tf.concat([tf.expand_dims(sos_ids, axis=1), seq], axis=1) | |
| targets = targets[:, :-1] | |
| tf.assert_equal(tf.shape(targets), (batch_size, seq_len)) | |
| return targets | |
| def remove_sos_from_seq(seq, pad_token_id): | |
| """Remove the start sequence token while keeping seq length.""" | |
| batch_size, seq_len = tf_utils.get_shape_list(seq, expected_rank=2) | |
| # remove <s> | |
| targets = seq[:, 1:] | |
| # pad | |
| pad_ids = tf.ones([batch_size], tf.int32) * pad_token_id | |
| targets = tf.concat([targets, tf.expand_dims(pad_ids, axis=1)], axis=1) | |
| tf.assert_equal(tf.shape(targets), (batch_size, seq_len)) | |
| return targets | |
| class Bert2Bert(tf.keras.Model): | |
| """Bert2Bert encoder decoder model for training.""" | |
| def __init__(self, params, bert_layer, decoder_layer, name=None): | |
| super(Bert2Bert, self).__init__(name=name) | |
| self.params = params | |
| if not bert_layer.built: | |
| raise ValueError("bert_layer should be built.") | |
| if not decoder_layer.built: | |
| raise ValueError("decoder_layer should be built.") | |
| self.bert_layer = bert_layer | |
| self.decoder_layer = decoder_layer | |
| def get_config(self): | |
| return {"params": self.params.as_dict()} | |
| def get_decode_logits(self, | |
| decoder_inputs, | |
| ids, | |
| decoder_self_attention_bias, | |
| step, | |
| cache=None): | |
| if cache: | |
| if self.params.get("padded_decode", False): | |
| bias_shape = decoder_self_attention_bias.shape.as_list() | |
| self_attention_bias = tf.slice( | |
| decoder_self_attention_bias, [0, 0, step, 0], | |
| [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) | |
| else: | |
| self_attention_bias = decoder_self_attention_bias[:, :, step:step + | |
| 1, :step + 1] | |
| # Sets decoder input to the last generated IDs. | |
| decoder_input = ids[:, -1:] | |
| else: | |
| self_attention_bias = decoder_self_attention_bias[:, :, :step + 1, :step + | |
| 1] | |
| decoder_input = ids | |
| decoder_inputs["target_ids"] = decoder_input | |
| decoder_inputs["self_attention_bias"] = self_attention_bias | |
| if cache: | |
| decoder_outputs = self.decoder_layer( | |
| decoder_inputs, | |
| cache, | |
| decode_loop_step=step, | |
| padded_decode=self.params.get("padded_decode", False)) | |
| else: | |
| decoder_outputs = self.decoder_layer(decoder_inputs) | |
| logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, | |
| decoder_outputs[:, -1:, :]) | |
| logits = tf.squeeze(logits, axis=[1]) | |
| return logits | |
| def _get_symbols_to_logits_fn(self, max_decode_length): | |
| """Returns a decoding function that calculates logits of the next tokens.""" | |
| # Max decode length should be smaller than the positional embedding max | |
| # sequence length. | |
| decoder_self_attention_bias = decoder.get_attention_bias( | |
| input_tensor=None, | |
| bias_type="decoder_self", | |
| max_length=max_decode_length) | |
| def _symbols_to_logits_fn(ids, i, cache): | |
| """Generate logits for next candidate IDs. | |
| Args: | |
| ids: Current decoded sequences. int tensor with shape [batch_size * | |
| beam_size, i + 1] | |
| i: Loop index | |
| cache: dictionary of values storing the encoder output, encoder-decoder | |
| attention bias, and previous decoder attention values. | |
| Returns: | |
| Tuple of | |
| (logits with shape [batch_size * beam_size, vocab_size], | |
| updated cache values) | |
| """ | |
| decoder_inputs = dict( | |
| all_encoder_outputs=cache["all_encoder_outputs"], | |
| attention_bias=cache["attention_bias"]) | |
| logits = self.get_decode_logits( | |
| decoder_inputs, | |
| ids, | |
| decoder_self_attention_bias, | |
| step=i, | |
| cache=cache if self.params.use_cache else None) | |
| return logits, cache | |
| return _symbols_to_logits_fn | |
| def train_decode(self, decode_outputs): | |
| logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, | |
| decode_outputs) | |
| decode_output_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32) | |
| output_log_probs = tf.nn.log_softmax(logits, axis=-1) | |
| return logits, decode_output_ids, output_log_probs | |
| def predict_decode(self, start_token_ids, cache): | |
| symbols_to_logits_fn = self._get_symbols_to_logits_fn(self.params.len_title) | |
| # Use beam search to find the top beam_size sequences and scores. | |
| decoded_ids, scores = beam_search.sequence_beam_search( | |
| symbols_to_logits_fn=symbols_to_logits_fn, | |
| initial_ids=start_token_ids, | |
| initial_cache=cache, | |
| vocab_size=self.params.vocab_size, | |
| beam_size=self.params.beam_size, | |
| alpha=self.params.alpha, | |
| max_decode_length=self.params.len_title, | |
| padded_decode=self.params.get("padded_decode", False), | |
| eos_id=self.params.end_token_id) | |
| return decoded_ids, scores | |
| def _get_logits_for_decode_ids(self, decoder_inputs, top_decoded_ids): | |
| """Returns the log probabilities for ids.""" | |
| target_ids = _add_sos_to_seq(top_decoded_ids, self.params.start_token_id) | |
| decoder_inputs["self_attention_bias"] = decoder.get_attention_bias( | |
| target_ids, bias_type="decoder_self") | |
| decoder_inputs["target_ids"] = target_ids | |
| decoder_outputs = self.decoder_layer(decoder_inputs) | |
| logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, | |
| decoder_outputs) | |
| return logits | |
| def _init_cache(self, batch_size): | |
| num_heads = self.params.num_decoder_attn_heads | |
| dim_per_head = self.params.hidden_size // num_heads | |
| init_decode_length = ( | |
| self.params.len_title if self.params.get("padded_decode", False) else 0) | |
| cache = {} | |
| for layer in range(self.params.num_decoder_layers): | |
| cache[str(layer)] = { | |
| "key": | |
| tf.zeros( | |
| [batch_size, init_decode_length, num_heads, dim_per_head], | |
| dtype=tf.float32), | |
| "value": | |
| tf.zeros( | |
| [batch_size, init_decode_length, num_heads, dim_per_head], | |
| dtype=tf.float32) | |
| } | |
| return cache | |
| def call(self, inputs, mode="train"): | |
| """Implements call(). | |
| Args: | |
| inputs: a dictionary of tensors. | |
| mode: string, an enum for mode, train/eval. | |
| Returns: | |
| logits, decode_output_ids, output_log_probs for training. top_decoded_ids | |
| for eval. | |
| """ | |
| input_ids = inputs["input_ids"] | |
| input_mask = inputs["input_mask"] | |
| segment_ids = inputs["segment_ids"] | |
| all_encoder_outputs, _ = self.bert_layer( | |
| [input_ids, input_mask, segment_ids]) | |
| if mode not in ("train", "eval", "predict"): | |
| raise ValueError("Invalid call mode: %s" % mode) | |
| encoder_decoder_attention_bias = decoder.get_attention_bias( | |
| input_ids, | |
| bias_type="single_cross", | |
| padding_value=self.params.pad_token_id) | |
| if mode == "train": | |
| self_attention_bias = decoder.get_attention_bias( | |
| inputs["target_ids"], bias_type="decoder_self") | |
| decoder_inputs = dict( | |
| attention_bias=encoder_decoder_attention_bias, | |
| all_encoder_outputs=all_encoder_outputs, | |
| target_ids=inputs["target_ids"], | |
| self_attention_bias=self_attention_bias) | |
| decoder_outputs = self.decoder_layer(decoder_inputs) | |
| return self.train_decode(decoder_outputs) | |
| batch_size = tf.shape(input_ids)[0] | |
| start_token_ids = tf.ones([batch_size], | |
| tf.int32) * self.params.start_token_id | |
| # Add encoder output and attention bias to the cache. | |
| if self.params.use_cache: | |
| cache = self._init_cache(batch_size) | |
| else: | |
| cache = {} | |
| cache["all_encoder_outputs"] = all_encoder_outputs | |
| cache["attention_bias"] = encoder_decoder_attention_bias | |
| decoded_ids, scores = self.predict_decode(start_token_ids, cache) | |
| if mode == "predict": | |
| return decoded_ids[:, :self.params.beam_size, | |
| 1:], scores[:, :self.params.beam_size] | |
| decoder_inputs = dict( | |
| attention_bias=encoder_decoder_attention_bias, | |
| all_encoder_outputs=all_encoder_outputs) | |
| top_decoded_ids = decoded_ids[:, 0, 1:] | |
| return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids) | |
| class NHNet(Bert2Bert): | |
| """NHNet model which performs multi-doc decoding.""" | |
| def __init__(self, params, bert_layer, decoder_layer, name=None): | |
| super(NHNet, self).__init__(params, bert_layer, decoder_layer, name=name) | |
| self.doc_attention = multi_channel_attention.VotingAttention( | |
| num_heads=params.num_decoder_attn_heads, | |
| head_size=params.hidden_size // params.num_decoder_attn_heads) | |
| def _expand_doc_attention_probs(self, doc_attention_probs, target_length): | |
| """Expands doc attention probs to fit the decoding sequence length.""" | |
| doc_attention_probs = tf.expand_dims( | |
| doc_attention_probs, axis=[1]) # [B, 1, A] | |
| doc_attention_probs = tf.expand_dims( | |
| doc_attention_probs, axis=[2]) # [B, 1, 1, A] | |
| return tf.tile(doc_attention_probs, | |
| [1, self.params.num_decoder_attn_heads, target_length, 1]) | |
| def _get_symbols_to_logits_fn(self, max_decode_length): | |
| """Returns a decoding function that calculates logits of the next tokens.""" | |
| # Max decode length should be smaller than the positional embedding max | |
| # sequence length. | |
| decoder_self_attention_bias = decoder.get_attention_bias( | |
| input_tensor=None, | |
| bias_type="decoder_self", | |
| max_length=max_decode_length) | |
| def _symbols_to_logits_fn(ids, i, cache): | |
| """Generate logits for next candidate IDs.""" | |
| if self.params.use_cache: | |
| target_length = 1 | |
| else: | |
| target_length = i + 1 | |
| decoder_inputs = dict( | |
| doc_attention_probs=self._expand_doc_attention_probs( | |
| cache["doc_attention_probs"], target_length), | |
| all_encoder_outputs=cache["all_encoder_outputs"], | |
| attention_bias=cache["attention_bias"]) | |
| logits = self.get_decode_logits( | |
| decoder_inputs, | |
| ids, | |
| decoder_self_attention_bias, | |
| step=i, | |
| cache=cache if self.params.use_cache else None) | |
| return logits, cache | |
| return _symbols_to_logits_fn | |
| def call(self, inputs, mode="training"): | |
| input_shape = tf_utils.get_shape_list(inputs["input_ids"], expected_rank=3) | |
| batch_size, num_docs, len_passage = (input_shape[0], input_shape[1], | |
| input_shape[2]) | |
| input_ids = tf.reshape(inputs["input_ids"], [-1, len_passage]) | |
| input_mask = tf.reshape(inputs["input_mask"], [-1, len_passage]) | |
| segment_ids = tf.reshape(inputs["segment_ids"], [-1, len_passage]) | |
| all_encoder_outputs, _ = self.bert_layer( | |
| [input_ids, input_mask, segment_ids]) | |
| encoder_outputs = tf.reshape( | |
| all_encoder_outputs[-1], | |
| [batch_size, num_docs, len_passage, self.params.hidden_size]) | |
| doc_attention_mask = tf.reshape( | |
| tf.cast( | |
| tf.math.count_nonzero(input_mask, axis=1, dtype=tf.int32) > 2, | |
| tf.int32), [batch_size, num_docs]) | |
| doc_attention_probs = self.doc_attention(encoder_outputs, | |
| doc_attention_mask) | |
| encoder_decoder_attention_bias = decoder.get_attention_bias( | |
| inputs["input_ids"], | |
| bias_type="multi_cross", | |
| padding_value=self.params.pad_token_id) | |
| if mode == "train": | |
| target_length = tf_utils.get_shape_list( | |
| inputs["target_ids"], expected_rank=2)[1] | |
| doc_attention_probs = self._expand_doc_attention_probs( | |
| doc_attention_probs, target_length) | |
| self_attention_bias = decoder.get_attention_bias( | |
| inputs["target_ids"], bias_type="decoder_self") | |
| decoder_inputs = dict( | |
| attention_bias=encoder_decoder_attention_bias, | |
| self_attention_bias=self_attention_bias, | |
| target_ids=inputs["target_ids"], | |
| all_encoder_outputs=encoder_outputs, | |
| doc_attention_probs=doc_attention_probs) | |
| decoder_outputs = self.decoder_layer(decoder_inputs) | |
| return self.train_decode(decoder_outputs) | |
| # Adds encoder output and attention bias to the cache. | |
| if self.params.use_cache: | |
| cache = self._init_cache(batch_size) | |
| else: | |
| cache = {} | |
| cache["all_encoder_outputs"] = [encoder_outputs] | |
| cache["attention_bias"] = encoder_decoder_attention_bias | |
| cache["doc_attention_probs"] = doc_attention_probs | |
| start_token_ids = tf.ones([batch_size], | |
| tf.int32) * self.params.start_token_id | |
| decoded_ids, scores = self.predict_decode(start_token_ids, cache) | |
| if mode == "predict": | |
| return decoded_ids[:, :self.params.beam_size, | |
| 1:], scores[:, :self.params.beam_size] | |
| top_decoded_ids = decoded_ids[:, 0, 1:] | |
| target_length = tf_utils.get_shape_list(top_decoded_ids)[-1] | |
| decoder_inputs = dict( | |
| attention_bias=encoder_decoder_attention_bias, | |
| all_encoder_outputs=[encoder_outputs], | |
| doc_attention_probs=self._expand_doc_attention_probs( | |
| doc_attention_probs, target_length)) | |
| return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids) | |
| def get_bert2bert_layers(params: configs.BERT2BERTConfig): | |
| """Creates a Bert2Bert stem model and returns Bert encoder/decoder. | |
| We use funtional-style to create stem model because we need to make all layers | |
| built to restore variables in a customized way. The layers are called with | |
| placeholder inputs to make them fully built. | |
| Args: | |
| params: ParamsDict. | |
| Returns: | |
| two keras Layers, bert_model_layer and decoder_layer | |
| """ | |
| input_ids = tf.keras.layers.Input( | |
| shape=(None,), name="input_ids", dtype=tf.int32) | |
| input_mask = tf.keras.layers.Input( | |
| shape=(None,), name="input_mask", dtype=tf.int32) | |
| segment_ids = tf.keras.layers.Input( | |
| shape=(None,), name="segment_ids", dtype=tf.int32) | |
| target_ids = tf.keras.layers.Input( | |
| shape=(None,), name="target_ids", dtype=tf.int32) | |
| bert_config = utils.get_bert_config_from_params(params) | |
| bert_model_layer = networks.TransformerEncoder( | |
| vocab_size=bert_config.vocab_size, | |
| hidden_size=bert_config.hidden_size, | |
| num_layers=bert_config.num_hidden_layers, | |
| num_attention_heads=bert_config.num_attention_heads, | |
| intermediate_size=bert_config.intermediate_size, | |
| activation=tf_utils.get_activation(bert_config.hidden_act), | |
| dropout_rate=bert_config.hidden_dropout_prob, | |
| attention_dropout_rate=bert_config.attention_probs_dropout_prob, | |
| sequence_length=None, | |
| max_sequence_length=bert_config.max_position_embeddings, | |
| type_vocab_size=bert_config.type_vocab_size, | |
| initializer=tf.keras.initializers.TruncatedNormal( | |
| stddev=bert_config.initializer_range), | |
| return_all_encoder_outputs=True, | |
| name="bert_encoder") | |
| all_encoder_outputs, _ = bert_model_layer( | |
| [input_ids, input_mask, segment_ids]) | |
| # pylint: disable=protected-access | |
| decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) | |
| # pylint: enable=protected-access | |
| cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")( | |
| input_ids) | |
| self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( | |
| target_ids) | |
| decoder_inputs = dict( | |
| attention_bias=cross_attention_bias, | |
| self_attention_bias=self_attention_bias, | |
| target_ids=target_ids, | |
| all_encoder_outputs=all_encoder_outputs) | |
| _ = decoder_layer(decoder_inputs) | |
| return bert_model_layer, decoder_layer | |
| def get_nhnet_layers(params: configs.NHNetConfig): | |
| """Creates a Mult-doc encoder/decoder. | |
| Args: | |
| params: ParamsDict. | |
| Returns: | |
| two keras Layers, bert_model_layer and decoder_layer | |
| """ | |
| input_ids = tf.keras.layers.Input( | |
| shape=(None,), name="input_ids", dtype=tf.int32) | |
| input_mask = tf.keras.layers.Input( | |
| shape=(None,), name="input_mask", dtype=tf.int32) | |
| segment_ids = tf.keras.layers.Input( | |
| shape=(None,), name="segment_ids", dtype=tf.int32) | |
| bert_config = utils.get_bert_config_from_params(params) | |
| bert_model_layer = networks.TransformerEncoder( | |
| vocab_size=bert_config.vocab_size, | |
| hidden_size=bert_config.hidden_size, | |
| num_layers=bert_config.num_hidden_layers, | |
| num_attention_heads=bert_config.num_attention_heads, | |
| intermediate_size=bert_config.intermediate_size, | |
| activation=tf_utils.get_activation(bert_config.hidden_act), | |
| dropout_rate=bert_config.hidden_dropout_prob, | |
| attention_dropout_rate=bert_config.attention_probs_dropout_prob, | |
| sequence_length=None, | |
| max_sequence_length=bert_config.max_position_embeddings, | |
| type_vocab_size=bert_config.type_vocab_size, | |
| initializer=tf.keras.initializers.TruncatedNormal( | |
| stddev=bert_config.initializer_range), | |
| return_all_encoder_outputs=True, | |
| name="bert_encoder") | |
| bert_model_layer([input_ids, input_mask, segment_ids]) | |
| input_ids = tf.keras.layers.Input( | |
| shape=(None, None), name="input_ids", dtype=tf.int32) | |
| all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size), | |
| dtype=tf.float32) | |
| target_ids = tf.keras.layers.Input( | |
| shape=(None,), name="target_ids", dtype=tf.int32) | |
| doc_attention_probs = tf.keras.layers.Input( | |
| (params.num_decoder_attn_heads, None, None), dtype=tf.float32) | |
| # pylint: disable=protected-access | |
| decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) | |
| # pylint: enable=protected-access | |
| cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")( | |
| input_ids) | |
| self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( | |
| target_ids) | |
| decoder_inputs = dict( | |
| attention_bias=cross_attention_bias, | |
| self_attention_bias=self_attention_bias, | |
| target_ids=target_ids, | |
| all_encoder_outputs=all_encoder_outputs, | |
| doc_attention_probs=doc_attention_probs) | |
| _ = decoder_layer(decoder_inputs) | |
| return bert_model_layer, decoder_layer | |
| def create_transformer_model(params, | |
| init_checkpoint: Optional[Text] = None | |
| ) -> tf.keras.Model: | |
| """A helper to create Transformer model.""" | |
| bert_layer, decoder_layer = get_bert2bert_layers(params=params) | |
| model = Bert2Bert( | |
| params=params, | |
| bert_layer=bert_layer, | |
| decoder_layer=decoder_layer, | |
| name="transformer") | |
| if init_checkpoint: | |
| logging.info( | |
| "Checkpoint file %s found and restoring from " | |
| "initial checkpoint.", init_checkpoint) | |
| ckpt = tf.train.Checkpoint(model=model) | |
| ckpt.restore(init_checkpoint).expect_partial() | |
| return model | |
| def create_bert2bert_model( | |
| params: configs.BERT2BERTConfig, | |
| cls=Bert2Bert, | |
| init_checkpoint: Optional[Text] = None) -> tf.keras.Model: | |
| """A helper to create Bert2Bert model.""" | |
| bert_layer, decoder_layer = get_bert2bert_layers(params=params) | |
| if init_checkpoint: | |
| utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer, | |
| init_checkpoint) | |
| return cls( | |
| params=params, | |
| bert_layer=bert_layer, | |
| decoder_layer=decoder_layer, | |
| name="bert2bert") | |
| def create_nhnet_model( | |
| params: configs.NHNetConfig, | |
| cls=NHNet, | |
| init_checkpoint: Optional[Text] = None) -> tf.keras.Model: | |
| """A helper to create NHNet model.""" | |
| bert_layer, decoder_layer = get_nhnet_layers(params=params) | |
| model = cls( | |
| params=params, | |
| bert_layer=bert_layer, | |
| decoder_layer=decoder_layer, | |
| name="nhnet") | |
| if init_checkpoint: | |
| logging.info( | |
| "Checkpoint file %s found and restoring from " | |
| "initial checkpoint.", init_checkpoint) | |
| if params.init_from_bert2bert: | |
| ckpt = tf.train.Checkpoint(model=model) | |
| ckpt.restore(init_checkpoint).assert_existing_objects_matched() | |
| else: | |
| utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer, | |
| init_checkpoint) | |
| return model | |
| def get_model_params(model: Optional[Text] = "bert2bert", | |
| config_class=None) -> params_dict.ParamsDict: | |
| """Helper function to convert config file to ParamsDict.""" | |
| if model == "bert2bert": | |
| return configs.BERT2BERTConfig() | |
| elif model == "nhnet": | |
| return configs.NHNetConfig() | |
| elif config_class: | |
| return config_class() | |
| else: | |
| raise KeyError("The model type is not defined: %s" % model) | |
| def create_model(model_type: Text, | |
| params, | |
| init_checkpoint: Optional[Text] = None): | |
| """A factory function to create different types of models.""" | |
| if model_type == "bert2bert": | |
| return create_bert2bert_model(params, init_checkpoint=init_checkpoint) | |
| elif model_type == "nhnet": | |
| return create_nhnet_model(params, init_checkpoint=init_checkpoint) | |
| elif "transformer" in model_type: | |
| return create_transformer_model( | |
| params, init_checkpoint=init_checkpoint) | |
| else: | |
| raise KeyError("The model type is not defined: %s" % model_type) | |