|
"""LSTM-based textual encoder for tokenized input""" |
|
|
|
from typing import Any |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class TextEncoder(nn.Module): |
|
"""Simple text encoder based on RNN""" |
|
|
|
def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int) -> None: |
|
""" |
|
Initialize embeddings lookup for tokens and main LSTM |
|
|
|
:param vocab_size: |
|
Size of created vocabulary for textual input. L from paper |
|
:param emb_dim: Length of embeddings for each word. |
|
:param hidden_dim: |
|
Length of hidden state of a LSTM cell. 2 x hidden_dim = C (from LWGAN paper) |
|
""" |
|
super().__init__() |
|
self.embs = nn.Embedding(vocab_size, emb_dim) |
|
self.lstm = nn.LSTM(emb_dim, hidden_dim, bidirectional=True, batch_first=True) |
|
|
|
def forward(self, tokens: torch.Tensor) -> Any: |
|
""" |
|
Propagate the text token input through the LSTM and return |
|
two types of embeddings: word-level and sentence-level |
|
|
|
:param torch.Tensor tokens: Input text tokens from vocab |
|
:return: Word-level embeddings (BxCxL) and sentence-level embeddings (BxC) |
|
:rtype: Any |
|
""" |
|
embs = self.embs(tokens) |
|
output, (hidden_states, _) = self.lstm(embs) |
|
word_embs = torch.transpose(output, 1, 2) |
|
sent_embs = torch.cat((hidden_states[-1, :, :], hidden_states[0, :, :]), dim=1) |
|
return word_embs, sent_embs |
|
|