poonehmousavi's picture
Upload 10 files
89a1ae3 verified
raw
history blame
3.21 kB
import torch
class AttentionMLP(torch.nn.Module):
def __init__(self, input_dim, hidden_dim):
super(AttentionMLP, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, 1, bias=False),
)
def forward(self, x):
x = self.layers(x)
att_w = torch.nn.functional.softmax(x, dim=2)
return att_w
class Discrete_EmbeddingLayer(torch.nn.Module):
"""This class handles embedding layers for discrete tokens.
Arguments
---------
num_codebooks: int ,
number of codebooks of the tokenizer.
vocab_size : int,
size of the dictionary of embeddings
emb_dim: int ,
the size of each embedding vector
pad_index: int (default: 0),
If specified, the entries at padding_idx do not contribute to the gradient.
init: boolean (default: False):
If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
freeze: boolean (default: False)
If True, the embedding is frozen. If False, the model will be trained
alongside with the rest of the pipeline.
Example
-------
>>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec
>>> model_hub = "facebook/encodec_24khz"
>>> save_path = "savedir"
>>> model = Encodec(model_hub, save_path)
>>> audio = torch.randn(4, 1000)
>>> length = torch.tensor([1.0, .5, .75, 1.0])
>>> tokens, emb = model.encode(audio, length)
>>> print(tokens.shape)
torch.Size([4, 4, 2])
>>> emb= Discrete_EmbeddingLayer(2, 1024, 1024)
>>> in_emb = emb(tokens)
>>> print(in_emb.shape)
torch.Size([4, 4, 2, 1024])
"""
def __init__(
self,
num_codebooks,
vocab_size,
emb_dim,
pad_index=0,
init=False,
freeze=False,
):
super(Discrete_EmbeddingLayer, self).__init__()
self.vocab_size = vocab_size
self.num_codebooks = num_codebooks
self.freeze = freeze
self.embedding = torch.nn.Embedding(
num_codebooks * vocab_size, emb_dim
).requires_grad_(not self.freeze)
self.init = init
def init_embedding(self, weights):
with torch.no_grad():
self.embedding.weight = torch.nn.Parameter(weights)
def forward(self, in_tokens):
"""Computes the embedding for discrete tokens.
a sample.
Arguments
---------
in_tokens : torch.Tensor
A (Batch x Time x num_codebooks)
audio sample
Returns
-------
in_embs : torch.Tensor
"""
with torch.set_grad_enabled(not self.freeze):
# Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size
in_tokens += torch.arange(
0,
self.num_codebooks * self.vocab_size,
self.vocab_size,
device=in_tokens.device,
)
# Forward Pass to embedding and
in_embs = self.embedding(in_tokens)
return in_embs