danfu09 commited on
Commit
0f1a7ee
1 Parent(s): 2f57ef2

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +114 -0
bert_layers.py CHANGED
@@ -870,3 +870,117 @@ class BertForSequenceClassification(BertPreTrainedModel):
870
  hidden_states=None,
871
  attentions=None,
872
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
870
  hidden_states=None,
871
  attentions=None,
872
  )
873
+
874
+ class BertForTextEncoding(BertPreTrainedModel):
875
+
876
+ def __init__(self, config):
877
+ super().__init__(config)
878
+
879
+ if config.is_decoder:
880
+ warnings.warn(
881
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
882
+ 'bi-directional self-attention.')
883
+
884
+ self.bert = BertModel(config, add_pooling_layer=False)
885
+
886
+ # Initialize weights and apply final processing
887
+ self.post_init()
888
+
889
+ @classmethod
890
+ def from_composer(cls,
891
+ pretrained_checkpoint,
892
+ state_dict=None,
893
+ cache_dir=None,
894
+ from_tf=False,
895
+ config=None,
896
+ *inputs,
897
+ **kwargs):
898
+ """Load from pre-trained."""
899
+ model = cls(config, *inputs, **kwargs)
900
+ if from_tf:
901
+ raise ValueError(
902
+ 'TensorFlow is not supported.')
903
+
904
+ state_dict = torch.load(pretrained_checkpoint)
905
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
906
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
907
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
908
+ strict=False)
909
+
910
+ if len(missing_keys) > 0:
911
+ logger.warning(
912
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
913
+ )
914
+ if len(unexpected_keys) > 0:
915
+ logger.warning(
916
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
917
+ )
918
+
919
+ return model
920
+
921
+ def forward(
922
+ self,
923
+ input_ids: Optional[torch.Tensor] = None,
924
+ attention_mask: Optional[torch.Tensor] = None,
925
+ token_type_ids: Optional[torch.Tensor] = None,
926
+ position_ids: Optional[torch.Tensor] = None,
927
+ head_mask: Optional[torch.Tensor] = None,
928
+ inputs_embeds: Optional[torch.Tensor] = None,
929
+ encoder_hidden_states: Optional[torch.Tensor] = None,
930
+ encoder_attention_mask: Optional[torch.Tensor] = None,
931
+ labels: Optional[torch.Tensor] = None,
932
+ output_attentions: Optional[bool] = None,
933
+ output_hidden_states: Optional[bool] = None,
934
+ return_dict: Optional[bool] = None,
935
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
936
+
937
+ if (input_ids is not None) == (inputs_embeds is not None):
938
+ raise ValueError('Must specify either input_ids or input_embeds!')
939
+
940
+ if labels is None:
941
+ masked_tokens_mask = None
942
+ else:
943
+ masked_tokens_mask = labels > 0
944
+
945
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
946
+
947
+ outputs = self.bert(
948
+ input_ids,
949
+ attention_mask=attention_mask,
950
+ token_type_ids=token_type_ids,
951
+ position_ids=position_ids,
952
+ head_mask=head_mask,
953
+ inputs_embeds=inputs_embeds,
954
+ encoder_hidden_states=encoder_hidden_states,
955
+ encoder_attention_mask=encoder_attention_mask,
956
+ output_attentions=output_attentions,
957
+ output_hidden_states=output_hidden_states,
958
+ return_dict=return_dict,
959
+ masked_tokens_mask=masked_tokens_mask,
960
+ )
961
+
962
+ pooled_output = outputs[1]
963
+
964
+ return {"sentence_embedding": pooled_output}
965
+
966
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
967
+ attention_mask: torch.Tensor,
968
+ **model_kwargs):
969
+ input_shape = input_ids.shape
970
+ effective_batch_size = input_shape[0]
971
+
972
+ # add a dummy token
973
+ if self.config.pad_token_id is None:
974
+ raise ValueError('The PAD token should be defined for generation')
975
+
976
+ attention_mask = torch.cat([
977
+ attention_mask,
978
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
979
+ ], dim=-1)
980
+ dummy_token = torch.full((effective_batch_size, 1),
981
+ self.config.pad_token_id,
982
+ dtype=torch.long,
983
+ device=input_ids.device)
984
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
985
+
986
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}