Spaces:
Build error
Build error
import comet | |
from typing import Dict | |
import torch | |
from comet.encoders.base import Encoder | |
from comet.encoders.bert import BERTEncoder | |
from transformers import AutoModel, AutoTokenizer | |
class robertaEncoder(BERTEncoder): | |
def __init__(self, pretrained_model: str) -> None: | |
super(Encoder, self).__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model) | |
self.model = AutoModel.from_pretrained( | |
pretrained_model, add_pooling_layer=False | |
) | |
self.model.encoder.output_hidden_states = True | |
def from_pretrained(cls, pretrained_model: str) -> Encoder: | |
return robertaEncoder(pretrained_model) | |
def forward( | |
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs | |
) -> Dict[str, torch.Tensor]: | |
last_hidden_states, _, all_layers = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_hidden_states=True, | |
return_dict=False, | |
) | |
return { | |
"sentemb": last_hidden_states[:, 0, :], | |
"wordemb": last_hidden_states, | |
"all_layers": all_layers, | |
"attention_mask": attention_mask, | |
} | |
# initialize roberta into str2encoder | |
comet.encoders.str2encoder['RoBERTa'] = robertaEncoder | |