cbert / model_cbert.py
yonichi's picture
add model
4230aba
import torch
import torch.nn as nn
import random
import numpy as np
import numpy as np
import pandas as pd
import torch.nn.functional as F
from transformers import BertModel, PreTrainedModel
from configuration_cbert import BertCustomConfig
import torch.optim as optim
class BertSentiment(PreTrainedModel):
config_class = BertCustomConfig
def __init__(self, config, weight_path=None):
super().__init__(config)
self.config = config
self.num_labels = self.config.hyperparams["num_labels"]
# self.bert = BertModel.from_pretrained('yiyanghkust/finbert-tone')
if weight_path:
self.bert = BertModel.from_pretrained(weight_path)
else:
self.bert = BertModel(self.config)
self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
self.hidden = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.classifier = nn.Linear(self.config.hidden_size, self.config.hyperparams["num_labels"])
# self.classifier2 = nn.Linear(dense_size + meta_size, num_labels)
nn.init.xavier_normal_(self.hidden.weight)
nn.init.xavier_normal_(self.classifier.weight)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, graphEmbeddings=None):
# _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, return_dict=False)
output, ctoken = self.bert(input_ids, token_type_ids, attention_mask, return_dict=False)
pooled_output = torch.mean(output, 1)
pooled_output = self.hidden(pooled_output)
pooled_output = self.dropout(pooled_output)
pooled_output = F.relu(pooled_output)
logits = self.classifier(pooled_output)
# dense1 = self.classifier(pooled_output)
# concatl = torch.cat((dense1, meta_data.float()), 1)
# logits = self.classifier2(concatl)
return logits