|
import torch |
|
import torch.nn as nn |
|
|
|
class biLSTM(nn.Module): |
|
""" |
|
The LSTM model that will be used to perform Sentiment analysis. |
|
""" |
|
|
|
def __init__(self, |
|
|
|
vocab_size: int, |
|
|
|
|
|
embedding_dim: int, |
|
|
|
hidden_dim: int, |
|
|
|
n_layers: int, |
|
drop_prob=0.5, |
|
seq_len = 128) -> None: |
|
|
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.n_layers = n_layers |
|
self.seq_len = seq_len |
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
self.lstm = nn.LSTM(embedding_dim, |
|
hidden_dim, |
|
n_layers, |
|
dropout=drop_prob, |
|
batch_first=True, |
|
bidirectional=True |
|
) |
|
|
|
self.do = nn.Dropout() |
|
|
|
self.fc1 = nn.Linear(2*hidden_dim * self.seq_len, 256) |
|
self.fc2 = nn.Linear(256, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
embeds = self.embedding(x) |
|
lstm_out, _ = self.lstm(embeds) |
|
out = self.fc2(torch.tanh(self.do(self.fc1(lstm_out.flatten(1))))) |
|
sig_out = self.sigmoid(out) |
|
|
|
return sig_out |