taim-gan / src /models /modules /text_encoder.py
Dmmc's picture
three-model version
c8ddb9b
raw
history blame
1.43 kB
"""LSTM-based textual encoder for tokenized input"""
from typing import Any
import torch
from torch import nn
class TextEncoder(nn.Module):
"""Simple text encoder based on RNN"""
def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int) -> None:
"""
Initialize embeddings lookup for tokens and main LSTM
:param vocab_size:
Size of created vocabulary for textual input. L from paper
:param emb_dim: Length of embeddings for each word.
:param hidden_dim:
Length of hidden state of a LSTM cell. 2 x hidden_dim = C (from LWGAN paper)
"""
super().__init__()
self.embs = nn.Embedding(vocab_size, emb_dim)
self.lstm = nn.LSTM(emb_dim, hidden_dim, bidirectional=True, batch_first=True)
def forward(self, tokens: torch.Tensor) -> Any:
"""
Propagate the text token input through the LSTM and return
two types of embeddings: word-level and sentence-level
:param torch.Tensor tokens: Input text tokens from vocab
:return: Word-level embeddings (BxCxL) and sentence-level embeddings (BxC)
:rtype: Any
"""
embs = self.embs(tokens)
output, (hidden_states, _) = self.lstm(embs)
word_embs = torch.transpose(output, 1, 2)
sent_embs = torch.cat((hidden_states[-1, :, :], hidden_states[0, :, :]), dim=1)
return word_embs, sent_embs