|
from math import sqrt,log |
|
import sys |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.functional import softmax,relu,linear, gelu |
|
from common import PositionalEncoding |
|
from hopfield import HopfieldLayer, HopfieldMHA, HopfieldReLU, HopfieldSoftmax |
|
from configuration_energy import BertEnergyConfig |
|
from torch.cuda.amp import autocast |
|
import yaml |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutput |
|
|
|
ACT2FN={'relu': relu, 'gelu': gelu, 'softmax': softmax} |
|
|
|
class BertModel(PreTrainedModel): |
|
""" Backbone of standard BERT model |
|
outputs : last hidden state, history""" |
|
|
|
config_class = BertEnergyConfig |
|
|
|
def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): |
|
super().__init__(config) |
|
|
|
self.Emb_in = nn.Embedding(config.vocabulary_size, config.embedding_dim, padding_idx=pad_idx) |
|
self.posn = PositionalEncoding(config.embedding_dim, max_len=config.block_size,dropout=config.dropout) if config.positional else None |
|
|
|
if config.share_layers: |
|
self.embedding_hidden_in = nn.Linear(config.embedding_dim, config.forward_memories) if config.share_layers else None |
|
|
|
self.embed_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm) |
|
self.embed_dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
self.num_layers = config.num_layers |
|
self.share_layers = config.share_layers |
|
|
|
if config.share_layers: |
|
layer = nn.TransformerEncoderLayer(config.forward_memories, |
|
config.num_heads, |
|
activation=config.activation, |
|
dim_feedforward=config.forward_memories*4, |
|
dropout=config.dropout, |
|
layer_norm_eps=config.layer_norm, |
|
batch_first=True, |
|
norm_first=True, |
|
) |
|
self.layers = nn.ModuleList([layer]) |
|
|
|
else: |
|
self.layers = nn.ModuleList([nn.TransformerEncoderLayer(config.embedding_dim, |
|
config.num_heads, |
|
dim_feedforward=config.forward_memories*4, |
|
dropout=config.dropout, |
|
layer_norm_eps=config.layer_norm, |
|
batch_first=True, |
|
norm_first=True, |
|
) for _ in range(config.num_layers)]) |
|
|
|
def forward(self,input_ids, attention_mask=None, **kwargs): |
|
""" Warning : expect attention mask with 0 pad tokens -> mismatch Pytorch/HF tokenizer""" |
|
|
|
xbatch = self.Emb_in(input_ids) |
|
|
|
if self.posn: |
|
X = xbatch + self.posn(xbatch) |
|
else: |
|
X = xbatch |
|
|
|
|
|
if self.share_layers: |
|
X = self.embed_norm(X) |
|
X = self.embed_dropout(X) |
|
X = self.embedding_hidden_in(X) |
|
|
|
history = None if self.training else [X] |
|
|
|
|
|
attention_mask = ~attention_mask.bool() |
|
for i in range(self.num_layers): |
|
if self.share_layers: |
|
layer = self.layers[0] |
|
else: |
|
layer = self.layers[i] |
|
X = layer(X, src_key_padding_mask=attention_mask) |
|
|
|
if not self.training: |
|
history.append(X) |
|
|
|
|
|
return BaseModelOutput(last_hidden_state=X, |
|
hidden_states=history, |
|
attentions=None) |
|
|
|
class BertModelForMaskedLM(PreTrainedModel): |
|
""" Bert model to be trained on the MLM task. |
|
Based on the backbone Bert model + projection on the vocabulary with tied weight and norm |
|
outputs: cross entropy loss / logits / hidden states |
|
""" |
|
|
|
config_class = BertEnergyConfig |
|
ignore_index = -100 |
|
|
|
_tied_weights_keys = ["Emb_out.weight", "Emb_out.bias"] |
|
|
|
def __init__(self, config, add_pooling_layer=True, pad_idx=None): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.model = BertModel(config, pad_idx=pad_idx) |
|
|
|
self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm) |
|
self.dense = nn.Linear(config.forward_memories, config.embedding_dim) |
|
self.activation = ACT2FN[config.activation] |
|
""" |
|
if config.tie_weights: |
|
self.Emb_out = nn.Linear(config.embedding_dim, config.vocabulary_size, bias=False) |
|
self.tie_weights() |
|
else: |
|
self.Emb_out = nn.Linear(config.embedding_dim, config.vocabulary_size) |
|
self.bias = nn.Parameter(torch.zeros(config.vocabulary_size)) |
|
self.Emb_out.bias = self.bias |
|
""" |
|
self.Emb_out = nn.Linear(config.forward_memories, config.vocabulary_size) |
|
self.bias = nn.Parameter(torch.zeros(config.vocabulary_size)) |
|
self.Emb_out.bias = self.bias |
|
|
|
def get_input_embeddings(self): |
|
return self.model.Emb_in |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.Emb_out = new_embeddings |
|
|
|
def forward(self,input_ids, attention_mask=None, labels=None, **kwargs): |
|
|
|
outputs = self.model(input_ids, attention_mask, **kwargs) |
|
last_hidden_state = outputs.last_hidden_state |
|
hidden_states = outputs.hidden_states |
|
attentions = outputs.attentions |
|
|
|
last_hidden_state = self.dense(last_hidden_state) |
|
last_hidden_state = self.activation(last_hidden_state) |
|
last_hidden_state = self.norm(last_hidden_state) |
|
|
|
""" |
|
if self.config.tie_weights: |
|
logits = last_hidden_state @ self.Emb_out.weight.transpose(-1,-2) |
|
else: |
|
logits = self.Emb_out(last_hidden_state) |
|
""" |
|
|
|
logits = self.Emb_out(last_hidden_state) |
|
|
|
loss = None |
|
|
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.config.vocabulary_size), labels.view(-1)) |
|
|
|
return MaskedLMOutput(loss=loss, |
|
logits=logits, |
|
hidden_states=hidden_states, |
|
attentions=attentions) |
|
|
|
|
|
class BertModelForSequenceClassification(PreTrainedModel): |
|
""" Bert model to be trained on Sequence classification tasks. |
|
Based on the backbone Bert model + projection on the vocabulary with tied weight and norm |
|
outputs: cross entropy loss / logits / hidden states |
|
""" |
|
|
|
config_class = BertEnergyConfig |
|
ignore_index = -100 |
|
|
|
def __init__(self, config, add_pooling_layer=True, pad_idx=None, |
|
num_labels=2, classifier_dropout=None, return_dict=True): |
|
super().__init__(config) |
|
self.config = config |
|
self.num_labels = num_labels |
|
self.classifier_dropout = classifier_dropout |
|
self.return_dict = return_dict |
|
|
|
self.model = BertModel(config, pad_idx=pad_idx) |
|
self.dense = nn.Linear(config.forward_memories, config.forward_memories) |
|
classifier_dropout = ( |
|
classifier_dropout if classifier_dropout is not None else config.dropout |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.forward_memories,num_labels) |
|
self.norm = nn.LayerNorm(config.embedding_dim) |
|
|
|
|
|
|
|
|
|
def forward(self,input_ids, labels=None, return_dict=False, **kwargs): |
|
|
|
outputs = self.model(input_ids, **kwargs) |
|
last_hidden_state = self.norm(outputs.last_hidden_state) |
|
|
|
x = last_hidden_state[:, 0, :] |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = torch.tanh(x) |
|
x = self.dropout(x) |
|
|
|
logits = self.classifier(x) |
|
hidden_states = outputs.hidden_states |
|
attentions = outputs.attentions |
|
|
|
loss = None |
|
|
|
if labels is not None: |
|
|
|
labels = labels.to(logits.device) |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def compute_loss(self, logits, labels): |
|
|
|
log_probs = -nn.functional.log_softmax(logits, dim=-1) |
|
if labels.dim() == log_probs.dim() - 1: |
|
labels = labels.unsqueeze(-1) |
|
|
|
padding_mask = labels.eq(self.ignore_index) |
|
|
|
|
|
labels = torch.clamp(labels, min=0) |
|
nll_loss = log_probs.gather(dim=-1, index=labels) |
|
nll_loss.masked_fill_(padding_mask, 0.0) |
|
num_active_elements = padding_mask.numel() - padding_mask.long().sum() |
|
nll_loss = nll_loss.sum() / num_active_elements |
|
return nll_loss |
|
|
|
|
|
class BertEnergyModel(PreTrainedModel): |
|
|
|
config_class = BertEnergyConfig |
|
|
|
def __init__(self, config, add_pooling_layer=True, pad_idx=None, **kwargs): |
|
super().__init__(config) |
|
|
|
self.Emb_in = nn.Embedding(config.vocabulary_size, config.embedding_dim, padding_idx=pad_idx) |
|
self.posn = PositionalEncoding(config.embedding_dim,max_len=config.block_size,dropout=config.dropout) if config.positional else None |
|
|
|
self.num_layers = config.num_layers |
|
self.layer = HopfieldLayer(config.embedding_dim,config.num_heads,forward_memories=config.forward_memories,forward_activation=config.activation,bias=config.bias,beta=config.beta,dropout=config.dropout) |
|
|
|
self.alpha = config.alpha |
|
|
|
def forward(self,input_ids, attention_mask=None, **kwargs): |
|
|
|
xbatch = self.Emb_in(input_ids) |
|
|
|
if self.posn: |
|
X = xbatch + self.posn(xbatch) |
|
else: |
|
X = xbatch |
|
|
|
history = None if self.training else [X] |
|
|
|
for _ in range(self.num_layers): |
|
|
|
X = X - self.alpha * self.layer(X, src_key_padding_mask=attention_mask, is_causal=False) |
|
if not self.training: |
|
history.append(X) |
|
|
|
return BaseModelOutput(last_hidden_state=X, |
|
hidden_states=history, |
|
attentions=None) |
|
|
|
|
|
class BertEnergyModelForMaskedLM(PreTrainedModel): |
|
|
|
config_class = BertEnergyConfig |
|
ignore_index = -100 |
|
|
|
_tied_weights_keys = ["Emb_out.weight", "Emb_out.bias"] |
|
|
|
def __init__(self, config, add_pooling_layer=True, pad_idx=None): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.model = BertEnergyModel(config, pad_idx=pad_idx) |
|
|
|
self.norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm) |
|
self.dense = nn.Linear(config.embedding_dim, config.embedding_dim) |
|
self.activation = ACT2FN[config.activation] |
|
|
|
self.Emb_out = nn.Linear(config.embedding_dim, config.vocabulary_size) |
|
self.bias = nn.Parameter(torch.zeros(config.vocabulary_size)) |
|
self.Emb_out.bias = self.bias |
|
|
|
|
|
def get_input_embeddings(self): |
|
return self.model.Emb_in |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.Emb_out = new_embeddings |
|
|
|
def forward(self,input_ids, attention_mask=None, labels=None, **kwargs ): |
|
|
|
outputs = self.model(input_ids , attention_mask=attention_mask) |
|
last_hidden_state = outputs.last_hidden_state |
|
hidden_states = outputs.hidden_states |
|
attentions = outputs.attentions |
|
|
|
last_hidden_state = self.dense(last_hidden_state) |
|
last_hidden_state = gelu(last_hidden_state) |
|
last_hidden_state = self.norm(last_hidden_state) |
|
|
|
|
|
if self.config.tie_weights: |
|
logits = last_hidden_state @ self.Emb_out.weight.transpose(-1,-2) |
|
else: |
|
logits = self.Emb_out(last_hidden_state) |
|
|
|
loss = None |
|
hidden_states = hidden_states |
|
attentions = None |
|
|
|
|
|
|
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.config.vocabulary_size), labels.view(-1)) |
|
|
|
return MaskedLMOutput(loss=loss, |
|
logits=logits, |
|
hidden_states=hidden_states, |
|
attentions=attentions) |
|
|
|
if __name__ == '__main__': |
|
|
|
def grads(f, x): |
|
""" Autograd used for the energy """ |
|
return torch.func.jacrev(f)(x) |
|
|
|
|
|
|
|
x = torch.randn(1,10) |
|
input_ids = torch.tensor([[3,12,44, 2]]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = HopfieldConfig(path="../lmconfig.yaml") |
|
print(config) |
|
|
|
mdl = HFHopfieldModel(config) |
|
mdl.eval() |
|
|
|
out = mdl(input_ids) |
|
print(out[0].mean()) |
|
mdl.save_pretrained("test_checkpoint") |
|
reloaded = HFHopfieldModel.from_pretrained("test_checkpoint") |
|
out_reloaded = reloaded(input_ids) |
|
print(out_reloaded[0].mean()) |
|
reloaded.to("cuda:0") |
|
print(reloaded(input_ids.to("cuda:0"))[0]) |
|
|