|
import torch |
|
|
|
|
|
class AttentionMLP(torch.nn.Module): |
|
def __init__(self, input_dim, hidden_dim): |
|
super(AttentionMLP, self).__init__() |
|
self.layers = torch.nn.Sequential( |
|
torch.nn.Linear(input_dim, hidden_dim), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(hidden_dim, 1, bias=False), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.layers(x) |
|
att_w = torch.nn.functional.softmax(x, dim=2) |
|
return att_w |
|
|
|
|
|
class Discrete_EmbeddingLayer(torch.nn.Module): |
|
"""This class handles embedding layers for discrete tokens. |
|
|
|
Arguments |
|
--------- |
|
num_codebooks: int , |
|
number of codebooks of the tokenizer. |
|
vocab_size : int, |
|
size of the dictionary of embeddings |
|
emb_dim: int , |
|
the size of each embedding vector |
|
pad_index: int (default: 0), |
|
If specified, the entries at padding_idx do not contribute to the gradient. |
|
init: boolean (default: False): |
|
If set to True, init the embedding with the tokenizer embedding otherwise init randomly. |
|
freeze: boolean (default: False) |
|
If True, the embedding is frozen. If False, the model will be trained |
|
alongside with the rest of the pipeline. |
|
|
|
Example |
|
------- |
|
>>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec |
|
>>> model_hub = "facebook/encodec_24khz" |
|
>>> save_path = "savedir" |
|
>>> model = Encodec(model_hub, save_path) |
|
>>> audio = torch.randn(4, 1000) |
|
>>> length = torch.tensor([1.0, .5, .75, 1.0]) |
|
>>> tokens, emb = model.encode(audio, length) |
|
>>> print(tokens.shape) |
|
torch.Size([4, 4, 2]) |
|
>>> emb= Discrete_EmbeddingLayer(2, 1024, 1024) |
|
>>> in_emb = emb(tokens) |
|
>>> print(in_emb.shape) |
|
torch.Size([4, 4, 2, 1024]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_codebooks, |
|
vocab_size, |
|
emb_dim, |
|
pad_index=0, |
|
init=False, |
|
freeze=False, |
|
): |
|
super(Discrete_EmbeddingLayer, self).__init__() |
|
self.vocab_size = vocab_size |
|
self.num_codebooks = num_codebooks |
|
self.freeze = freeze |
|
self.embedding = torch.nn.Embedding( |
|
num_codebooks * vocab_size, emb_dim |
|
).requires_grad_(not self.freeze) |
|
self.init = init |
|
|
|
def init_embedding(self, weights): |
|
with torch.no_grad(): |
|
self.embedding.weight = torch.nn.Parameter(weights) |
|
|
|
def forward(self, in_tokens): |
|
"""Computes the embedding for discrete tokens. |
|
a sample. |
|
|
|
Arguments |
|
--------- |
|
in_tokens : torch.Tensor |
|
A (Batch x Time x num_codebooks) |
|
audio sample |
|
Returns |
|
------- |
|
in_embs : torch.Tensor |
|
""" |
|
with torch.set_grad_enabled(not self.freeze): |
|
|
|
in_tokens += torch.arange( |
|
0, |
|
self.num_codebooks * self.vocab_size, |
|
self.vocab_size, |
|
device=in_tokens.device, |
|
) |
|
|
|
in_embs = self.embedding(in_tokens) |
|
return in_embs |
|
|