File size: 2,242 Bytes
d4be371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoModel,
Wav2Vec2Model,
)
class Projection(torch.nn.Module):
def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(d_in, d_out, bias=False)
self.linear2 = torch.nn.Linear(d_out, d_out, bias=False)
self.layer_norm = torch.nn.LayerNorm(d_out)
self.drop = torch.nn.Dropout(p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(F.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
class SpeechEncoder(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.model_name = model_name
self.base = Wav2Vec2Model.from_pretrained(self.model_name)
self.hidden_size = self.base.config.hidden_size
def forward(self, x):
x = self.base(x)['last_hidden_state']
x = x.mean(1)
return x
class TextEncoder(torch.nn.Module):
def __init__(self, model_name: str) -> None:
super().__init__()
self.base = AutoModel.from_pretrained(model_name)
def forward(self, x):
out = self.base(**x)[0]
out = out[:, 0, :].detach() # get CLS token output
return out
class CLAP(torch.nn.Module):
def __init__(self, speech_name: str, text_name: str, embedding_dim: int = 1024):
super().__init__()
self.audio_branch = SpeechEncoder(model_name=speech_name)
self.text_branch = TextEncoder(model_name=text_name)
self.audio_projection = Projection(self.audio_branch.hidden_size, embedding_dim)
self.text_projection = Projection(self.text_branch.base.config.hidden_size, embedding_dim)
self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, audio, text):
speech_emb = self.audio_branch(audio)
text_emb = self.text_branch(text)
speech_emb = self.audio_projection(speech_emb)
text_emb = self.text_projection(text_emb)
return text_emb, speech_emb, self.logit_scale.exp()
|