Spaces:
Runtime error
Runtime error
from torch import nn | |
from transformers import RobertaModel, RobertaConfig | |
class RobertaSNLI(nn.Module): | |
def __init__(self): | |
super(RobertaSNLI, self).__init__() | |
config = RobertaConfig.from_pretrained('roberta-base') | |
config.output_attentions = True # activer sortie des poids d'attention | |
config.max_position_embeddings = 130 # gérer la longueur des séquences | |
config.hidden_size = 256 # taille des états cachés du modèle | |
config.num_hidden_layers = 4 # nombre de couches cachées dans le transformateur | |
config.intermediate_size = 512 # taille couche intermédiaire dans modèle de transformateur | |
config.num_attention_heads = 4 # nombre de têtes d'attentions | |
self.roberta = RobertaModel(config) | |
self.roberta.requires_grad = True | |
self.output = nn.Linear(256, 3) # couche de sortie linéaire. Entrée la taille des états cachées et 3 sorties | |
def forward(self, input_ids, attention_mask=None): | |
outputs = self.roberta(input_ids, attention_mask=attention_mask) | |
roberta_out = outputs[0] # séquence des états cachés à la sortie de la dernière couche | |
attentions = outputs.attentions # poids d'attention du modèle RoBERTa | |
return self.output(roberta_out[:, 0]), attentions | |