| """LSTM-based textual encoder for tokenized input""" | |
| from typing import Any | |
| import torch | |
| from torch import nn | |
| class TextEncoder(nn.Module): | |
| """Simple text encoder based on RNN""" | |
| def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int) -> None: | |
| """ | |
| Initialize embeddings lookup for tokens and main LSTM | |
| :param vocab_size: | |
| Size of created vocabulary for textual input. L from paper | |
| :param emb_dim: Length of embeddings for each word. | |
| :param hidden_dim: | |
| Length of hidden state of a LSTM cell. 2 x hidden_dim = C (from LWGAN paper) | |
| """ | |
| super().__init__() | |
| self.embs = nn.Embedding(vocab_size, emb_dim) | |
| self.lstm = nn.LSTM(emb_dim, hidden_dim, bidirectional=True, batch_first=True) | |
| def forward(self, tokens: torch.Tensor) -> Any: | |
| """ | |
| Propagate the text token input through the LSTM and return | |
| two types of embeddings: word-level and sentence-level | |
| :param torch.Tensor tokens: Input text tokens from vocab | |
| :return: Word-level embeddings (BxCxL) and sentence-level embeddings (BxC) | |
| :rtype: Any | |
| """ | |
| embs = self.embs(tokens) | |
| output, (hidden_states, _) = self.lstm(embs) | |
| word_embs = torch.transpose(output, 1, 2) | |
| sent_embs = torch.cat((hidden_states[-1, :, :], hidden_states[0, :, :]), dim=1) | |
| return word_embs, sent_embs | |