from transformers import AutoModel | |
from torch import nn | |
import pytorch_lightning as pl | |
class MediumBert(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.bert_model = AutoModel.from_pretrained('asafaya/bert-medium-arabic') | |
self.fc = nn.Linear(512,18) | |
def forward(self,input_ids,attention_mask): | |
out = self.bert_model(input_ids = input_ids, attention_mask =attention_mask)#inputs["input_ids"],inputs["token_type_ids"],inputs["attention_mask"]) | |
pooler = out[1] | |
out = self.fc(pooler) | |
return out |