Update bert_layers.py
Browse files- bert_layers.py +114 -0
bert_layers.py
CHANGED
@@ -870,3 +870,117 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
870 |
hidden_states=None,
|
871 |
attentions=None,
|
872 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
870 |
hidden_states=None,
|
871 |
attentions=None,
|
872 |
)
|
873 |
+
|
874 |
+
class BertForTextEncoding(BertPreTrainedModel):
|
875 |
+
|
876 |
+
def __init__(self, config):
|
877 |
+
super().__init__(config)
|
878 |
+
|
879 |
+
if config.is_decoder:
|
880 |
+
warnings.warn(
|
881 |
+
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
882 |
+
'bi-directional self-attention.')
|
883 |
+
|
884 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
885 |
+
|
886 |
+
# Initialize weights and apply final processing
|
887 |
+
self.post_init()
|
888 |
+
|
889 |
+
@classmethod
|
890 |
+
def from_composer(cls,
|
891 |
+
pretrained_checkpoint,
|
892 |
+
state_dict=None,
|
893 |
+
cache_dir=None,
|
894 |
+
from_tf=False,
|
895 |
+
config=None,
|
896 |
+
*inputs,
|
897 |
+
**kwargs):
|
898 |
+
"""Load from pre-trained."""
|
899 |
+
model = cls(config, *inputs, **kwargs)
|
900 |
+
if from_tf:
|
901 |
+
raise ValueError(
|
902 |
+
'TensorFlow is not supported.')
|
903 |
+
|
904 |
+
state_dict = torch.load(pretrained_checkpoint)
|
905 |
+
# If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
|
906 |
+
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
|
907 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
908 |
+
strict=False)
|
909 |
+
|
910 |
+
if len(missing_keys) > 0:
|
911 |
+
logger.warning(
|
912 |
+
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
913 |
+
)
|
914 |
+
if len(unexpected_keys) > 0:
|
915 |
+
logger.warning(
|
916 |
+
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
|
917 |
+
)
|
918 |
+
|
919 |
+
return model
|
920 |
+
|
921 |
+
def forward(
|
922 |
+
self,
|
923 |
+
input_ids: Optional[torch.Tensor] = None,
|
924 |
+
attention_mask: Optional[torch.Tensor] = None,
|
925 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
926 |
+
position_ids: Optional[torch.Tensor] = None,
|
927 |
+
head_mask: Optional[torch.Tensor] = None,
|
928 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
929 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
930 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
931 |
+
labels: Optional[torch.Tensor] = None,
|
932 |
+
output_attentions: Optional[bool] = None,
|
933 |
+
output_hidden_states: Optional[bool] = None,
|
934 |
+
return_dict: Optional[bool] = None,
|
935 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
936 |
+
|
937 |
+
if (input_ids is not None) == (inputs_embeds is not None):
|
938 |
+
raise ValueError('Must specify either input_ids or input_embeds!')
|
939 |
+
|
940 |
+
if labels is None:
|
941 |
+
masked_tokens_mask = None
|
942 |
+
else:
|
943 |
+
masked_tokens_mask = labels > 0
|
944 |
+
|
945 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
946 |
+
|
947 |
+
outputs = self.bert(
|
948 |
+
input_ids,
|
949 |
+
attention_mask=attention_mask,
|
950 |
+
token_type_ids=token_type_ids,
|
951 |
+
position_ids=position_ids,
|
952 |
+
head_mask=head_mask,
|
953 |
+
inputs_embeds=inputs_embeds,
|
954 |
+
encoder_hidden_states=encoder_hidden_states,
|
955 |
+
encoder_attention_mask=encoder_attention_mask,
|
956 |
+
output_attentions=output_attentions,
|
957 |
+
output_hidden_states=output_hidden_states,
|
958 |
+
return_dict=return_dict,
|
959 |
+
masked_tokens_mask=masked_tokens_mask,
|
960 |
+
)
|
961 |
+
|
962 |
+
pooled_output = outputs[1]
|
963 |
+
|
964 |
+
return {"sentence_embedding": pooled_output}
|
965 |
+
|
966 |
+
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
967 |
+
attention_mask: torch.Tensor,
|
968 |
+
**model_kwargs):
|
969 |
+
input_shape = input_ids.shape
|
970 |
+
effective_batch_size = input_shape[0]
|
971 |
+
|
972 |
+
# add a dummy token
|
973 |
+
if self.config.pad_token_id is None:
|
974 |
+
raise ValueError('The PAD token should be defined for generation')
|
975 |
+
|
976 |
+
attention_mask = torch.cat([
|
977 |
+
attention_mask,
|
978 |
+
attention_mask.new_zeros((attention_mask.shape[0], 1))
|
979 |
+
], dim=-1)
|
980 |
+
dummy_token = torch.full((effective_batch_size, 1),
|
981 |
+
self.config.pad_token_id,
|
982 |
+
dtype=torch.long,
|
983 |
+
device=input_ids.device)
|
984 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
985 |
+
|
986 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|