File size: 2,242 Bytes
d4be371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoModel,
    Wav2Vec2Model,
)

class Projection(torch.nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(d_in, d_out, bias=False)
        self.linear2 = torch.nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = torch.nn.LayerNorm(d_out)
        self.drop = torch.nn.Dropout(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds


class SpeechEncoder(torch.nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.base = Wav2Vec2Model.from_pretrained(self.model_name)
        self.hidden_size = self.base.config.hidden_size

    def forward(self, x):
        x = self.base(x)['last_hidden_state']
        x = x.mean(1)
        return x


class TextEncoder(torch.nn.Module):
    def __init__(self, model_name: str) -> None:
        super().__init__()
        self.base = AutoModel.from_pretrained(model_name)

    def forward(self, x):
        out = self.base(**x)[0]
        out = out[:, 0, :].detach()  # get CLS token output
        return out


class CLAP(torch.nn.Module):
    def __init__(self, speech_name: str, text_name: str, embedding_dim: int = 1024):
        super().__init__()

        self.audio_branch = SpeechEncoder(model_name=speech_name)

        self.text_branch = TextEncoder(model_name=text_name)
        self.audio_projection = Projection(self.audio_branch.hidden_size, embedding_dim)
        self.text_projection = Projection(self.text_branch.base.config.hidden_size, embedding_dim)

        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, audio, text):
        speech_emb = self.audio_branch(audio)
        text_emb = self.text_branch(text)

        speech_emb = self.audio_projection(speech_emb)
        text_emb = self.text_projection(text_emb)

        return text_emb, speech_emb, self.logit_scale.exp()