weibo_senti_cls / cnn.py
ZaynSu99's picture
Update weibo_senti_cls
995278d
import torch
from transformers import BertModel,BertTokenizer,AdamW
pretrained = BertModel.from_pretrained('bert-base-chinese').to(torch.device('cuda'))
for param in pretrained.parameters():
param.requires_grad_(False)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(768,2)
def forward(self,input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
out = self.fc(out.last_hidden_state[:,0])
out = out.softmax(dim=1)
return out
class CNN(torch.nn.Module):
def __init__(self):
super().__init__()
#self.embedding = torch.nn.Embedding()
self.conv1 = torch.nn.Conv2d(1,1,(4,768))
self.conv2 = torch.nn.Conv2d(1,1,(3,768))
self.conv3 = torch.nn.Conv2d(1,1,(2,768))
self.pool1 = torch.nn.MaxPool2d((68,1), stride=68)
self.pool2 = torch.nn.MaxPool2d((69,1), stride=69)
self.pool3 = torch.nn.MaxPool2d((69,1), stride=69)
self.fc = torch.nn.Linear(6,2)
def forward(self, input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
out = torch.unsqueeze(out.last_hidden_state,1)
out1 = self.conv1(out)
out2 = self.conv2(out)
out3 = self.conv3(out)
out1 = torch.tanh(out1)
out2 = torch.tanh(out2)
out3 = torch.tanh(out3)
out1 = self.pool1(out1)
out2 = self.pool2(out2)
out3 = self.pool3(out3)
output = torch.cat((out1,out2,out3),dim=2)
output = self.fc(output[:,0,:,0])
output = output.softmax(dim=1)
return output