sescore / __init__.py
xu1998hz's picture
add all modules
bfa7fa8
raw
history blame
1.33 kB
import comet
from typing import Dict
import torch
from comet.encoders.base import Encoder
from comet.encoders.bert import BERTEncoder
from transformers import AutoModel, AutoTokenizer
class robertaEncoder(BERTEncoder):
def __init__(self, pretrained_model: str) -> None:
super(Encoder, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
self.model = AutoModel.from_pretrained(
pretrained_model, add_pooling_layer=False
)
self.model.encoder.output_hidden_states = True
@classmethod
def from_pretrained(cls, pretrained_model: str) -> Encoder:
return robertaEncoder(pretrained_model)
def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
) -> Dict[str, torch.Tensor]:
last_hidden_states, _, all_layers = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=False,
)
return {
"sentemb": last_hidden_states[:, 0, :],
"wordemb": last_hidden_states,
"all_layers": all_layers,
"attention_mask": attention_mask,
}
# initialize roberta into str2encoder
comet.encoders.str2encoder['RoBERTa'] = robertaEncoder