haoheliu commited on
Commit
bdab1da
·
1 Parent(s): a5d109b

first commit and add large model

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +0 -0
  2. app.py +10 -4
  3. audioldm/__init__.py +3 -0
  4. audioldm/audio/__init__.py +0 -0
  5. audioldm/audio/audio_processing.py +100 -0
  6. audioldm/audio/stft.py +180 -0
  7. audioldm/audio/tools.py +33 -0
  8. audioldm/clap/__init__.py +0 -0
  9. audioldm/clap/encoders.py +169 -0
  10. audioldm/clap/open_clip/__init__.py +25 -0
  11. audioldm/clap/open_clip/bert.py +40 -0
  12. audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  13. audioldm/clap/open_clip/factory.py +277 -0
  14. audioldm/clap/open_clip/feature_fusion.py +192 -0
  15. audioldm/clap/open_clip/htsat.py +1308 -0
  16. audioldm/clap/open_clip/linear_probe.py +66 -0
  17. audioldm/clap/open_clip/loss.py +398 -0
  18. audioldm/clap/open_clip/model.py +934 -0
  19. audioldm/clap/open_clip/model_configs/HTSAT-base.json +23 -0
  20. audioldm/clap/open_clip/model_configs/HTSAT-large.json +23 -0
  21. audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +23 -0
  22. audioldm/clap/open_clip/model_configs/HTSAT-tiny.json +23 -0
  23. audioldm/clap/open_clip/model_configs/PANN-10.json +23 -0
  24. audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json +23 -0
  25. audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +23 -0
  26. audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +23 -0
  27. audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json +23 -0
  28. audioldm/clap/open_clip/model_configs/PANN-14.json +23 -0
  29. audioldm/clap/open_clip/model_configs/PANN-6.json +23 -0
  30. audioldm/clap/open_clip/model_configs/RN101-quickgelu.json +22 -0
  31. audioldm/clap/open_clip/model_configs/RN101.json +21 -0
  32. audioldm/clap/open_clip/model_configs/RN50-quickgelu.json +22 -0
  33. audioldm/clap/open_clip/model_configs/RN50.json +21 -0
  34. audioldm/clap/open_clip/model_configs/RN50x16.json +21 -0
  35. audioldm/clap/open_clip/model_configs/RN50x4.json +21 -0
  36. audioldm/clap/open_clip/model_configs/ViT-B-16.json +16 -0
  37. audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  38. audioldm/clap/open_clip/model_configs/ViT-B-32.json +16 -0
  39. audioldm/clap/open_clip/model_configs/ViT-L-14.json +16 -0
  40. audioldm/clap/open_clip/openai.py +156 -0
  41. audioldm/clap/open_clip/pann_model.py +703 -0
  42. audioldm/clap/open_clip/pretrained.py +167 -0
  43. audioldm/clap/open_clip/timm_model.py +112 -0
  44. audioldm/clap/open_clip/tokenizer.py +197 -0
  45. audioldm/clap/open_clip/transform.py +45 -0
  46. audioldm/clap/open_clip/utils.py +361 -0
  47. audioldm/clap/open_clip/version.py +1 -0
  48. audioldm/clap/training/__init__.py +0 -0
  49. audioldm/clap/training/audioset_textmap.npy +3 -0
  50. audioldm/clap/training/data.py +977 -0
.gitignore ADDED
File without changes
app.py CHANGED
@@ -1,8 +1,14 @@
1
  import gradio as gr
2
  import numpy as np
 
3
 
4
- def greet(name):
5
- return [(16000, np.random.randn(16000)), (16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
 
 
6
 
7
- iface = gr.Interface(fn=greet, inputs="text", outputs=["audio", "audio", "audio"])
8
- iface.launch()
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ from audioldm import text_to_audio
4
 
5
+ def greet(text):
6
+ waveform = text_to_audio(text, n_gen=1) # [bs, 1, samples]
7
+ waveform = [(16000, wave[0]) for wave in waveform]
8
+ return waveform
9
 
10
+ iface = gr.Interface(fn=greet, inputs="text", outputs=["audio", "audio"])
11
+ iface.launch()
12
+
13
+ # if __name__ == "__main__":
14
+ # greet("hello world")
audioldm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ldm import LatentDiffusion
2
+ from .utils import seed_everything
3
+ from .pipeline import *
audioldm/audio/__init__.py ADDED
File without changes
audioldm/audio/audio_processing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import librosa.util as librosa_util
4
+ from scipy.signal import get_window
5
+
6
+
7
+ def window_sumsquare(
8
+ window,
9
+ n_frames,
10
+ hop_length,
11
+ win_length,
12
+ n_fft,
13
+ dtype=np.float32,
14
+ norm=None,
15
+ ):
16
+ """
17
+ # from librosa 0.6
18
+ Compute the sum-square envelope of a window function at a given hop length.
19
+
20
+ This is used to estimate modulation effects induced by windowing
21
+ observations in short-time fourier transforms.
22
+
23
+ Parameters
24
+ ----------
25
+ window : string, tuple, number, callable, or list-like
26
+ Window specification, as in `get_window`
27
+
28
+ n_frames : int > 0
29
+ The number of analysis frames
30
+
31
+ hop_length : int > 0
32
+ The number of samples to advance between frames
33
+
34
+ win_length : [optional]
35
+ The length of the window function. By default, this matches `n_fft`.
36
+
37
+ n_fft : int > 0
38
+ The length of each analysis frame.
39
+
40
+ dtype : np.dtype
41
+ The data type of the output
42
+
43
+ Returns
44
+ -------
45
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46
+ The sum-squared envelope of the window function
47
+ """
48
+ if win_length is None:
49
+ win_length = n_fft
50
+
51
+ n = n_fft + hop_length * (n_frames - 1)
52
+ x = np.zeros(n, dtype=dtype)
53
+
54
+ # Compute the squared window at the desired length
55
+ win_sq = get_window(window, win_length, fftbins=True)
56
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
58
+
59
+ # Fill the envelope
60
+ for i in range(n_frames):
61
+ sample = i * hop_length
62
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63
+ return x
64
+
65
+
66
+ def griffin_lim(magnitudes, stft_fn, n_iters=30):
67
+ """
68
+ PARAMS
69
+ ------
70
+ magnitudes: spectrogram magnitudes
71
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72
+ """
73
+
74
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75
+ angles = angles.astype(np.float32)
76
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
77
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78
+
79
+ for i in range(n_iters):
80
+ _, angles = stft_fn.transform(signal)
81
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82
+ return signal
83
+
84
+
85
+ def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86
+ """
87
+ PARAMS
88
+ ------
89
+ C: compression factor
90
+ """
91
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
92
+
93
+
94
+ def dynamic_range_decompression(x, C=1):
95
+ """
96
+ PARAMS
97
+ ------
98
+ C: compression factor used to compress
99
+ """
100
+ return torch.exp(x) / C
audioldm/audio/stft.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy.signal import get_window
5
+ from librosa.util import pad_center, tiny
6
+ from librosa.filters import mel as librosa_mel_fn
7
+
8
+ from audioldm.audio.audio_processing import (
9
+ dynamic_range_compression,
10
+ dynamic_range_decompression,
11
+ window_sumsquare,
12
+ )
13
+
14
+
15
+ class STFT(torch.nn.Module):
16
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17
+
18
+ def __init__(self, filter_length, hop_length, win_length, window="hann"):
19
+ super(STFT, self).__init__()
20
+ self.filter_length = filter_length
21
+ self.hop_length = hop_length
22
+ self.win_length = win_length
23
+ self.window = window
24
+ self.forward_transform = None
25
+ scale = self.filter_length / self.hop_length
26
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
27
+
28
+ cutoff = int((self.filter_length / 2 + 1))
29
+ fourier_basis = np.vstack(
30
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31
+ )
32
+
33
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34
+ inverse_basis = torch.FloatTensor(
35
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36
+ )
37
+
38
+ if window is not None:
39
+ assert filter_length >= win_length
40
+ # get window and zero center pad it to filter_length
41
+ fft_window = get_window(window, win_length, fftbins=True)
42
+ fft_window = pad_center(fft_window, filter_length)
43
+ fft_window = torch.from_numpy(fft_window).float()
44
+
45
+ # window the bases
46
+ forward_basis *= fft_window
47
+ inverse_basis *= fft_window
48
+
49
+ self.register_buffer("forward_basis", forward_basis.float())
50
+ self.register_buffer("inverse_basis", inverse_basis.float())
51
+
52
+ def transform(self, input_data):
53
+ num_batches = input_data.size(0)
54
+ num_samples = input_data.size(1)
55
+
56
+ self.num_samples = num_samples
57
+
58
+ # similar to librosa, reflect-pad the input
59
+ input_data = input_data.view(num_batches, 1, num_samples)
60
+ input_data = F.pad(
61
+ input_data.unsqueeze(1),
62
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63
+ mode="reflect",
64
+ )
65
+ input_data = input_data.squeeze(1)
66
+
67
+ forward_transform = F.conv1d(
68
+ input_data,
69
+ torch.autograd.Variable(self.forward_basis, requires_grad=False),
70
+ stride=self.hop_length,
71
+ padding=0,
72
+ ).cpu()
73
+
74
+ cutoff = int((self.filter_length / 2) + 1)
75
+ real_part = forward_transform[:, :cutoff, :]
76
+ imag_part = forward_transform[:, cutoff:, :]
77
+
78
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
79
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80
+
81
+ return magnitude, phase
82
+
83
+ def inverse(self, magnitude, phase):
84
+ recombine_magnitude_phase = torch.cat(
85
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86
+ )
87
+
88
+ inverse_transform = F.conv_transpose1d(
89
+ recombine_magnitude_phase,
90
+ torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91
+ stride=self.hop_length,
92
+ padding=0,
93
+ )
94
+
95
+ if self.window is not None:
96
+ window_sum = window_sumsquare(
97
+ self.window,
98
+ magnitude.size(-1),
99
+ hop_length=self.hop_length,
100
+ win_length=self.win_length,
101
+ n_fft=self.filter_length,
102
+ dtype=np.float32,
103
+ )
104
+ # remove modulation effects
105
+ approx_nonzero_indices = torch.from_numpy(
106
+ np.where(window_sum > tiny(window_sum))[0]
107
+ )
108
+ window_sum = torch.autograd.Variable(
109
+ torch.from_numpy(window_sum), requires_grad=False
110
+ )
111
+ window_sum = window_sum
112
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113
+ approx_nonzero_indices
114
+ ]
115
+
116
+ # scale by hop ratio
117
+ inverse_transform *= float(self.filter_length) / self.hop_length
118
+
119
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121
+
122
+ return inverse_transform
123
+
124
+ def forward(self, input_data):
125
+ self.magnitude, self.phase = self.transform(input_data)
126
+ reconstruction = self.inverse(self.magnitude, self.phase)
127
+ return reconstruction
128
+
129
+
130
+ class TacotronSTFT(torch.nn.Module):
131
+ def __init__(
132
+ self,
133
+ filter_length,
134
+ hop_length,
135
+ win_length,
136
+ n_mel_channels,
137
+ sampling_rate,
138
+ mel_fmin,
139
+ mel_fmax,
140
+ ):
141
+ super(TacotronSTFT, self).__init__()
142
+ self.n_mel_channels = n_mel_channels
143
+ self.sampling_rate = sampling_rate
144
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
145
+ mel_basis = librosa_mel_fn(
146
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147
+ )
148
+ mel_basis = torch.from_numpy(mel_basis).float()
149
+ self.register_buffer("mel_basis", mel_basis)
150
+
151
+ def spectral_normalize(self, magnitudes, normalize_fun):
152
+ output = dynamic_range_compression(magnitudes, normalize_fun)
153
+ return output
154
+
155
+ def spectral_de_normalize(self, magnitudes):
156
+ output = dynamic_range_decompression(magnitudes)
157
+ return output
158
+
159
+ def mel_spectrogram(self, y, normalize_fun=torch.log):
160
+ """Computes mel-spectrograms from a batch of waves
161
+ PARAMS
162
+ ------
163
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164
+
165
+ RETURNS
166
+ -------
167
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168
+ """
169
+ assert torch.min(y.data) >= -1, torch.min(y.data)
170
+ assert torch.max(y.data) <= 1, torch.max(y.data)
171
+
172
+ magnitudes, phases = self.stft_fn.transform(y)
173
+ magnitudes = magnitudes.data
174
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
175
+ mel_output = self.spectral_normalize(mel_output, normalize_fun)
176
+ energy = torch.norm(magnitudes, dim=1)
177
+
178
+ log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
179
+
180
+ return mel_output, log_magnitudes, energy
audioldm/audio/tools.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def get_mel_from_wav(audio, _stft):
6
+ audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
7
+ audio = torch.autograd.Variable(audio, requires_grad=False)
8
+ melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
9
+ melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
10
+ log_magnitudes_stft = (
11
+ torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
12
+ )
13
+ energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
14
+ return melspec, log_magnitudes_stft, energy
15
+
16
+
17
+ # def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
18
+ # mel = torch.stack([mel])
19
+ # mel_decompress = _stft.spectral_de_normalize(mel)
20
+ # mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
21
+ # spec_from_mel_scaling = 1000
22
+ # spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
23
+ # spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
24
+ # spec_from_mel = spec_from_mel * spec_from_mel_scaling
25
+
26
+ # audio = griffin_lim(
27
+ # torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
28
+ # )
29
+
30
+ # audio = audio.squeeze()
31
+ # audio = audio.cpu().numpy()
32
+ # audio_path = out_filename
33
+ # write(audio_path, _stft.sampling_rate, audio)
audioldm/clap/__init__.py ADDED
File without changes
audioldm/clap/encoders.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from audioldm.clap.open_clip import create_model
4
+ from audioldm.clap.training.data import get_audio_features
5
+ import torchaudio
6
+ from transformers import RobertaTokenizer
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
11
+ def __init__(
12
+ self,
13
+ pretrained_path="",
14
+ key="class",
15
+ sampling_rate=16000,
16
+ embed_mode="audio",
17
+ unconditional_prob=0.1,
18
+ random_mute=False,
19
+ max_random_mute_portion=0.5,
20
+ training_mode=True,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.key = key
25
+ self.device = "cpu"
26
+ self.precision = "fp32"
27
+ self.amodel = "HTSAT-tiny" # or 'PANN-14'
28
+ self.tmodel = "roberta" # the best text encoder in our training
29
+ self.enable_fusion = False # False if you do not want to use the fusion model
30
+ self.fusion_type = "aff_2d"
31
+ self.pretrained = pretrained_path
32
+ self.embed_mode = embed_mode
33
+ self.embed_mode_orig = embed_mode
34
+ self.sampling_rate = sampling_rate
35
+ self.unconditional_prob = unconditional_prob
36
+ self.random_mute = random_mute
37
+ self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
38
+ self.max_random_mute_portion = max_random_mute_portion
39
+ self.training_mode = training_mode
40
+ self.model, self.model_cfg = create_model(
41
+ self.amodel,
42
+ self.tmodel,
43
+ self.pretrained,
44
+ precision=self.precision,
45
+ device=self.device,
46
+ enable_fusion=self.enable_fusion,
47
+ fusion_type=self.fusion_type,
48
+ )
49
+ for p in self.model.parameters():
50
+ p.requires_grad = False
51
+
52
+ self.model.eval()
53
+
54
+ def get_unconditional_condition(self, batchsize):
55
+ self.unconditional_token = self.model.get_text_embedding(
56
+ self.tokenizer(["", ""])
57
+ )[0:1]
58
+ return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
59
+
60
+ def batch_to_list(self, batch):
61
+ ret = []
62
+ for i in range(batch.size(0)):
63
+ ret.append(batch[i])
64
+ return ret
65
+
66
+ def make_decision(self, probability):
67
+ if float(torch.rand(1)) < probability:
68
+ return True
69
+ else:
70
+ return False
71
+
72
+ def random_uniform(self, start, end):
73
+ val = torch.rand(1).item()
74
+ return start + (end - start) * val
75
+
76
+ def _random_mute(self, waveform):
77
+ # waveform: [bs, t-steps]
78
+ t_steps = waveform.size(-1)
79
+ for i in range(waveform.size(0)):
80
+ mute_size = int(
81
+ self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
82
+ )
83
+ mute_start = int(self.random_uniform(0, t_steps - mute_size))
84
+ waveform[i, mute_start : mute_start + mute_size] = 0
85
+ return waveform
86
+
87
+ def cos_similarity(self, waveform, text):
88
+ # waveform: [bs, t_steps]
89
+ with torch.no_grad():
90
+ self.embed_mode = "audio"
91
+ audio_emb = self(waveform.cuda())
92
+ self.embed_mode = "text"
93
+ text_emb = self(text)
94
+ similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
95
+ return similarity.squeeze()
96
+
97
+ def forward(self, batch, key=None):
98
+ # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
99
+ # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
100
+ if self.model.training == True and not self.training_mode:
101
+ print(
102
+ "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
103
+ )
104
+ self.model, self.model_cfg = create_model(
105
+ self.amodel,
106
+ self.tmodel,
107
+ self.pretrained,
108
+ precision=self.precision,
109
+ device="cuda",
110
+ enable_fusion=self.enable_fusion,
111
+ fusion_type=self.fusion_type,
112
+ )
113
+ for p in self.model.parameters():
114
+ p.requires_grad = False
115
+ self.model.eval()
116
+
117
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
118
+ if self.embed_mode == "audio":
119
+ with torch.no_grad():
120
+ audio_dict_list = []
121
+ assert (
122
+ self.sampling_rate == 16000
123
+ ), "We only support 16000 sampling rate"
124
+ if self.random_mute:
125
+ batch = self._random_mute(batch)
126
+ # batch: [bs, 1, t-samples]
127
+ batch = torchaudio.functional.resample(
128
+ batch, orig_freq=self.sampling_rate, new_freq=48000
129
+ )
130
+ for waveform in self.batch_to_list(batch):
131
+ audio_dict = {}
132
+ audio_dict = get_audio_features(
133
+ audio_dict,
134
+ waveform,
135
+ 480000,
136
+ data_truncating="fusion",
137
+ data_filling="repeatpad",
138
+ audio_cfg=self.model_cfg["audio_cfg"],
139
+ )
140
+ audio_dict_list.append(audio_dict)
141
+ # [bs, 512]
142
+ embed = self.model.get_audio_embedding(audio_dict_list)
143
+ elif self.embed_mode == "text":
144
+ with torch.no_grad():
145
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
146
+ text_data = self.tokenizer(batch)
147
+ embed = self.model.get_text_embedding(text_data)
148
+
149
+ embed = embed.unsqueeze(1)
150
+ self.unconditional_token = self.model.get_text_embedding(
151
+ self.tokenizer(["", ""])
152
+ )[0:1]
153
+
154
+ for i in range(embed.size(0)):
155
+ if self.make_decision(self.unconditional_prob):
156
+ embed[i] = self.unconditional_token
157
+
158
+ # [bs, 1, 512]
159
+ return embed.detach()
160
+
161
+ def tokenizer(self, text):
162
+ result = self.tokenize(
163
+ text,
164
+ padding="max_length",
165
+ truncation=True,
166
+ max_length=77,
167
+ return_tensors="pt",
168
+ )
169
+ return {k: v.squeeze(0) for k, v in result.items()}
audioldm/clap/open_clip/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .factory import (
2
+ list_models,
3
+ create_model,
4
+ create_model_and_transforms,
5
+ add_model_config,
6
+ )
7
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
+ from .model import (
9
+ CLAP,
10
+ CLAPTextCfg,
11
+ CLAPVisionCfg,
12
+ CLAPAudioCfp,
13
+ convert_weights_to_fp16,
14
+ trace_model,
15
+ )
16
+ from .openai import load_openai_model, list_openai_models
17
+ from .pretrained import (
18
+ list_pretrained,
19
+ list_pretrained_tag_models,
20
+ list_pretrained_model_tags,
21
+ get_pretrained_url,
22
+ download_pretrained,
23
+ )
24
+ from .tokenizer import SimpleTokenizer, tokenize
25
+ from .transform import image_transform
audioldm/clap/open_clip/bert.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+
3
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4
+ model = BertModel.from_pretrained("bert-base-uncased")
5
+ text = "Replace me by any text you'd like."
6
+
7
+
8
+ def bert_embeddings(text):
9
+ # text = "Replace me by any text you'd like."
10
+ encoded_input = tokenizer(text, return_tensors="pt")
11
+ output = model(**encoded_input)
12
+ return output
13
+
14
+
15
+ from transformers import RobertaTokenizer, RobertaModel
16
+
17
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18
+ model = RobertaModel.from_pretrained("roberta-base")
19
+ text = "Replace me by any text you'd like."
20
+
21
+
22
+ def Roberta_embeddings(text):
23
+ # text = "Replace me by any text you'd like."
24
+ encoded_input = tokenizer(text, return_tensors="pt")
25
+ output = model(**encoded_input)
26
+ return output
27
+
28
+
29
+ from transformers import BartTokenizer, BartModel
30
+
31
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32
+ model = BartModel.from_pretrained("facebook/bart-base")
33
+ text = "Replace me by any text you'd like."
34
+
35
+
36
+ def bart_embeddings(text):
37
+ # text = "Replace me by any text you'd like."
38
+ encoded_input = tokenizer(text, return_tensors="pt")
39
+ output = model(**encoded_input)
40
+ return output
audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
audioldm/clap/open_clip/factory.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ from .model import CLAP, convert_weights_to_fp16
12
+ from .openai import load_openai_model
13
+ from .pretrained import get_pretrained_url, download_pretrained
14
+ from .transform import image_transform
15
+
16
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
+
19
+
20
+ def _natural_key(string_):
21
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
+
23
+
24
+ def _rescan_model_configs():
25
+ global _MODEL_CONFIGS
26
+
27
+ config_ext = (".json",)
28
+ config_files = []
29
+ for config_path in _MODEL_CONFIG_PATHS:
30
+ if config_path.is_file() and config_path.suffix in config_ext:
31
+ config_files.append(config_path)
32
+ elif config_path.is_dir():
33
+ for ext in config_ext:
34
+ config_files.extend(config_path.glob(f"*{ext}"))
35
+
36
+ for cf in config_files:
37
+ if os.path.basename(cf)[0] == ".":
38
+ continue # Ignore hidden files
39
+
40
+ with open(cf, "r") as f:
41
+ model_cfg = json.load(f)
42
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
43
+ _MODEL_CONFIGS[cf.stem] = model_cfg
44
+
45
+ _MODEL_CONFIGS = {
46
+ k: v
47
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
48
+ }
49
+
50
+
51
+ _rescan_model_configs() # initial populate of model config registry
52
+
53
+
54
+ def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
55
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
56
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
57
+ state_dict = checkpoint["state_dict"]
58
+ else:
59
+ state_dict = checkpoint
60
+ if skip_params:
61
+ if next(iter(state_dict.items()))[0].startswith("module"):
62
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
63
+ # for k in state_dict:
64
+ # if k.startswith('transformer'):
65
+ # v = state_dict.pop(k)
66
+ # state_dict['text_branch.' + k[12:]] = v
67
+ return state_dict
68
+
69
+
70
+ def create_model(
71
+ amodel_name: str,
72
+ tmodel_name: str,
73
+ pretrained: str = "",
74
+ precision: str = "fp32",
75
+ device: torch.device = torch.device("cpu"),
76
+ jit: bool = False,
77
+ force_quick_gelu: bool = False,
78
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
79
+ skip_params=True,
80
+ pretrained_audio: str = "",
81
+ pretrained_text: str = "",
82
+ enable_fusion: bool = False,
83
+ fusion_type: str = "None"
84
+ # pretrained_image: bool = False,
85
+ ):
86
+ amodel_name = amodel_name.replace(
87
+ "/", "-"
88
+ ) # for callers using old naming with / in ViT names
89
+ pretrained_orig = pretrained
90
+ pretrained = pretrained.lower()
91
+ if pretrained == "openai":
92
+ if amodel_name in _MODEL_CONFIGS:
93
+ logging.info(f"Loading {amodel_name} model config.")
94
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
95
+ else:
96
+ logging.error(
97
+ f"Model config for {amodel_name} not found; available models {list_models()}."
98
+ )
99
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
100
+
101
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
102
+ # Hard Code in model name
103
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
104
+ model = load_openai_model(
105
+ "ViT-B-16",
106
+ model_cfg,
107
+ device=device,
108
+ jit=jit,
109
+ cache_dir=openai_model_cache_dir,
110
+ enable_fusion=enable_fusion,
111
+ fusion_type=fusion_type,
112
+ )
113
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
114
+ if precision == "amp" or precision == "fp32":
115
+ model = model.float()
116
+ else:
117
+ if amodel_name in _MODEL_CONFIGS:
118
+ logging.info(f"Loading {amodel_name} model config.")
119
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
120
+ else:
121
+ logging.error(
122
+ f"Model config for {amodel_name} not found; available models {list_models()}."
123
+ )
124
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
125
+
126
+ if force_quick_gelu:
127
+ # override for use of QuickGELU on non-OpenAI transformer models
128
+ model_cfg["quick_gelu"] = True
129
+
130
+ # if pretrained_image:
131
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
132
+ # # pretrained weight loading for timm models set via vision_cfg
133
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
134
+ # else:
135
+ # assert False, 'pretrained image towers currently only supported for timm models'
136
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
137
+ model_cfg["enable_fusion"] = enable_fusion
138
+ model_cfg["fusion_type"] = fusion_type
139
+ model = CLAP(**model_cfg)
140
+
141
+ if pretrained:
142
+ checkpoint_path = ""
143
+ url = get_pretrained_url(amodel_name, pretrained)
144
+ if url:
145
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
146
+ elif os.path.exists(pretrained_orig):
147
+ checkpoint_path = pretrained_orig
148
+ if checkpoint_path:
149
+ logging.info(
150
+ f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
151
+ )
152
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
153
+ model.load_state_dict(ckpt)
154
+ param_names = [n for n, p in model.named_parameters()]
155
+ # for n in param_names:
156
+ # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
157
+ else:
158
+ logging.warning(
159
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
160
+ )
161
+ raise RuntimeError(
162
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
163
+ )
164
+
165
+ if pretrained_audio:
166
+ if amodel_name.startswith("PANN"):
167
+ if "Cnn14_mAP" in pretrained_audio: # official checkpoint
168
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
169
+ audio_ckpt = audio_ckpt["model"]
170
+ keys = list(audio_ckpt.keys())
171
+ for key in keys:
172
+ if (
173
+ "spectrogram_extractor" not in key
174
+ and "logmel_extractor" not in key
175
+ ):
176
+ v = audio_ckpt.pop(key)
177
+ audio_ckpt["audio_branch." + key] = v
178
+ elif os.path.basename(pretrained_audio).startswith(
179
+ "PANN"
180
+ ): # checkpoint trained via HTSAT codebase
181
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
182
+ audio_ckpt = audio_ckpt["state_dict"]
183
+ keys = list(audio_ckpt.keys())
184
+ for key in keys:
185
+ if key.startswith("sed_model"):
186
+ v = audio_ckpt.pop(key)
187
+ audio_ckpt["audio_branch." + key[10:]] = v
188
+ elif os.path.basename(pretrained_audio).startswith(
189
+ "finetuned"
190
+ ): # checkpoint trained via linear probe codebase
191
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
192
+ else:
193
+ raise ValueError("Unknown audio checkpoint")
194
+ elif amodel_name.startswith("HTSAT"):
195
+ if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
196
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
197
+ audio_ckpt = audio_ckpt["state_dict"]
198
+ keys = list(audio_ckpt.keys())
199
+ for key in keys:
200
+ if key.startswith("sed_model") and (
201
+ "spectrogram_extractor" not in key
202
+ and "logmel_extractor" not in key
203
+ ):
204
+ v = audio_ckpt.pop(key)
205
+ audio_ckpt["audio_branch." + key[10:]] = v
206
+ elif os.path.basename(pretrained_audio).startswith(
207
+ "HTSAT"
208
+ ): # checkpoint trained via HTSAT codebase
209
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
210
+ audio_ckpt = audio_ckpt["state_dict"]
211
+ keys = list(audio_ckpt.keys())
212
+ for key in keys:
213
+ if key.startswith("sed_model"):
214
+ v = audio_ckpt.pop(key)
215
+ audio_ckpt["audio_branch." + key[10:]] = v
216
+ elif os.path.basename(pretrained_audio).startswith(
217
+ "finetuned"
218
+ ): # checkpoint trained via linear probe codebase
219
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
220
+ else:
221
+ raise ValueError("Unknown audio checkpoint")
222
+ else:
223
+ raise f"this audio encoder pretrained checkpoint is not support"
224
+
225
+ model.load_state_dict(audio_ckpt, strict=False)
226
+ logging.info(
227
+ f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
228
+ )
229
+ param_names = [n for n, p in model.named_parameters()]
230
+ for n in param_names:
231
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
232
+
233
+ model.to(device=device)
234
+ if precision == "fp16":
235
+ assert device.type != "cpu"
236
+ convert_weights_to_fp16(model)
237
+
238
+ if jit:
239
+ model = torch.jit.script(model)
240
+
241
+ return model, model_cfg
242
+
243
+
244
+ def create_model_and_transforms(
245
+ model_name: str,
246
+ pretrained: str = "",
247
+ precision: str = "fp32",
248
+ device: torch.device = torch.device("cpu"),
249
+ jit: bool = False,
250
+ force_quick_gelu: bool = False,
251
+ # pretrained_image: bool = False,
252
+ ):
253
+ model = create_model(
254
+ model_name,
255
+ pretrained,
256
+ precision,
257
+ device,
258
+ jit,
259
+ force_quick_gelu=force_quick_gelu,
260
+ # pretrained_image=pretrained_image
261
+ )
262
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
263
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
264
+ return model, preprocess_train, preprocess_val
265
+
266
+
267
+ def list_models():
268
+ """enumerate available model architectures based on config files"""
269
+ return list(_MODEL_CONFIGS.keys())
270
+
271
+
272
+ def add_model_config(path):
273
+ """add model config path or file and update registry"""
274
+ if not isinstance(path, Path):
275
+ path = Path(path)
276
+ _MODEL_CONFIG_PATHS.append(path)
277
+ _rescan_model_configs()
audioldm/clap/open_clip/feature_fusion.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature Fusion for Varible-Length Data Processing
3
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class DAF(nn.Module):
12
+ """
13
+ 直接相加 DirectAddFuse
14
+ """
15
+
16
+ def __init__(self):
17
+ super(DAF, self).__init__()
18
+
19
+ def forward(self, x, residual):
20
+ return x + residual
21
+
22
+
23
+ class iAFF(nn.Module):
24
+ """
25
+ 多特征融合 iAFF
26
+ """
27
+
28
+ def __init__(self, channels=64, r=4, type="2D"):
29
+ super(iAFF, self).__init__()
30
+ inter_channels = int(channels // r)
31
+
32
+ if type == "1D":
33
+ # 本地注意力
34
+ self.local_att = nn.Sequential(
35
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
+ nn.BatchNorm1d(inter_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm1d(channels),
40
+ )
41
+
42
+ # 全局注意力
43
+ self.global_att = nn.Sequential(
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
+ nn.BatchNorm1d(inter_channels),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
+ nn.BatchNorm1d(channels),
50
+ )
51
+
52
+ # 第二次本地注意力
53
+ self.local_att2 = nn.Sequential(
54
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
+ nn.BatchNorm1d(inter_channels),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
+ nn.BatchNorm1d(channels),
59
+ )
60
+ # 第二次全局注意力
61
+ self.global_att2 = nn.Sequential(
62
+ nn.AdaptiveAvgPool1d(1),
63
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(inter_channels),
65
+ nn.ReLU(inplace=True),
66
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
+ nn.BatchNorm1d(channels),
68
+ )
69
+ elif type == "2D":
70
+ # 本地注意力
71
+ self.local_att = nn.Sequential(
72
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
+ nn.BatchNorm2d(inter_channels),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
+ nn.BatchNorm2d(channels),
77
+ )
78
+
79
+ # 全局注意力
80
+ self.global_att = nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm2d(inter_channels),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
+ nn.BatchNorm2d(channels),
87
+ )
88
+
89
+ # 第二次本地注意力
90
+ self.local_att2 = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+ # 第二次全局注意力
98
+ self.global_att2 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+ else:
107
+ raise f"the type is not supported"
108
+
109
+ self.sigmoid = nn.Sigmoid()
110
+
111
+ def forward(self, x, residual):
112
+ flag = False
113
+ xa = x + residual
114
+ if xa.size(0) == 1:
115
+ xa = torch.cat([xa, xa], dim=0)
116
+ flag = True
117
+ xl = self.local_att(xa)
118
+ xg = self.global_att(xa)
119
+ xlg = xl + xg
120
+ wei = self.sigmoid(xlg)
121
+ xi = x * wei + residual * (1 - wei)
122
+
123
+ xl2 = self.local_att2(xi)
124
+ xg2 = self.global_att(xi)
125
+ xlg2 = xl2 + xg2
126
+ wei2 = self.sigmoid(xlg2)
127
+ xo = x * wei2 + residual * (1 - wei2)
128
+ if flag:
129
+ xo = xo[0].unsqueeze(0)
130
+ return xo
131
+
132
+
133
+ class AFF(nn.Module):
134
+ """
135
+ 多特征融合 AFF
136
+ """
137
+
138
+ def __init__(self, channels=64, r=4, type="2D"):
139
+ super(AFF, self).__init__()
140
+ inter_channels = int(channels // r)
141
+
142
+ if type == "1D":
143
+ self.local_att = nn.Sequential(
144
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
+ nn.BatchNorm1d(inter_channels),
146
+ nn.ReLU(inplace=True),
147
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
+ nn.BatchNorm1d(channels),
149
+ )
150
+ self.global_att = nn.Sequential(
151
+ nn.AdaptiveAvgPool1d(1),
152
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
+ nn.BatchNorm1d(inter_channels),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
+ nn.BatchNorm1d(channels),
157
+ )
158
+ elif type == "2D":
159
+ self.local_att = nn.Sequential(
160
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
+ nn.BatchNorm2d(inter_channels),
162
+ nn.ReLU(inplace=True),
163
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
+ nn.BatchNorm2d(channels),
165
+ )
166
+ self.global_att = nn.Sequential(
167
+ nn.AdaptiveAvgPool2d(1),
168
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
+ nn.BatchNorm2d(inter_channels),
170
+ nn.ReLU(inplace=True),
171
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+ else:
175
+ raise f"the type is not supported."
176
+
177
+ self.sigmoid = nn.Sigmoid()
178
+
179
+ def forward(self, x, residual):
180
+ flag = False
181
+ xa = x + residual
182
+ if xa.size(0) == 1:
183
+ xa = torch.cat([xa, xa], dim=0)
184
+ flag = True
185
+ xl = self.local_att(xa)
186
+ xg = self.global_att(xa)
187
+ xlg = xl + xg
188
+ wei = self.sigmoid(xlg)
189
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
190
+ if flag:
191
+ xo = xo[0].unsqueeze(0)
192
+ return xo
audioldm/clap/open_clip/htsat.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+ import math
14
+ import warnings
15
+
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+ import torch.utils.checkpoint as checkpoint
18
+
19
+ import random
20
+
21
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
+ from torchlibrosa.augmentation import SpecAugmentation
23
+
24
+ from itertools import repeat
25
+ from .utils import do_mixup, interpolate
26
+
27
+ from .feature_fusion import iAFF, AFF, DAF
28
+
29
+ # from PyTorch internals
30
+ def _ntuple(n):
31
+ def parse(x):
32
+ if isinstance(x, collections.abc.Iterable):
33
+ return x
34
+ return tuple(repeat(x, n))
35
+
36
+ return parse
37
+
38
+
39
+ to_1tuple = _ntuple(1)
40
+ to_2tuple = _ntuple(2)
41
+ to_3tuple = _ntuple(3)
42
+ to_4tuple = _ntuple(4)
43
+ to_ntuple = _ntuple
44
+
45
+
46
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
49
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
50
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
51
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
52
+ 'survival rate' as the argument.
53
+ """
54
+ if drop_prob == 0.0 or not training:
55
+ return x
56
+ keep_prob = 1 - drop_prob
57
+ shape = (x.shape[0],) + (1,) * (
58
+ x.ndim - 1
59
+ ) # work with diff dim tensors, not just 2D ConvNets
60
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
61
+ random_tensor.floor_() # binarize
62
+ output = x.div(keep_prob) * random_tensor
63
+ return output
64
+
65
+
66
+ class DropPath(nn.Module):
67
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
68
+
69
+ def __init__(self, drop_prob=None):
70
+ super(DropPath, self).__init__()
71
+ self.drop_prob = drop_prob
72
+
73
+ def forward(self, x):
74
+ return drop_path(x, self.drop_prob, self.training)
75
+
76
+
77
+ class PatchEmbed(nn.Module):
78
+ """2D Image to Patch Embedding"""
79
+
80
+ def __init__(
81
+ self,
82
+ img_size=224,
83
+ patch_size=16,
84
+ in_chans=3,
85
+ embed_dim=768,
86
+ norm_layer=None,
87
+ flatten=True,
88
+ patch_stride=16,
89
+ enable_fusion=False,
90
+ fusion_type="None",
91
+ ):
92
+ super().__init__()
93
+ img_size = to_2tuple(img_size)
94
+ patch_size = to_2tuple(patch_size)
95
+ patch_stride = to_2tuple(patch_stride)
96
+ self.img_size = img_size
97
+ self.patch_size = patch_size
98
+ self.patch_stride = patch_stride
99
+ self.grid_size = (
100
+ img_size[0] // patch_stride[0],
101
+ img_size[1] // patch_stride[1],
102
+ )
103
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
104
+ self.flatten = flatten
105
+ self.in_chans = in_chans
106
+ self.embed_dim = embed_dim
107
+
108
+ self.enable_fusion = enable_fusion
109
+ self.fusion_type = fusion_type
110
+
111
+ padding = (
112
+ (patch_size[0] - patch_stride[0]) // 2,
113
+ (patch_size[1] - patch_stride[1]) // 2,
114
+ )
115
+
116
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
117
+ self.proj = nn.Conv2d(
118
+ in_chans * 4,
119
+ embed_dim,
120
+ kernel_size=patch_size,
121
+ stride=patch_stride,
122
+ padding=padding,
123
+ )
124
+ else:
125
+ self.proj = nn.Conv2d(
126
+ in_chans,
127
+ embed_dim,
128
+ kernel_size=patch_size,
129
+ stride=patch_stride,
130
+ padding=padding,
131
+ )
132
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
133
+
134
+ if (self.enable_fusion) and (
135
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
136
+ ):
137
+ self.mel_conv2d = nn.Conv2d(
138
+ in_chans,
139
+ embed_dim,
140
+ kernel_size=(patch_size[0], patch_size[1] * 3),
141
+ stride=(patch_stride[0], patch_stride[1] * 3),
142
+ padding=padding,
143
+ )
144
+ if self.fusion_type == "daf_2d":
145
+ self.fusion_model = DAF()
146
+ elif self.fusion_type == "aff_2d":
147
+ self.fusion_model = AFF(channels=embed_dim, type="2D")
148
+ elif self.fusion_type == "iaff_2d":
149
+ self.fusion_model = iAFF(channels=embed_dim, type="2D")
150
+
151
+ def forward(self, x, longer_idx=None):
152
+ if (self.enable_fusion) and (
153
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
154
+ ):
155
+ global_x = x[:, 0:1, :, :]
156
+
157
+ # global processing
158
+ B, C, H, W = global_x.shape
159
+ assert (
160
+ H == self.img_size[0] and W == self.img_size[1]
161
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162
+ global_x = self.proj(global_x)
163
+ TW = global_x.size(-1)
164
+ if len(longer_idx) > 0:
165
+ # local processing
166
+ local_x = x[longer_idx, 1:, :, :].contiguous()
167
+ B, C, H, W = local_x.shape
168
+ local_x = local_x.view(B * C, 1, H, W)
169
+ local_x = self.mel_conv2d(local_x)
170
+ local_x = local_x.view(
171
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
172
+ )
173
+ local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
174
+ TB, TC, TH, _ = local_x.size()
175
+ if local_x.size(-1) < TW:
176
+ local_x = torch.cat(
177
+ [
178
+ local_x,
179
+ torch.zeros(
180
+ (TB, TC, TH, TW - local_x.size(-1)),
181
+ device=global_x.device,
182
+ ),
183
+ ],
184
+ dim=-1,
185
+ )
186
+ else:
187
+ local_x = local_x[:, :, :, :TW]
188
+
189
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
190
+ x = global_x
191
+ else:
192
+ B, C, H, W = x.shape
193
+ assert (
194
+ H == self.img_size[0] and W == self.img_size[1]
195
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
196
+ x = self.proj(x)
197
+
198
+ if self.flatten:
199
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
200
+ x = self.norm(x)
201
+ return x
202
+
203
+
204
+ class Mlp(nn.Module):
205
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
206
+
207
+ def __init__(
208
+ self,
209
+ in_features,
210
+ hidden_features=None,
211
+ out_features=None,
212
+ act_layer=nn.GELU,
213
+ drop=0.0,
214
+ ):
215
+ super().__init__()
216
+ out_features = out_features or in_features
217
+ hidden_features = hidden_features or in_features
218
+ self.fc1 = nn.Linear(in_features, hidden_features)
219
+ self.act = act_layer()
220
+ self.fc2 = nn.Linear(hidden_features, out_features)
221
+ self.drop = nn.Dropout(drop)
222
+
223
+ def forward(self, x):
224
+ x = self.fc1(x)
225
+ x = self.act(x)
226
+ x = self.drop(x)
227
+ x = self.fc2(x)
228
+ x = self.drop(x)
229
+ return x
230
+
231
+
232
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
233
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
234
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
235
+ def norm_cdf(x):
236
+ # Computes standard normal cumulative distribution function
237
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
238
+
239
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
240
+ warnings.warn(
241
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
242
+ "The distribution of values may be incorrect.",
243
+ stacklevel=2,
244
+ )
245
+
246
+ with torch.no_grad():
247
+ # Values are generated by using a truncated uniform distribution and
248
+ # then using the inverse CDF for the normal distribution.
249
+ # Get upper and lower cdf values
250
+ l = norm_cdf((a - mean) / std)
251
+ u = norm_cdf((b - mean) / std)
252
+
253
+ # Uniformly fill tensor with values from [l, u], then translate to
254
+ # [2l-1, 2u-1].
255
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
256
+
257
+ # Use inverse cdf transform for normal distribution to get truncated
258
+ # standard normal
259
+ tensor.erfinv_()
260
+
261
+ # Transform to proper mean, std
262
+ tensor.mul_(std * math.sqrt(2.0))
263
+ tensor.add_(mean)
264
+
265
+ # Clamp to ensure it's in the proper range
266
+ tensor.clamp_(min=a, max=b)
267
+ return tensor
268
+
269
+
270
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
271
+ # type: (Tensor, float, float, float, float) -> Tensor
272
+ r"""Fills the input Tensor with values drawn from a truncated
273
+ normal distribution. The values are effectively drawn from the
274
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275
+ with values outside :math:`[a, b]` redrawn until they are within
276
+ the bounds. The method used for generating the random values works
277
+ best when :math:`a \leq \text{mean} \leq b`.
278
+ Args:
279
+ tensor: an n-dimensional `torch.Tensor`
280
+ mean: the mean of the normal distribution
281
+ std: the standard deviation of the normal distribution
282
+ a: the minimum cutoff value
283
+ b: the maximum cutoff value
284
+ Examples:
285
+ >>> w = torch.empty(3, 5)
286
+ >>> nn.init.trunc_normal_(w)
287
+ """
288
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289
+
290
+
291
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
292
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
293
+ if mode == "fan_in":
294
+ denom = fan_in
295
+ elif mode == "fan_out":
296
+ denom = fan_out
297
+ elif mode == "fan_avg":
298
+ denom = (fan_in + fan_out) / 2
299
+
300
+ variance = scale / denom
301
+
302
+ if distribution == "truncated_normal":
303
+ # constant is stddev of standard normal truncated to (-2, 2)
304
+ trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
305
+ elif distribution == "normal":
306
+ tensor.normal_(std=math.sqrt(variance))
307
+ elif distribution == "uniform":
308
+ bound = math.sqrt(3 * variance)
309
+ tensor.uniform_(-bound, bound)
310
+ else:
311
+ raise ValueError(f"invalid distribution {distribution}")
312
+
313
+
314
+ def lecun_normal_(tensor):
315
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
316
+
317
+
318
+ def window_partition(x, window_size):
319
+ """
320
+ Args:
321
+ x: (B, H, W, C)
322
+ window_size (int): window size
323
+ Returns:
324
+ windows: (num_windows*B, window_size, window_size, C)
325
+ """
326
+ B, H, W, C = x.shape
327
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
328
+ windows = (
329
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
330
+ )
331
+ return windows
332
+
333
+
334
+ def window_reverse(windows, window_size, H, W):
335
+ """
336
+ Args:
337
+ windows: (num_windows*B, window_size, window_size, C)
338
+ window_size (int): Window size
339
+ H (int): Height of image
340
+ W (int): Width of image
341
+ Returns:
342
+ x: (B, H, W, C)
343
+ """
344
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
345
+ x = windows.view(
346
+ B, H // window_size, W // window_size, window_size, window_size, -1
347
+ )
348
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
349
+ return x
350
+
351
+
352
+ class WindowAttention(nn.Module):
353
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
354
+ It supports both of shifted and non-shifted window.
355
+ Args:
356
+ dim (int): Number of input channels.
357
+ window_size (tuple[int]): The height and width of the window.
358
+ num_heads (int): Number of attention heads.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
361
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
362
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ dim,
368
+ window_size,
369
+ num_heads,
370
+ qkv_bias=True,
371
+ qk_scale=None,
372
+ attn_drop=0.0,
373
+ proj_drop=0.0,
374
+ ):
375
+
376
+ super().__init__()
377
+ self.dim = dim
378
+ self.window_size = window_size # Wh, Ww
379
+ self.num_heads = num_heads
380
+ head_dim = dim // num_heads
381
+ self.scale = qk_scale or head_dim**-0.5
382
+
383
+ # define a parameter table of relative position bias
384
+ self.relative_position_bias_table = nn.Parameter(
385
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
386
+ ) # 2*Wh-1 * 2*Ww-1, nH
387
+
388
+ # get pair-wise relative position index for each token inside the window
389
+ coords_h = torch.arange(self.window_size[0])
390
+ coords_w = torch.arange(self.window_size[1])
391
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
392
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
393
+ relative_coords = (
394
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
395
+ ) # 2, Wh*Ww, Wh*Ww
396
+ relative_coords = relative_coords.permute(
397
+ 1, 2, 0
398
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
399
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
400
+ relative_coords[:, :, 1] += self.window_size[1] - 1
401
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
402
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
403
+ self.register_buffer("relative_position_index", relative_position_index)
404
+
405
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
406
+ self.attn_drop = nn.Dropout(attn_drop)
407
+ self.proj = nn.Linear(dim, dim)
408
+ self.proj_drop = nn.Dropout(proj_drop)
409
+
410
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
411
+ self.softmax = nn.Softmax(dim=-1)
412
+
413
+ def forward(self, x, mask=None):
414
+ """
415
+ Args:
416
+ x: input features with shape of (num_windows*B, N, C)
417
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
418
+ """
419
+ B_, N, C = x.shape
420
+ qkv = (
421
+ self.qkv(x)
422
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
423
+ .permute(2, 0, 3, 1, 4)
424
+ )
425
+ q, k, v = (
426
+ qkv[0],
427
+ qkv[1],
428
+ qkv[2],
429
+ ) # make torchscript happy (cannot use tensor as tuple)
430
+
431
+ q = q * self.scale
432
+ attn = q @ k.transpose(-2, -1)
433
+
434
+ relative_position_bias = self.relative_position_bias_table[
435
+ self.relative_position_index.view(-1)
436
+ ].view(
437
+ self.window_size[0] * self.window_size[1],
438
+ self.window_size[0] * self.window_size[1],
439
+ -1,
440
+ ) # Wh*Ww,Wh*Ww,nH
441
+ relative_position_bias = relative_position_bias.permute(
442
+ 2, 0, 1
443
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
444
+ attn = attn + relative_position_bias.unsqueeze(0)
445
+
446
+ if mask is not None:
447
+ nW = mask.shape[0]
448
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
449
+ 1
450
+ ).unsqueeze(0)
451
+ attn = attn.view(-1, self.num_heads, N, N)
452
+ attn = self.softmax(attn)
453
+ else:
454
+ attn = self.softmax(attn)
455
+
456
+ attn = self.attn_drop(attn)
457
+
458
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
459
+ x = self.proj(x)
460
+ x = self.proj_drop(x)
461
+ return x, attn
462
+
463
+ def extra_repr(self):
464
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
465
+
466
+
467
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
468
+ class SwinTransformerBlock(nn.Module):
469
+ r"""Swin Transformer Block.
470
+ Args:
471
+ dim (int): Number of input channels.
472
+ input_resolution (tuple[int]): Input resulotion.
473
+ num_heads (int): Number of attention heads.
474
+ window_size (int): Window size.
475
+ shift_size (int): Shift size for SW-MSA.
476
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
477
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
478
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
479
+ drop (float, optional): Dropout rate. Default: 0.0
480
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
481
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
482
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
483
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
484
+ """
485
+
486
+ def __init__(
487
+ self,
488
+ dim,
489
+ input_resolution,
490
+ num_heads,
491
+ window_size=7,
492
+ shift_size=0,
493
+ mlp_ratio=4.0,
494
+ qkv_bias=True,
495
+ qk_scale=None,
496
+ drop=0.0,
497
+ attn_drop=0.0,
498
+ drop_path=0.0,
499
+ act_layer=nn.GELU,
500
+ norm_layer=nn.LayerNorm,
501
+ norm_before_mlp="ln",
502
+ ):
503
+ super().__init__()
504
+ self.dim = dim
505
+ self.input_resolution = input_resolution
506
+ self.num_heads = num_heads
507
+ self.window_size = window_size
508
+ self.shift_size = shift_size
509
+ self.mlp_ratio = mlp_ratio
510
+ self.norm_before_mlp = norm_before_mlp
511
+ if min(self.input_resolution) <= self.window_size:
512
+ # if window size is larger than input resolution, we don't partition windows
513
+ self.shift_size = 0
514
+ self.window_size = min(self.input_resolution)
515
+ assert (
516
+ 0 <= self.shift_size < self.window_size
517
+ ), "shift_size must in 0-window_size"
518
+
519
+ self.norm1 = norm_layer(dim)
520
+ self.attn = WindowAttention(
521
+ dim,
522
+ window_size=to_2tuple(self.window_size),
523
+ num_heads=num_heads,
524
+ qkv_bias=qkv_bias,
525
+ qk_scale=qk_scale,
526
+ attn_drop=attn_drop,
527
+ proj_drop=drop,
528
+ )
529
+
530
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
531
+ if self.norm_before_mlp == "ln":
532
+ self.norm2 = nn.LayerNorm(dim)
533
+ elif self.norm_before_mlp == "bn":
534
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
535
+ 1, 2
536
+ )
537
+ else:
538
+ raise NotImplementedError
539
+ mlp_hidden_dim = int(dim * mlp_ratio)
540
+ self.mlp = Mlp(
541
+ in_features=dim,
542
+ hidden_features=mlp_hidden_dim,
543
+ act_layer=act_layer,
544
+ drop=drop,
545
+ )
546
+
547
+ if self.shift_size > 0:
548
+ # calculate attention mask for SW-MSA
549
+ H, W = self.input_resolution
550
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
551
+ h_slices = (
552
+ slice(0, -self.window_size),
553
+ slice(-self.window_size, -self.shift_size),
554
+ slice(-self.shift_size, None),
555
+ )
556
+ w_slices = (
557
+ slice(0, -self.window_size),
558
+ slice(-self.window_size, -self.shift_size),
559
+ slice(-self.shift_size, None),
560
+ )
561
+ cnt = 0
562
+ for h in h_slices:
563
+ for w in w_slices:
564
+ img_mask[:, h, w, :] = cnt
565
+ cnt += 1
566
+
567
+ mask_windows = window_partition(
568
+ img_mask, self.window_size
569
+ ) # nW, window_size, window_size, 1
570
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
571
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
572
+ attn_mask = attn_mask.masked_fill(
573
+ attn_mask != 0, float(-100.0)
574
+ ).masked_fill(attn_mask == 0, float(0.0))
575
+ else:
576
+ attn_mask = None
577
+
578
+ self.register_buffer("attn_mask", attn_mask)
579
+
580
+ def forward(self, x):
581
+ # pdb.set_trace()
582
+ H, W = self.input_resolution
583
+ # print("H: ", H)
584
+ # print("W: ", W)
585
+ # pdb.set_trace()
586
+ B, L, C = x.shape
587
+ # assert L == H * W, "input feature has wrong size"
588
+
589
+ shortcut = x
590
+ x = self.norm1(x)
591
+ x = x.view(B, H, W, C)
592
+
593
+ # cyclic shift
594
+ if self.shift_size > 0:
595
+ shifted_x = torch.roll(
596
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
597
+ )
598
+ else:
599
+ shifted_x = x
600
+
601
+ # partition windows
602
+ x_windows = window_partition(
603
+ shifted_x, self.window_size
604
+ ) # nW*B, window_size, window_size, C
605
+ x_windows = x_windows.view(
606
+ -1, self.window_size * self.window_size, C
607
+ ) # nW*B, window_size*window_size, C
608
+
609
+ # W-MSA/SW-MSA
610
+ attn_windows, attn = self.attn(
611
+ x_windows, mask=self.attn_mask
612
+ ) # nW*B, window_size*window_size, C
613
+
614
+ # merge windows
615
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
616
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
617
+
618
+ # reverse cyclic shift
619
+ if self.shift_size > 0:
620
+ x = torch.roll(
621
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
622
+ )
623
+ else:
624
+ x = shifted_x
625
+ x = x.view(B, H * W, C)
626
+
627
+ # FFN
628
+ x = shortcut + self.drop_path(x)
629
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
630
+
631
+ return x, attn
632
+
633
+ def extra_repr(self):
634
+ return (
635
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
636
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
637
+ )
638
+
639
+
640
+ class PatchMerging(nn.Module):
641
+ r"""Patch Merging Layer.
642
+ Args:
643
+ input_resolution (tuple[int]): Resolution of input feature.
644
+ dim (int): Number of input channels.
645
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
646
+ """
647
+
648
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
649
+ super().__init__()
650
+ self.input_resolution = input_resolution
651
+ self.dim = dim
652
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
653
+ self.norm = norm_layer(4 * dim)
654
+
655
+ def forward(self, x):
656
+ """
657
+ x: B, H*W, C
658
+ """
659
+ H, W = self.input_resolution
660
+ B, L, C = x.shape
661
+ assert L == H * W, "input feature has wrong size"
662
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
663
+
664
+ x = x.view(B, H, W, C)
665
+
666
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
667
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
668
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
669
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
670
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
671
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
672
+
673
+ x = self.norm(x)
674
+ x = self.reduction(x)
675
+
676
+ return x
677
+
678
+ def extra_repr(self):
679
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
680
+
681
+
682
+ class BasicLayer(nn.Module):
683
+ """A basic Swin Transformer layer for one stage.
684
+ Args:
685
+ dim (int): Number of input channels.
686
+ input_resolution (tuple[int]): Input resolution.
687
+ depth (int): Number of blocks.
688
+ num_heads (int): Number of attention heads.
689
+ window_size (int): Local window size.
690
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
691
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
692
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
693
+ drop (float, optional): Dropout rate. Default: 0.0
694
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
695
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
696
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
697
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
698
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
699
+ """
700
+
701
+ def __init__(
702
+ self,
703
+ dim,
704
+ input_resolution,
705
+ depth,
706
+ num_heads,
707
+ window_size,
708
+ mlp_ratio=4.0,
709
+ qkv_bias=True,
710
+ qk_scale=None,
711
+ drop=0.0,
712
+ attn_drop=0.0,
713
+ drop_path=0.0,
714
+ norm_layer=nn.LayerNorm,
715
+ downsample=None,
716
+ use_checkpoint=False,
717
+ norm_before_mlp="ln",
718
+ ):
719
+
720
+ super().__init__()
721
+ self.dim = dim
722
+ self.input_resolution = input_resolution
723
+ self.depth = depth
724
+ self.use_checkpoint = use_checkpoint
725
+
726
+ # build blocks
727
+ self.blocks = nn.ModuleList(
728
+ [
729
+ SwinTransformerBlock(
730
+ dim=dim,
731
+ input_resolution=input_resolution,
732
+ num_heads=num_heads,
733
+ window_size=window_size,
734
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
735
+ mlp_ratio=mlp_ratio,
736
+ qkv_bias=qkv_bias,
737
+ qk_scale=qk_scale,
738
+ drop=drop,
739
+ attn_drop=attn_drop,
740
+ drop_path=drop_path[i]
741
+ if isinstance(drop_path, list)
742
+ else drop_path,
743
+ norm_layer=norm_layer,
744
+ norm_before_mlp=norm_before_mlp,
745
+ )
746
+ for i in range(depth)
747
+ ]
748
+ )
749
+
750
+ # patch merging layer
751
+ if downsample is not None:
752
+ self.downsample = downsample(
753
+ input_resolution, dim=dim, norm_layer=norm_layer
754
+ )
755
+ else:
756
+ self.downsample = None
757
+
758
+ def forward(self, x):
759
+ attns = []
760
+ for blk in self.blocks:
761
+ if self.use_checkpoint:
762
+ x = checkpoint.checkpoint(blk, x)
763
+ else:
764
+ x, attn = blk(x)
765
+ if not self.training:
766
+ attns.append(attn.unsqueeze(0))
767
+ if self.downsample is not None:
768
+ x = self.downsample(x)
769
+ if not self.training:
770
+ attn = torch.cat(attns, dim=0)
771
+ attn = torch.mean(attn, dim=0)
772
+ return x, attn
773
+
774
+ def extra_repr(self):
775
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
776
+
777
+
778
+ # The Core of HTSAT
779
+ class HTSAT_Swin_Transformer(nn.Module):
780
+ r"""HTSAT based on the Swin Transformer
781
+ Args:
782
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
783
+ patch_size (int | tuple(int)): Patch size. Default: 4
784
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
785
+ in_chans (int): Number of input image channels. Default: 1 (mono)
786
+ num_classes (int): Number of classes for classification head. Default: 527
787
+ embed_dim (int): Patch embedding dimension. Default: 96
788
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
789
+ num_heads (tuple(int)): Number of attention heads in different layers.
790
+ window_size (int): Window size. Default: 8
791
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
792
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
793
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
794
+ drop_rate (float): Dropout rate. Default: 0
795
+ attn_drop_rate (float): Attention dropout rate. Default: 0
796
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
797
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
798
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
799
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
800
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
801
+ config (module): The configuration Module from config.py
802
+ """
803
+
804
+ def __init__(
805
+ self,
806
+ spec_size=256,
807
+ patch_size=4,
808
+ patch_stride=(4, 4),
809
+ in_chans=1,
810
+ num_classes=527,
811
+ embed_dim=96,
812
+ depths=[2, 2, 6, 2],
813
+ num_heads=[4, 8, 16, 32],
814
+ window_size=8,
815
+ mlp_ratio=4.0,
816
+ qkv_bias=True,
817
+ qk_scale=None,
818
+ drop_rate=0.0,
819
+ attn_drop_rate=0.0,
820
+ drop_path_rate=0.1,
821
+ norm_layer=nn.LayerNorm,
822
+ ape=False,
823
+ patch_norm=True,
824
+ use_checkpoint=False,
825
+ norm_before_mlp="ln",
826
+ config=None,
827
+ enable_fusion=False,
828
+ fusion_type="None",
829
+ **kwargs,
830
+ ):
831
+ super(HTSAT_Swin_Transformer, self).__init__()
832
+
833
+ self.config = config
834
+ self.spec_size = spec_size
835
+ self.patch_stride = patch_stride
836
+ self.patch_size = patch_size
837
+ self.window_size = window_size
838
+ self.embed_dim = embed_dim
839
+ self.depths = depths
840
+ self.ape = ape
841
+ self.in_chans = in_chans
842
+ self.num_classes = num_classes
843
+ self.num_heads = num_heads
844
+ self.num_layers = len(self.depths)
845
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
846
+
847
+ self.drop_rate = drop_rate
848
+ self.attn_drop_rate = attn_drop_rate
849
+ self.drop_path_rate = drop_path_rate
850
+
851
+ self.qkv_bias = qkv_bias
852
+ self.qk_scale = None
853
+
854
+ self.patch_norm = patch_norm
855
+ self.norm_layer = norm_layer if self.patch_norm else None
856
+ self.norm_before_mlp = norm_before_mlp
857
+ self.mlp_ratio = mlp_ratio
858
+
859
+ self.use_checkpoint = use_checkpoint
860
+
861
+ self.enable_fusion = enable_fusion
862
+ self.fusion_type = fusion_type
863
+
864
+ # process mel-spec ; used only once
865
+ self.freq_ratio = self.spec_size // self.config.mel_bins
866
+ window = "hann"
867
+ center = True
868
+ pad_mode = "reflect"
869
+ ref = 1.0
870
+ amin = 1e-10
871
+ top_db = None
872
+ self.interpolate_ratio = 32 # Downsampled ratio
873
+ # Spectrogram extractor
874
+ self.spectrogram_extractor = Spectrogram(
875
+ n_fft=config.window_size,
876
+ hop_length=config.hop_size,
877
+ win_length=config.window_size,
878
+ window=window,
879
+ center=center,
880
+ pad_mode=pad_mode,
881
+ freeze_parameters=True,
882
+ )
883
+ # Logmel feature extractor
884
+ self.logmel_extractor = LogmelFilterBank(
885
+ sr=config.sample_rate,
886
+ n_fft=config.window_size,
887
+ n_mels=config.mel_bins,
888
+ fmin=config.fmin,
889
+ fmax=config.fmax,
890
+ ref=ref,
891
+ amin=amin,
892
+ top_db=top_db,
893
+ freeze_parameters=True,
894
+ )
895
+ # Spec augmenter
896
+ self.spec_augmenter = SpecAugmentation(
897
+ time_drop_width=64,
898
+ time_stripes_num=2,
899
+ freq_drop_width=8,
900
+ freq_stripes_num=2,
901
+ ) # 2 2
902
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
903
+
904
+ # split spctrogram into non-overlapping patches
905
+ self.patch_embed = PatchEmbed(
906
+ img_size=self.spec_size,
907
+ patch_size=self.patch_size,
908
+ in_chans=self.in_chans,
909
+ embed_dim=self.embed_dim,
910
+ norm_layer=self.norm_layer,
911
+ patch_stride=patch_stride,
912
+ enable_fusion=self.enable_fusion,
913
+ fusion_type=self.fusion_type,
914
+ )
915
+
916
+ num_patches = self.patch_embed.num_patches
917
+ patches_resolution = self.patch_embed.grid_size
918
+ self.patches_resolution = patches_resolution
919
+
920
+ # absolute position embedding
921
+ if self.ape:
922
+ self.absolute_pos_embed = nn.Parameter(
923
+ torch.zeros(1, num_patches, self.embed_dim)
924
+ )
925
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
926
+
927
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
928
+
929
+ # stochastic depth
930
+ dpr = [
931
+ x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
932
+ ] # stochastic depth decay rule
933
+
934
+ # build layers
935
+ self.layers = nn.ModuleList()
936
+ for i_layer in range(self.num_layers):
937
+ layer = BasicLayer(
938
+ dim=int(self.embed_dim * 2**i_layer),
939
+ input_resolution=(
940
+ patches_resolution[0] // (2**i_layer),
941
+ patches_resolution[1] // (2**i_layer),
942
+ ),
943
+ depth=self.depths[i_layer],
944
+ num_heads=self.num_heads[i_layer],
945
+ window_size=self.window_size,
946
+ mlp_ratio=self.mlp_ratio,
947
+ qkv_bias=self.qkv_bias,
948
+ qk_scale=self.qk_scale,
949
+ drop=self.drop_rate,
950
+ attn_drop=self.attn_drop_rate,
951
+ drop_path=dpr[
952
+ sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
953
+ ],
954
+ norm_layer=self.norm_layer,
955
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
956
+ use_checkpoint=use_checkpoint,
957
+ norm_before_mlp=self.norm_before_mlp,
958
+ )
959
+ self.layers.append(layer)
960
+
961
+ self.norm = self.norm_layer(self.num_features)
962
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
963
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
964
+
965
+ SF = (
966
+ self.spec_size
967
+ // (2 ** (len(self.depths) - 1))
968
+ // self.patch_stride[0]
969
+ // self.freq_ratio
970
+ )
971
+ self.tscam_conv = nn.Conv2d(
972
+ in_channels=self.num_features,
973
+ out_channels=self.num_classes,
974
+ kernel_size=(SF, 3),
975
+ padding=(0, 1),
976
+ )
977
+ self.head = nn.Linear(num_classes, num_classes)
978
+
979
+ if (self.enable_fusion) and (
980
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
981
+ ):
982
+ self.mel_conv1d = nn.Sequential(
983
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
984
+ nn.BatchNorm1d(64),
985
+ )
986
+ if self.fusion_type == "daf_1d":
987
+ self.fusion_model = DAF()
988
+ elif self.fusion_type == "aff_1d":
989
+ self.fusion_model = AFF(channels=64, type="1D")
990
+ elif self.fusion_type == "iaff_1d":
991
+ self.fusion_model = iAFF(channels=64, type="1D")
992
+
993
+ self.apply(self._init_weights)
994
+
995
+ def _init_weights(self, m):
996
+ if isinstance(m, nn.Linear):
997
+ trunc_normal_(m.weight, std=0.02)
998
+ if isinstance(m, nn.Linear) and m.bias is not None:
999
+ nn.init.constant_(m.bias, 0)
1000
+ elif isinstance(m, nn.LayerNorm):
1001
+ nn.init.constant_(m.bias, 0)
1002
+ nn.init.constant_(m.weight, 1.0)
1003
+
1004
+ @torch.jit.ignore
1005
+ def no_weight_decay(self):
1006
+ return {"absolute_pos_embed"}
1007
+
1008
+ @torch.jit.ignore
1009
+ def no_weight_decay_keywords(self):
1010
+ return {"relative_position_bias_table"}
1011
+
1012
+ def forward_features(self, x, longer_idx=None):
1013
+ # A deprecated optimization for using a hierarchical output from different blocks
1014
+
1015
+ frames_num = x.shape[2]
1016
+ x = self.patch_embed(x, longer_idx=longer_idx)
1017
+ if self.ape:
1018
+ x = x + self.absolute_pos_embed
1019
+ x = self.pos_drop(x)
1020
+ for i, layer in enumerate(self.layers):
1021
+ x, attn = layer(x)
1022
+ # for x
1023
+ x = self.norm(x)
1024
+ B, N, C = x.shape
1025
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1026
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1027
+ x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
1028
+ B, C, F, T = x.shape
1029
+ # group 2D CNN
1030
+ c_freq_bin = F // self.freq_ratio
1031
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1032
+ x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
1033
+ # get latent_output
1034
+ fine_grained_latent_output = torch.mean(x, dim=2)
1035
+ fine_grained_latent_output = interpolate(
1036
+ fine_grained_latent_output.permute(0, 2, 1).contiguous(),
1037
+ 8 * self.patch_stride[1],
1038
+ )
1039
+
1040
+ latent_output = self.avgpool(torch.flatten(x, 2))
1041
+ latent_output = torch.flatten(latent_output, 1)
1042
+
1043
+ # display the attention map, if needed
1044
+
1045
+ x = self.tscam_conv(x)
1046
+ x = torch.flatten(x, 2) # B, C, T
1047
+
1048
+ fpx = interpolate(
1049
+ torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
1050
+ )
1051
+
1052
+ x = self.avgpool(x)
1053
+ x = torch.flatten(x, 1)
1054
+
1055
+ output_dict = {
1056
+ "framewise_output": fpx, # already sigmoided
1057
+ "clipwise_output": torch.sigmoid(x),
1058
+ "fine_grained_embedding": fine_grained_latent_output,
1059
+ "embedding": latent_output,
1060
+ }
1061
+
1062
+ return output_dict
1063
+
1064
+ def crop_wav(self, x, crop_size, spe_pos=None):
1065
+ time_steps = x.shape[2]
1066
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1067
+ for i in range(len(x)):
1068
+ if spe_pos is None:
1069
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1070
+ else:
1071
+ crop_pos = spe_pos
1072
+ tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
1073
+ return tx
1074
+
1075
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1076
+ def reshape_wav2img(self, x):
1077
+ B, C, T, F = x.shape
1078
+ target_T = int(self.spec_size * self.freq_ratio)
1079
+ target_F = self.spec_size // self.freq_ratio
1080
+ assert (
1081
+ T <= target_T and F <= target_F
1082
+ ), "the wav size should less than or equal to the swin input size"
1083
+ # to avoid bicubic zero error
1084
+ if T < target_T:
1085
+ x = nn.functional.interpolate(
1086
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1087
+ )
1088
+ if F < target_F:
1089
+ x = nn.functional.interpolate(
1090
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1091
+ )
1092
+ x = x.permute(0, 1, 3, 2).contiguous()
1093
+ x = x.reshape(
1094
+ x.shape[0],
1095
+ x.shape[1],
1096
+ x.shape[2],
1097
+ self.freq_ratio,
1098
+ x.shape[3] // self.freq_ratio,
1099
+ )
1100
+ # print(x.shape)
1101
+ x = x.permute(0, 1, 3, 2, 4).contiguous()
1102
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1103
+ return x
1104
+
1105
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1106
+ def repeat_wat2img(self, x, cur_pos):
1107
+ B, C, T, F = x.shape
1108
+ target_T = int(self.spec_size * self.freq_ratio)
1109
+ target_F = self.spec_size // self.freq_ratio
1110
+ assert (
1111
+ T <= target_T and F <= target_F
1112
+ ), "the wav size should less than or equal to the swin input size"
1113
+ # to avoid bicubic zero error
1114
+ if T < target_T:
1115
+ x = nn.functional.interpolate(
1116
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1117
+ )
1118
+ if F < target_F:
1119
+ x = nn.functional.interpolate(
1120
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1121
+ )
1122
+ x = x.permute(0, 1, 3, 2).contiguous() # B C F T
1123
+ x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
1124
+ x = x.repeat(repeats=(1, 1, 4, 1))
1125
+ return x
1126
+
1127
+ def forward(
1128
+ self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
1129
+ ): # out_feat_keys: List[str] = None):
1130
+
1131
+ if self.enable_fusion and x["longer"].sum() == 0:
1132
+ # if no audio is longer than 10s, then randomly select one audio to be longer
1133
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
1134
+
1135
+ if not self.enable_fusion:
1136
+ x = x["waveform"].to(device=device, non_blocking=True)
1137
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1138
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1139
+ x = x.transpose(1, 3)
1140
+ x = self.bn0(x)
1141
+ x = x.transpose(1, 3)
1142
+ if self.training:
1143
+ x = self.spec_augmenter(x)
1144
+
1145
+ if self.training and mixup_lambda is not None:
1146
+ x = do_mixup(x, mixup_lambda)
1147
+
1148
+ x = self.reshape_wav2img(x)
1149
+ output_dict = self.forward_features(x)
1150
+ else:
1151
+ longer_list = x["longer"].to(device=device, non_blocking=True)
1152
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
1153
+ x = x.transpose(1, 3)
1154
+ x = self.bn0(x)
1155
+ x = x.transpose(1, 3)
1156
+ longer_list_idx = torch.where(longer_list)[0]
1157
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
1158
+ new_x = x[:, 0:1, :, :].clone().contiguous()
1159
+ if len(longer_list_idx) > 0:
1160
+ # local processing
1161
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
1162
+ FB, FC, FT, FF = fusion_x_local.size()
1163
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
1164
+ fusion_x_local = torch.permute(
1165
+ fusion_x_local, (0, 2, 1)
1166
+ ).contiguous()
1167
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
1168
+ fusion_x_local = fusion_x_local.view(
1169
+ FB, FC, FF, fusion_x_local.size(-1)
1170
+ )
1171
+ fusion_x_local = (
1172
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
1173
+ .contiguous()
1174
+ .flatten(2)
1175
+ )
1176
+ if fusion_x_local.size(-1) < FT:
1177
+ fusion_x_local = torch.cat(
1178
+ [
1179
+ fusion_x_local,
1180
+ torch.zeros(
1181
+ (FB, FF, FT - fusion_x_local.size(-1)),
1182
+ device=device,
1183
+ ),
1184
+ ],
1185
+ dim=-1,
1186
+ )
1187
+ else:
1188
+ fusion_x_local = fusion_x_local[:, :, :FT]
1189
+ # 1D fusion
1190
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
1191
+ new_x[longer_list_idx] = self.fusion_model(
1192
+ new_x[longer_list_idx], fusion_x_local
1193
+ )
1194
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
1195
+ else:
1196
+ x = new_x
1197
+
1198
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
1199
+ x = x # no change
1200
+
1201
+ if self.training:
1202
+ x = self.spec_augmenter(x)
1203
+ if self.training and mixup_lambda is not None:
1204
+ x = do_mixup(x, mixup_lambda)
1205
+
1206
+ x = self.reshape_wav2img(x)
1207
+ output_dict = self.forward_features(x, longer_idx=longer_list_idx)
1208
+
1209
+ # if infer_mode:
1210
+ # # in infer mode. we need to handle different length audio input
1211
+ # frame_num = x.shape[2]
1212
+ # target_T = int(self.spec_size * self.freq_ratio)
1213
+ # repeat_ratio = math.floor(target_T / frame_num)
1214
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
1215
+ # x = self.reshape_wav2img(x)
1216
+ # output_dict = self.forward_features(x)
1217
+ # else:
1218
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
1219
+ # if self.training:
1220
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1221
+ # x = self.reshape_wav2img(x)
1222
+ # output_dict = self.forward_features(x)
1223
+ # else:
1224
+ # # Change: Hard code here
1225
+ # overlap_size = (x.shape[2] - 1) // 4
1226
+ # output_dicts = []
1227
+ # crop_size = (x.shape[2] - 1) // 2
1228
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1229
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1230
+ # tx = self.reshape_wav2img(tx)
1231
+ # output_dicts.append(self.forward_features(tx))
1232
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1233
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1234
+ # for d in output_dicts:
1235
+ # clipwise_output += d["clipwise_output"]
1236
+ # framewise_output += d["framewise_output"]
1237
+ # clipwise_output = clipwise_output / len(output_dicts)
1238
+ # framewise_output = framewise_output / len(output_dicts)
1239
+ # output_dict = {
1240
+ # 'framewise_output': framewise_output,
1241
+ # 'clipwise_output': clipwise_output
1242
+ # }
1243
+ # else: # this part is typically used, and most easy one
1244
+ # x = self.reshape_wav2img(x)
1245
+ # output_dict = self.forward_features(x)
1246
+ # x = self.head(x)
1247
+
1248
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
1249
+
1250
+ return output_dict
1251
+
1252
+
1253
+ def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
1254
+ try:
1255
+
1256
+ assert audio_cfg.model_name in [
1257
+ "tiny",
1258
+ "base",
1259
+ "large",
1260
+ ], "model name for HTS-AT is wrong!"
1261
+ if audio_cfg.model_name == "tiny":
1262
+ model = HTSAT_Swin_Transformer(
1263
+ spec_size=256,
1264
+ patch_size=4,
1265
+ patch_stride=(4, 4),
1266
+ num_classes=audio_cfg.class_num,
1267
+ embed_dim=96,
1268
+ depths=[2, 2, 6, 2],
1269
+ num_heads=[4, 8, 16, 32],
1270
+ window_size=8,
1271
+ config=audio_cfg,
1272
+ enable_fusion=enable_fusion,
1273
+ fusion_type=fusion_type,
1274
+ )
1275
+ elif audio_cfg.model_name == "base":
1276
+ model = HTSAT_Swin_Transformer(
1277
+ spec_size=256,
1278
+ patch_size=4,
1279
+ patch_stride=(4, 4),
1280
+ num_classes=audio_cfg.class_num,
1281
+ embed_dim=128,
1282
+ depths=[2, 2, 12, 2],
1283
+ num_heads=[4, 8, 16, 32],
1284
+ window_size=8,
1285
+ config=audio_cfg,
1286
+ enable_fusion=enable_fusion,
1287
+ fusion_type=fusion_type,
1288
+ )
1289
+ elif audio_cfg.model_name == "large":
1290
+ model = HTSAT_Swin_Transformer(
1291
+ spec_size=256,
1292
+ patch_size=4,
1293
+ patch_stride=(4, 4),
1294
+ num_classes=audio_cfg.class_num,
1295
+ embed_dim=256,
1296
+ depths=[2, 2, 12, 2],
1297
+ num_heads=[4, 8, 16, 32],
1298
+ window_size=8,
1299
+ config=audio_cfg,
1300
+ enable_fusion=enable_fusion,
1301
+ fusion_type=fusion_type,
1302
+ )
1303
+
1304
+ return model
1305
+ except:
1306
+ raise RuntimeError(
1307
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
1308
+ )
audioldm/clap/open_clip/linear_probe.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from .model import MLPLayers
5
+
6
+
7
+ class LinearProbe(nn.Module):
8
+ def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
9
+ """
10
+ Args:
11
+ model: nn.Module
12
+ mlp: bool, if True, then use the MLP layer as the linear probe module
13
+ freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
14
+ in_ch: int, the output channel from CLAP model
15
+ out_ch: int, the output channel from linear probe (class_num)
16
+ act: torch.nn.functional, the activation function before the loss function
17
+ """
18
+ super().__init__()
19
+ in_ch = 512
20
+ self.clap_model = model
21
+ self.clap_model.text_branch = None # to save memory
22
+ self.freeze = freeze
23
+ if mlp:
24
+ self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
25
+ else:
26
+ self.lp_layer = nn.Linear(in_ch, out_ch)
27
+
28
+ if self.freeze:
29
+ for param in self.clap_model.parameters():
30
+ param.requires_grad = False
31
+
32
+ if act == "None":
33
+ self.act = None
34
+ elif act == "relu":
35
+ self.act = nn.ReLU()
36
+ elif act == "elu":
37
+ self.act = nn.ELU()
38
+ elif act == "prelu":
39
+ self.act = nn.PReLU(num_parameters=in_ch)
40
+ elif act == "softmax":
41
+ self.act = nn.Softmax(dim=-1)
42
+ elif act == "sigmoid":
43
+ self.act = nn.Sigmoid()
44
+
45
+ def forward(self, x, mix_lambda=None, device=None):
46
+ """
47
+ Args:
48
+ x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
49
+ mix_lambda: torch.tensor [batch], the mixup lambda
50
+ Returns:
51
+ class_prob: torch.tensor [batch, class_num]
52
+
53
+ """
54
+ # batchnorm cancel grandient
55
+ if self.freeze:
56
+ self.clap_model.eval()
57
+
58
+ x = self.clap_model.audio_projection(
59
+ self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[
60
+ "embedding"
61
+ ]
62
+ )
63
+ out = self.lp_layer(x)
64
+ if self.act is not None:
65
+ out = self.act(out)
66
+ return out
audioldm/clap/open_clip/loss.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.sharedctypes import Value
2
+ import torch
3
+ import torch.distributed.nn
4
+ from torch import distributed as dist, nn as nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
8
+
9
+ try:
10
+ import horovod.torch as hvd
11
+ except ImportError:
12
+ hvd = None
13
+
14
+
15
+ def gather_features(
16
+ audio_features,
17
+ text_features,
18
+ audio_features_mlp=None,
19
+ text_features_mlp=None,
20
+ local_loss=False,
21
+ gather_with_grad=False,
22
+ rank=0,
23
+ world_size=1,
24
+ use_horovod=False,
25
+ mlp_loss=False,
26
+ ):
27
+ if use_horovod:
28
+ assert hvd is not None, "Please install horovod"
29
+ if gather_with_grad:
30
+ all_audio_features = hvd.allgather(audio_features)
31
+ all_text_features = hvd.allgather(text_features)
32
+ if mlp_loss:
33
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
34
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
35
+ else:
36
+ with torch.no_grad():
37
+ all_audio_features = hvd.allgather(audio_features)
38
+ all_text_features = hvd.allgather(text_features)
39
+ if mlp_loss:
40
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
41
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
42
+ if not local_loss:
43
+ # ensure grads for local rank when all_* features don't have a gradient
44
+ gathered_audio_features = list(
45
+ all_audio_features.chunk(world_size, dim=0)
46
+ )
47
+ gathered_text_features = list(
48
+ all_text_features.chunk(world_size, dim=0)
49
+ )
50
+ gathered_audio_features[rank] = audio_features
51
+ gathered_text_features[rank] = text_features
52
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
53
+ all_text_features = torch.cat(gathered_text_features, dim=0)
54
+ if mlp_loss:
55
+ gathered_audio_features_mlp = list(
56
+ all_audio_features_mlp.chunk(world_size, dim=0)
57
+ )
58
+ gathered_text_features_mlp = list(
59
+ all_text_features_mlp.chunk(world_size, dim=0)
60
+ )
61
+ gathered_audio_features_mlp[rank] = audio_features_mlp
62
+ gathered_text_features_mlp[rank] = text_features_mlp
63
+ all_audio_features_mlp = torch.cat(
64
+ gathered_audio_features_mlp, dim=0
65
+ )
66
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
67
+ else:
68
+ # We gather tensors from all gpus
69
+ if gather_with_grad:
70
+ all_audio_features = torch.cat(
71
+ torch.distributed.nn.all_gather(audio_features), dim=0
72
+ )
73
+ all_text_features = torch.cat(
74
+ torch.distributed.nn.all_gather(text_features), dim=0
75
+ )
76
+ if mlp_loss:
77
+ all_audio_features_mlp = torch.cat(
78
+ torch.distributed.nn.all_gather(audio_features_mlp), dim=0
79
+ )
80
+ all_text_features_mlp = torch.cat(
81
+ torch.distributed.nn.all_gather(text_features_mlp), dim=0
82
+ )
83
+ else:
84
+ gathered_audio_features = [
85
+ torch.zeros_like(audio_features) for _ in range(world_size)
86
+ ]
87
+ gathered_text_features = [
88
+ torch.zeros_like(text_features) for _ in range(world_size)
89
+ ]
90
+ dist.all_gather(gathered_audio_features, audio_features)
91
+ dist.all_gather(gathered_text_features, text_features)
92
+ if mlp_loss:
93
+ gathered_audio_features_mlp = [
94
+ torch.zeros_like(audio_features_mlp) for _ in range(world_size)
95
+ ]
96
+ gathered_text_features_mlp = [
97
+ torch.zeros_like(text_features_mlp) for _ in range(world_size)
98
+ ]
99
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
100
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
101
+ if not local_loss:
102
+ # ensure grads for local rank when all_* features don't have a gradient
103
+ gathered_audio_features[rank] = audio_features
104
+ gathered_text_features[rank] = text_features
105
+ if mlp_loss:
106
+ gathered_audio_features_mlp[rank] = audio_features_mlp
107
+ gathered_text_features_mlp[rank] = text_features_mlp
108
+
109
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
110
+ all_text_features = torch.cat(gathered_text_features, dim=0)
111
+ if mlp_loss:
112
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
113
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
114
+ if mlp_loss:
115
+ return (
116
+ all_audio_features,
117
+ all_text_features,
118
+ all_audio_features_mlp,
119
+ all_text_features_mlp,
120
+ )
121
+ else:
122
+ return all_audio_features, all_text_features
123
+
124
+
125
+ class ClipLoss(nn.Module):
126
+ def __init__(
127
+ self,
128
+ local_loss=False,
129
+ gather_with_grad=False,
130
+ cache_labels=False,
131
+ rank=0,
132
+ world_size=1,
133
+ use_horovod=False,
134
+ mlp_loss=False,
135
+ weight_loss_kappa=0,
136
+ ):
137
+ super().__init__()
138
+ self.local_loss = local_loss
139
+ self.gather_with_grad = gather_with_grad
140
+ self.cache_labels = cache_labels
141
+ self.rank = rank
142
+ self.world_size = world_size
143
+ self.use_horovod = use_horovod
144
+ self.mlp_loss = mlp_loss
145
+ self.weighted_loss = bool(weight_loss_kappa != 0)
146
+ self.weight_loss_kappa = weight_loss_kappa
147
+ # cache state
148
+ self.prev_num_logits = 0
149
+ self.labels = {}
150
+
151
+ def forward(
152
+ self,
153
+ audio_features,
154
+ text_features,
155
+ logit_scale_a,
156
+ logit_scale_t=None,
157
+ audio_features_mlp=None,
158
+ text_features_mlp=None,
159
+ ):
160
+ device = audio_features.device
161
+ if self.mlp_loss:
162
+ if self.world_size > 1:
163
+ (
164
+ all_audio_features,
165
+ all_text_features,
166
+ all_audio_features_mlp,
167
+ all_text_features_mlp,
168
+ ) = gather_features(
169
+ audio_features=audio_features,
170
+ text_features=text_features,
171
+ audio_features_mlp=audio_features_mlp,
172
+ text_features_mlp=text_features_mlp,
173
+ local_loss=self.local_loss,
174
+ gather_with_grad=self.gather_with_grad,
175
+ rank=self.rank,
176
+ world_size=self.world_size,
177
+ use_horovod=self.use_horovod,
178
+ mlp_loss=self.mlp_loss,
179
+ )
180
+ if self.local_loss:
181
+ a_logits_per_audio = (
182
+ logit_scale_a * audio_features @ all_text_features_mlp.T
183
+ )
184
+ a_logits_per_text = (
185
+ logit_scale_a * text_features_mlp @ all_audio_features.T
186
+ )
187
+ t_logits_per_audio = (
188
+ logit_scale_t * audio_features_mlp @ all_text_features.T
189
+ )
190
+ t_logits_per_text = (
191
+ logit_scale_t * text_features @ all_audio_features_mlp.T
192
+ )
193
+ else:
194
+ a_logits_per_audio = (
195
+ logit_scale_a * all_audio_features @ all_text_features_mlp.T
196
+ )
197
+ a_logits_per_text = a_logits_per_audio.T
198
+ t_logits_per_audio = (
199
+ logit_scale_t * all_audio_features_mlp @ all_text_features.T
200
+ )
201
+ t_logits_per_text = t_logits_per_audio.T
202
+ else:
203
+ a_logits_per_audio = (
204
+ logit_scale_a * audio_features @ text_features_mlp.T
205
+ )
206
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
207
+ t_logits_per_audio = (
208
+ logit_scale_t * audio_features_mlp @ text_features.T
209
+ )
210
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
211
+
212
+ # calculated ground-truth and cache if enabled
213
+ num_logits = a_logits_per_audio.shape[0]
214
+ if self.prev_num_logits != num_logits or device not in self.labels:
215
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
216
+ if self.world_size > 1 and self.local_loss:
217
+ labels = labels + num_logits * self.rank
218
+ if self.cache_labels:
219
+ self.labels[device] = labels
220
+ self.prev_num_logits = num_logits
221
+ else:
222
+ labels = self.labels[device]
223
+
224
+ if not self.weighted_loss:
225
+ total_loss = (
226
+ F.cross_entropy(a_logits_per_audio, labels)
227
+ + F.cross_entropy(a_logits_per_text, labels)
228
+ + F.cross_entropy(t_logits_per_audio, labels)
229
+ + F.cross_entropy(t_logits_per_text, labels)
230
+ ) / 4
231
+ else:
232
+ audio_weight = (audio_features @ audio_features.T).detach()
233
+ audio_weight = (
234
+ torch.exp(
235
+ torch.sum(audio_weight, axis=1)
236
+ / (self.weight_loss_kappa * len(audio_weight))
237
+ )
238
+ ).detach()
239
+ text_weight = (text_features @ text_features.T).detach()
240
+ text_weight = (
241
+ torch.exp(
242
+ torch.sum(text_weight, axis=1)
243
+ / (self.weight_loss_kappa * len(text_features))
244
+ )
245
+ ).detach()
246
+ total_loss = (
247
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
248
+ + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
249
+ + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
250
+ + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
251
+ ) / 4
252
+ else:
253
+ if self.world_size > 1:
254
+ all_audio_features, all_text_features = gather_features(
255
+ audio_features=audio_features,
256
+ text_features=text_features,
257
+ local_loss=self.local_loss,
258
+ gather_with_grad=self.gather_with_grad,
259
+ rank=self.rank,
260
+ world_size=self.world_size,
261
+ use_horovod=self.use_horovod,
262
+ mlp_loss=self.mlp_loss,
263
+ )
264
+
265
+ if self.local_loss:
266
+ logits_per_audio = (
267
+ logit_scale_a * audio_features @ all_text_features.T
268
+ )
269
+ logits_per_text = (
270
+ logit_scale_a * text_features @ all_audio_features.T
271
+ )
272
+ else:
273
+ logits_per_audio = (
274
+ logit_scale_a * all_audio_features @ all_text_features.T
275
+ )
276
+ logits_per_text = logits_per_audio.T
277
+ else:
278
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
279
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
280
+
281
+ # calculated ground-truth and cache if enabled
282
+ num_logits = logits_per_audio.shape[0]
283
+ if self.prev_num_logits != num_logits or device not in self.labels:
284
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
285
+ if self.world_size > 1 and self.local_loss:
286
+ labels = labels + num_logits * self.rank
287
+ if self.cache_labels:
288
+ self.labels[device] = labels
289
+ self.prev_num_logits = num_logits
290
+ else:
291
+ labels = self.labels[device]
292
+ if not self.weighted_loss:
293
+ total_loss = (
294
+ F.cross_entropy(logits_per_audio, labels)
295
+ + F.cross_entropy(logits_per_text, labels)
296
+ ) / 2
297
+ else:
298
+ audio_weight = (all_audio_features @ all_audio_features.T).detach()
299
+ audio_weight = (
300
+ torch.exp(
301
+ torch.sum(audio_weight, axis=1)
302
+ / (self.weight_loss_kappa * len(all_audio_features))
303
+ )
304
+ ).detach()
305
+ text_weight = (all_text_features @ all_text_features.T).detach()
306
+ text_weight = (
307
+ torch.exp(
308
+ torch.sum(text_weight, axis=1)
309
+ / (self.weight_loss_kappa * len(all_text_features))
310
+ )
311
+ ).detach()
312
+ total_loss = (
313
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight)
314
+ + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
315
+ ) / 2
316
+ return total_loss
317
+
318
+
319
+ def lp_gather_features(pred, target, world_size=1, use_horovod=False):
320
+ if use_horovod:
321
+ assert hvd is not None, "Please install horovod"
322
+ with torch.no_grad():
323
+ all_preds = hvd.allgather(pred)
324
+ all_targets = hvd.allgath(target)
325
+ else:
326
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
327
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
328
+
329
+ dist.all_gather(gathered_preds, pred)
330
+ dist.all_gather(gathered_targets, target)
331
+ all_preds = torch.cat(gathered_preds, dim=0)
332
+ all_targets = torch.cat(gathered_targets, dim=0)
333
+
334
+ return all_preds, all_targets
335
+
336
+
337
+ def get_map(pred, target):
338
+ pred = torch.sigmoid(pred).numpy()
339
+ target = target.numpy()
340
+ return np.mean(average_precision_score(target, pred, average=None))
341
+
342
+
343
+ def get_acc(pred, target):
344
+ pred = torch.argmax(pred, 1).numpy()
345
+ target = torch.argmax(target, 1).numpy()
346
+ return accuracy_score(target, pred)
347
+
348
+
349
+ def get_mauc(pred, target):
350
+ pred = torch.sigmoid(pred).numpy()
351
+ target = target.numpy()
352
+ return np.mean(roc_auc_score(target, pred, average=None))
353
+
354
+
355
+ class LPMetrics(object):
356
+ def __init__(self, metric_names=["map", "acc", "mauc"]):
357
+ self.metrics = []
358
+ for name in metric_names:
359
+ self.metrics.append(self.get_metric(name))
360
+ self.metric_names = metric_names
361
+
362
+ def get_metric(self, name):
363
+ if name == "map":
364
+ return get_map
365
+ elif name == "acc":
366
+ return get_acc
367
+ elif name == "mauc":
368
+ return get_mauc
369
+ else:
370
+ raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
371
+
372
+ def evaluate_mertics(self, pred, target):
373
+ metric_dict = {}
374
+ for i in range(len(self.metric_names)):
375
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
376
+ return metric_dict
377
+
378
+
379
+ def calc_celoss(pred, target):
380
+ target = torch.argmax(target, 1).long()
381
+ return nn.CrossEntropyLoss()(pred, target)
382
+
383
+
384
+ class LPLoss(nn.Module):
385
+ def __init__(self, loss_name):
386
+ super().__init__()
387
+ if loss_name == "bce":
388
+ self.loss_func = nn.BCEWithLogitsLoss()
389
+ elif loss_name == "ce":
390
+ self.loss_func = calc_celoss
391
+ elif loss_name == "mse":
392
+ self.loss_func = nn.MSELoss()
393
+ else:
394
+ raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
395
+
396
+ def forward(self, pred, target):
397
+ loss = self.loss_func(pred, target)
398
+ return loss
audioldm/clap/open_clip/model.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLAP Model
2
+
3
+ Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ Adapted to the Audio Task.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from dataclasses import dataclass
9
+ from email.mime import audio
10
+ from typing import Tuple, Union, Callable, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn
16
+
17
+ from .timm_model import TimmModel
18
+ import logging
19
+ from .utils import freeze_batch_norm_2d
20
+
21
+ from .pann_model import create_pann_model
22
+ from .htsat import create_htsat_model
23
+ from transformers import BertModel, RobertaModel, BartModel
24
+ from transformers.tokenization_utils_base import BatchEncoding
25
+
26
+
27
+ class MLPLayers(nn.Module):
28
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
29
+ super(MLPLayers, self).__init__()
30
+ self.nonlin = nonlin
31
+ self.dropout = dropout
32
+
33
+ sequence = []
34
+ for u0, u1 in zip(units[:-1], units[1:]):
35
+ sequence.append(nn.Linear(u0, u1))
36
+ sequence.append(self.nonlin)
37
+ sequence.append(nn.Dropout(self.dropout))
38
+ sequence = sequence[:-2]
39
+
40
+ self.sequential = nn.Sequential(*sequence)
41
+
42
+ def forward(self, X):
43
+ X = self.sequential(X)
44
+ return X
45
+
46
+
47
+ class Bottleneck(nn.Module):
48
+ expansion = 4
49
+
50
+ def __init__(self, inplanes, planes, stride=1):
51
+ super().__init__()
52
+
53
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
54
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
55
+ self.bn1 = nn.BatchNorm2d(planes)
56
+
57
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
58
+ self.bn2 = nn.BatchNorm2d(planes)
59
+
60
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
61
+
62
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
63
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
64
+
65
+ self.relu = nn.ReLU(inplace=True)
66
+ self.downsample = None
67
+ self.stride = stride
68
+
69
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
70
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
71
+ self.downsample = nn.Sequential(
72
+ OrderedDict(
73
+ [
74
+ ("-1", nn.AvgPool2d(stride)),
75
+ (
76
+ "0",
77
+ nn.Conv2d(
78
+ inplanes,
79
+ planes * self.expansion,
80
+ 1,
81
+ stride=1,
82
+ bias=False,
83
+ ),
84
+ ),
85
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
86
+ ]
87
+ )
88
+ )
89
+
90
+ def forward(self, x: torch.Tensor):
91
+ identity = x
92
+
93
+ out = self.relu(self.bn1(self.conv1(x)))
94
+ out = self.relu(self.bn2(self.conv2(out)))
95
+ out = self.avgpool(out)
96
+ out = self.bn3(self.conv3(out))
97
+
98
+ if self.downsample is not None:
99
+ identity = self.downsample(x)
100
+
101
+ out += identity
102
+ out = self.relu(out)
103
+ return out
104
+
105
+
106
+ class AttentionPool2d(nn.Module):
107
+ def __init__(
108
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
109
+ ):
110
+ super().__init__()
111
+ self.positional_embedding = nn.Parameter(
112
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
113
+ )
114
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
115
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
116
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
117
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
118
+ self.num_heads = num_heads
119
+
120
+ def forward(self, x):
121
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
122
+ 2, 0, 1
123
+ ) # NCHW -> (HW)NC
124
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
125
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
126
+ x, _ = F.multi_head_attention_forward(
127
+ query=x,
128
+ key=x,
129
+ value=x,
130
+ embed_dim_to_check=x.shape[-1],
131
+ num_heads=self.num_heads,
132
+ q_proj_weight=self.q_proj.weight,
133
+ k_proj_weight=self.k_proj.weight,
134
+ v_proj_weight=self.v_proj.weight,
135
+ in_proj_weight=None,
136
+ in_proj_bias=torch.cat(
137
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
138
+ ),
139
+ bias_k=None,
140
+ bias_v=None,
141
+ add_zero_attn=False,
142
+ dropout_p=0,
143
+ out_proj_weight=self.c_proj.weight,
144
+ out_proj_bias=self.c_proj.bias,
145
+ use_separate_proj_weight=True,
146
+ training=self.training,
147
+ need_weights=False,
148
+ )
149
+
150
+ return x[0]
151
+
152
+
153
+ class ModifiedResNet(nn.Module):
154
+ """
155
+ A ResNet class that is similar to torchvision's but contains the following changes:
156
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
157
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
158
+ - The final pooling layer is a QKV attention instead of an average pool
159
+ """
160
+
161
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
162
+ super().__init__()
163
+ self.output_dim = output_dim
164
+ self.image_size = image_size
165
+
166
+ # the 3-layer stem
167
+ self.conv1 = nn.Conv2d(
168
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
169
+ )
170
+ self.bn1 = nn.BatchNorm2d(width // 2)
171
+ self.conv2 = nn.Conv2d(
172
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
173
+ )
174
+ self.bn2 = nn.BatchNorm2d(width // 2)
175
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
176
+ self.bn3 = nn.BatchNorm2d(width)
177
+ self.avgpool = nn.AvgPool2d(2)
178
+ self.relu = nn.ReLU(inplace=True)
179
+
180
+ # residual layers
181
+ self._inplanes = width # this is a *mutable* variable used during construction
182
+ self.layer1 = self._make_layer(width, layers[0])
183
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
184
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
185
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
186
+
187
+ embed_dim = width * 32 # the ResNet feature dimension
188
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
189
+
190
+ self.init_parameters()
191
+
192
+ def _make_layer(self, planes, blocks, stride=1):
193
+ layers = [Bottleneck(self._inplanes, planes, stride)]
194
+
195
+ self._inplanes = planes * Bottleneck.expansion
196
+ for _ in range(1, blocks):
197
+ layers.append(Bottleneck(self._inplanes, planes))
198
+
199
+ return nn.Sequential(*layers)
200
+
201
+ def init_parameters(self):
202
+ if self.attnpool is not None:
203
+ std = self.attnpool.c_proj.in_features**-0.5
204
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
205
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
206
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
207
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
208
+
209
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
210
+ for name, param in resnet_block.named_parameters():
211
+ if name.endswith("bn3.weight"):
212
+ nn.init.zeros_(param)
213
+
214
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
215
+ assert (
216
+ unlocked_groups == 0
217
+ ), "partial locking not currently supported for this model"
218
+ for param in self.parameters():
219
+ param.requires_grad = False
220
+ if freeze_bn_stats:
221
+ freeze_batch_norm_2d(self)
222
+
223
+ def stem(self, x):
224
+ for conv, bn in [
225
+ (self.conv1, self.bn1),
226
+ (self.conv2, self.bn2),
227
+ (self.conv3, self.bn3),
228
+ ]:
229
+ x = self.relu(bn(conv(x)))
230
+ x = self.avgpool(x)
231
+ return x
232
+
233
+ def forward(self, x):
234
+ x = self.stem(x)
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.attnpool(x)
240
+
241
+ return x
242
+
243
+
244
+ class LayerNorm(nn.LayerNorm):
245
+ """Subclass torch's LayerNorm to handle fp16."""
246
+
247
+ def forward(self, x: torch.Tensor):
248
+ orig_type = x.dtype
249
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
250
+ return x.to(orig_type)
251
+
252
+
253
+ class QuickGELU(nn.Module):
254
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
255
+ def forward(self, x: torch.Tensor):
256
+ return x * torch.sigmoid(1.702 * x)
257
+
258
+
259
+ class ResidualAttentionBlock(nn.Module):
260
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
261
+ super().__init__()
262
+
263
+ self.attn = nn.MultiheadAttention(d_model, n_head)
264
+ self.ln_1 = LayerNorm(d_model)
265
+ self.mlp = nn.Sequential(
266
+ OrderedDict(
267
+ [
268
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
269
+ ("gelu", act_layer()),
270
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
271
+ ]
272
+ )
273
+ )
274
+ self.ln_2 = LayerNorm(d_model)
275
+
276
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
278
+
279
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
280
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
281
+ x = x + self.mlp(self.ln_2(x))
282
+ return x
283
+
284
+
285
+ class Transformer(nn.Module):
286
+ def __init__(
287
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
288
+ ):
289
+ super().__init__()
290
+ self.width = width
291
+ self.layers = layers
292
+ self.resblocks = nn.ModuleList(
293
+ [
294
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
295
+ for _ in range(layers)
296
+ ]
297
+ )
298
+
299
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
300
+ for r in self.resblocks:
301
+ x = r(x, attn_mask=attn_mask)
302
+ return x
303
+
304
+
305
+ class VisualTransformer(nn.Module):
306
+ def __init__(
307
+ self,
308
+ image_size: int,
309
+ patch_size: int,
310
+ width: int,
311
+ layers: int,
312
+ heads: int,
313
+ output_dim: int,
314
+ act_layer: Callable = nn.GELU,
315
+ ):
316
+ super().__init__()
317
+ self.image_size = image_size
318
+ self.output_dim = output_dim
319
+ self.conv1 = nn.Conv2d(
320
+ in_channels=3,
321
+ out_channels=width,
322
+ kernel_size=patch_size,
323
+ stride=patch_size,
324
+ bias=False,
325
+ )
326
+
327
+ scale = width**-0.5
328
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
329
+ self.positional_embedding = nn.Parameter(
330
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
331
+ )
332
+ self.ln_pre = LayerNorm(width)
333
+
334
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
335
+
336
+ self.ln_post = LayerNorm(width)
337
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
338
+
339
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
340
+ assert (
341
+ unlocked_groups == 0
342
+ ), "partial locking not currently supported for this model"
343
+ for param in self.parameters():
344
+ param.requires_grad = False
345
+
346
+ def forward(self, x: torch.Tensor):
347
+ x = self.conv1(x) # shape = [*, width, grid, grid]
348
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
349
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
350
+ x = torch.cat(
351
+ [
352
+ self.class_embedding.to(x.dtype)
353
+ + torch.zeros(
354
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
355
+ ),
356
+ x,
357
+ ],
358
+ dim=1,
359
+ ) # shape = [*, grid ** 2 + 1, width]
360
+ x = x + self.positional_embedding.to(x.dtype)
361
+ x = self.ln_pre(x)
362
+
363
+ x = x.permute(1, 0, 2) # NLD -> LND
364
+ x = self.text_branch(x)
365
+ x = x.permute(1, 0, 2) # LND -> NLD
366
+
367
+ x = self.ln_post(x[:, 0, :])
368
+
369
+ if self.proj is not None:
370
+ x = x @ self.proj
371
+
372
+ return x
373
+
374
+
375
+ @dataclass
376
+ class CLAPVisionCfg:
377
+ layers: Union[Tuple[int, int, int, int], int] = 12
378
+ width: int = 768
379
+ patch_size: int = 16
380
+ image_size: Union[Tuple[int, int], int] = 224
381
+ timm_model_name: str = (
382
+ None # a valid model name overrides layers, width, patch_size
383
+ )
384
+ timm_model_pretrained: bool = (
385
+ False # use (imagenet) pretrained weights for named model
386
+ )
387
+ timm_pool: str = (
388
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
389
+ )
390
+ timm_proj: str = (
391
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
392
+ )
393
+
394
+
395
+ # Audio Config Class
396
+ @dataclass
397
+ class CLAPAudioCfp:
398
+ model_type: str = "PANN"
399
+ model_name: str = "Cnn14"
400
+ sample_rate: int = 48000
401
+ # Param
402
+ audio_length: int = 1024
403
+ window_size: int = 1024
404
+ hop_size: int = 1024
405
+ fmin: int = 50
406
+ fmax: int = 14000
407
+ class_num: int = 527
408
+ mel_bins: int = 64
409
+ clip_samples: int = 480000
410
+
411
+
412
+ @dataclass
413
+ class CLAPTextCfg:
414
+ context_length: int
415
+ vocab_size: int
416
+ width: int
417
+ heads: int
418
+ layers: int
419
+ model_type: str
420
+
421
+
422
+ class CLAP(nn.Module):
423
+ def __init__(
424
+ self,
425
+ embed_dim: int,
426
+ audio_cfg: CLAPAudioCfp,
427
+ text_cfg: CLAPTextCfg,
428
+ quick_gelu: bool = False,
429
+ enable_fusion: bool = False,
430
+ fusion_type: str = "None",
431
+ joint_embed_shape: int = 512,
432
+ mlp_act: str = "relu",
433
+ ):
434
+ super().__init__()
435
+ if isinstance(audio_cfg, dict):
436
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
437
+ if isinstance(text_cfg, dict):
438
+ text_cfg = CLAPTextCfg(**text_cfg)
439
+
440
+ self.audio_cfg = audio_cfg
441
+ self.text_cfg = text_cfg
442
+ self.enable_fusion = enable_fusion
443
+ self.fusion_type = fusion_type
444
+ self.joint_embed_shape = joint_embed_shape
445
+ self.mlp_act = mlp_act
446
+
447
+ self.context_length = text_cfg.context_length
448
+
449
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
450
+ # memory efficient in recent PyTorch releases (>= 1.10).
451
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
452
+ act_layer = QuickGELU if quick_gelu else nn.GELU
453
+
454
+ if mlp_act == "relu":
455
+ mlp_act_layer = nn.ReLU()
456
+ elif mlp_act == "gelu":
457
+ mlp_act_layer = nn.GELU()
458
+ else:
459
+ raise NotImplementedError
460
+
461
+ # audio branch
462
+ # audio branch parameters
463
+ if audio_cfg.model_type == "PANN":
464
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
465
+ elif audio_cfg.model_type == "HTSAT":
466
+ self.audio_branch = create_htsat_model(
467
+ audio_cfg, enable_fusion, fusion_type
468
+ )
469
+ else:
470
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
471
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
472
+
473
+ # text branch
474
+ # text branch parameters
475
+ if text_cfg.model_type == "transformer":
476
+ self.text_branch = Transformer(
477
+ width=text_cfg.width,
478
+ layers=text_cfg.layers,
479
+ heads=text_cfg.heads,
480
+ act_layer=act_layer,
481
+ )
482
+ self.vocab_size = text_cfg.vocab_size
483
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
484
+ self.positional_embedding = nn.Parameter(
485
+ torch.empty(self.context_length, text_cfg.width)
486
+ )
487
+ self.ln_final = LayerNorm(text_cfg.width)
488
+ self.text_transform = MLPLayers(
489
+ units=[
490
+ self.joint_embed_shape,
491
+ self.joint_embed_shape,
492
+ self.joint_embed_shape,
493
+ ],
494
+ dropout=0.1,
495
+ )
496
+ self.text_projection = nn.Sequential(
497
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
498
+ mlp_act_layer,
499
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
500
+ )
501
+ elif text_cfg.model_type == "bert":
502
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
503
+ self.text_transform = MLPLayers(
504
+ units=[
505
+ self.joint_embed_shape,
506
+ self.joint_embed_shape,
507
+ self.joint_embed_shape,
508
+ ],
509
+ dropout=0.1,
510
+ )
511
+ self.text_projection = nn.Sequential(
512
+ nn.Linear(768, self.joint_embed_shape),
513
+ mlp_act_layer,
514
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
515
+ )
516
+ elif text_cfg.model_type == "roberta":
517
+ self.text_branch = RobertaModel.from_pretrained("roberta-base")
518
+ self.text_transform = MLPLayers(
519
+ units=[
520
+ self.joint_embed_shape,
521
+ self.joint_embed_shape,
522
+ self.joint_embed_shape,
523
+ ],
524
+ dropout=0.1,
525
+ )
526
+ self.text_projection = nn.Sequential(
527
+ nn.Linear(768, self.joint_embed_shape),
528
+ mlp_act_layer,
529
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
530
+ )
531
+ elif text_cfg.model_type == "bart":
532
+ self.text_branch = BartModel.from_pretrained("facebook/bart-base")
533
+ self.text_transform = MLPLayers(
534
+ units=[
535
+ self.joint_embed_shape,
536
+ self.joint_embed_shape,
537
+ self.joint_embed_shape,
538
+ ],
539
+ dropout=0.1,
540
+ )
541
+ self.text_projection = nn.Sequential(
542
+ nn.Linear(768, self.joint_embed_shape),
543
+ mlp_act_layer,
544
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
545
+ )
546
+ else:
547
+ logging.error(f"Model config for {text_cfg.model_type} not found")
548
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
549
+ self.text_branch_type = text_cfg.model_type
550
+ # text branch parameters
551
+
552
+ # audio branch parameters
553
+ self.audio_transform = MLPLayers(
554
+ units=[
555
+ self.joint_embed_shape,
556
+ self.joint_embed_shape,
557
+ self.joint_embed_shape,
558
+ ],
559
+ dropout=0.1,
560
+ )
561
+
562
+ # below here is text branch parameters
563
+
564
+ # ============================================================================================================
565
+ self.audio_projection = nn.Sequential(
566
+ nn.Linear(embed_dim, self.joint_embed_shape),
567
+ mlp_act_layer,
568
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
569
+ )
570
+
571
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
573
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
574
+
575
+ self.init_text_branch_parameters()
576
+
577
+ def init_text_branch_parameters(self):
578
+ if self.text_branch_type == "transformer":
579
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
580
+ nn.init.normal_(self.positional_embedding, std=0.01)
581
+ proj_std = (self.text_branch.width**-0.5) * (
582
+ (2 * self.text_branch.layers) ** -0.5
583
+ )
584
+ attn_std = self.text_branch.width**-0.5
585
+ fc_std = (2 * self.text_branch.width) ** -0.5
586
+ for block in self.text_branch.resblocks:
587
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
588
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
589
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
590
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
591
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
592
+ width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
593
+ elif self.text_branch_type == "bart":
594
+ width = self.text_branch.shared.weight.shape[-1]
595
+ else:
596
+ width = self.text_branch.width
597
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
598
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
599
+
600
+ # deprecated
601
+ # if hasattr(self.visual, 'init_parameters'):
602
+ # self.visual.init_parameters()
603
+
604
+ # if self.text_projection is not None:
605
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
606
+
607
+ def build_attention_mask(self):
608
+ # lazily create causal attention mask, with full attention between the vision tokens
609
+ # pytorch uses additive attention mask; fill with -inf
610
+ mask = torch.empty(self.context_length, self.context_length)
611
+ mask.fill_(float("-inf"))
612
+ mask.triu_(1) # zero out the lower diagonal
613
+ return mask
614
+
615
+ def encode_audio(self, audio, device):
616
+ return self.audio_branch(
617
+ audio, mixup_lambda=None, device=device
618
+ ) # mix lambda needs to add
619
+
620
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
621
+ # tmp = {}
622
+ # for k in x[0].keys():
623
+ # tmp[k] = []
624
+ # for i in range(len(x)):
625
+ # tmp[k].append(x[i][k][:77])
626
+ # for k in x[0].keys():
627
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
628
+ # return tmp
629
+
630
+ def encode_text(self, text, device):
631
+ if self.text_branch_type == "transformer":
632
+ text = text.to(device=device, non_blocking=True)
633
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
634
+
635
+ x = x + self.positional_embedding
636
+ x = x.permute(1, 0, 2) # NLD -> LND
637
+ x = self.text_branch(x, attn_mask=self.attn_mask)
638
+ x = x.permute(1, 0, 2) # LND -> NLD
639
+ x = self.ln_final(x)
640
+
641
+ # x.shape = [batch_size, n_ctx, transformer.width]
642
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
643
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
644
+ elif self.text_branch_type == "bert":
645
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
646
+ # text = BatchEncoding(text)
647
+ x = self.text_branch(
648
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
649
+ attention_mask=text["attention_mask"].to(
650
+ device=device, non_blocking=True
651
+ ),
652
+ token_type_ids=text["token_type_ids"].to(
653
+ device=device, non_blocking=True
654
+ ),
655
+ )["pooler_output"]
656
+ x = self.text_projection(x)
657
+ elif self.text_branch_type == "roberta":
658
+ x = self.text_branch(
659
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
660
+ attention_mask=text["attention_mask"].to(
661
+ device=device, non_blocking=True
662
+ ),
663
+ )["pooler_output"]
664
+ x = self.text_projection(x)
665
+ elif self.text_branch_type == "bart":
666
+ x = torch.mean(
667
+ self.text_branch(
668
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
669
+ attention_mask=text["attention_mask"].to(
670
+ device=device, non_blocking=True
671
+ ),
672
+ )["encoder_last_hidden_state"],
673
+ axis=1,
674
+ )
675
+ x = self.text_projection(x)
676
+ else:
677
+ logging.error(f"Model type {self.text_branch_type} not found")
678
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
679
+ return x
680
+
681
+ def forward(self, audio, text, device=None):
682
+ """Forward audio and text into the CLAP
683
+
684
+ Parameters
685
+ ----------
686
+ audio: torch.Tensor (batch_size, audio_length)
687
+ the time-domain audio input / the batch of mel_spec and longer list.
688
+ text: torch.Tensor () // need to add
689
+ the text token input
690
+ """
691
+ if device is None:
692
+ if audio is not None:
693
+ device = audio.device
694
+ elif text is not None:
695
+ device = text.device
696
+ if audio is None and text is None:
697
+ # a hack to get the logit scale
698
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
699
+ elif audio is None:
700
+ return self.encode_text(text, device=device)
701
+ elif text is None:
702
+ return self.audio_projection(
703
+ self.encode_audio(audio, device=device)["embedding"]
704
+ )
705
+ audio_features = self.audio_projection(
706
+ self.encode_audio(audio, device=device)["embedding"]
707
+ )
708
+ audio_features = F.normalize(audio_features, dim=-1)
709
+
710
+ text_features = self.encode_text(text, device=device)
711
+ # print("text_features", text_features)
712
+ # print("text_features.shape", text_features.shape)
713
+ # print("text_features.type", type(text_features))
714
+ text_features = F.normalize(text_features, dim=-1)
715
+
716
+ audio_features_mlp = self.audio_transform(audio_features)
717
+ text_features_mlp = self.text_transform(text_features)
718
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
719
+ return (
720
+ audio_features,
721
+ text_features,
722
+ audio_features_mlp,
723
+ text_features_mlp,
724
+ self.logit_scale_a.exp(),
725
+ self.logit_scale_t.exp(),
726
+ )
727
+
728
+ def get_logit_scale(self):
729
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
730
+
731
+ def get_text_embedding(self, data):
732
+ """Get the text embedding from the model
733
+
734
+ Parameters
735
+ ----------
736
+ data: torch.Tensor
737
+ a tensor of text embedding
738
+
739
+ Returns
740
+ ----------
741
+ text_embed: torch.Tensor
742
+ a tensor of text_embeds (N, D)
743
+
744
+ """
745
+ device = next(self.parameters()).device
746
+ for k in data:
747
+ data[k] = data[k].to(device)
748
+ text_embeds = self.encode_text(data, device=device)
749
+ text_embeds = F.normalize(text_embeds, dim=-1)
750
+
751
+ return text_embeds
752
+
753
+ def get_audio_embedding(self, data):
754
+ """Get the audio embedding from the model
755
+
756
+ Parameters
757
+ ----------
758
+ data: a list of dict
759
+ the audio input dict list from 'get_audio_feature' method
760
+
761
+ Returns
762
+ ----------
763
+ audio_embed: torch.Tensor
764
+ a tensor of audio_embeds (N, D)
765
+
766
+ """
767
+ device = next(self.parameters()).device
768
+ input_dict = {}
769
+ keys = data[0].keys()
770
+ for k in keys:
771
+ input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
772
+ device
773
+ )
774
+
775
+ audio_embeds = self.audio_projection(
776
+ self.encode_audio(input_dict, device=device)["embedding"]
777
+ )
778
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
779
+
780
+ return audio_embeds
781
+
782
+ def audio_infer(self, audio, hopsize=None, device=None):
783
+ """Forward one audio and produce the audio embedding
784
+
785
+ Parameters
786
+ ----------
787
+ audio: (audio_length)
788
+ the time-domain audio input, notice that it must be only one input
789
+ hopsize: int
790
+ the overlap hopsize as the sliding window
791
+
792
+ Returns
793
+ ----------
794
+ output_dict: {
795
+ key: [n, (embedding_shape)] if "HTS-AT"
796
+ or
797
+ key: [(embedding_shape)] if "PANN"
798
+ }
799
+ the list of key values of the audio branch
800
+
801
+ """
802
+
803
+ assert not self.training, "the inference mode must be run at eval stage"
804
+ output_dict = {}
805
+ # PANN
806
+ if self.audio_cfg.model_type == "PANN":
807
+ audio_input = audio.unsqueeze(dim=0)
808
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
809
+ key
810
+ ].squeeze(dim=0)
811
+ elif self.audio_cfg.model_type == "HTSAT":
812
+ # repeat
813
+ audio_len = len(audio)
814
+ k = self.audio_cfg.clip_samples // audio_len
815
+ if k > 1:
816
+ audio = audio.repeat(k)
817
+ audio_len = len(audio)
818
+
819
+ if hopsize is None:
820
+ hopsize = min(hopsize, audio_len)
821
+
822
+ if audio_len > self.audio_cfg.clip_samples:
823
+ audio_input = [
824
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
825
+ for pos in range(
826
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
827
+ )
828
+ ]
829
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
830
+ audio_input = torch.stack(audio_input)
831
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
832
+ else:
833
+ audio_input = audio.unsqueeze(dim=0)
834
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
835
+ key
836
+ ].squeeze(dim=0)
837
+
838
+ return output_dict
839
+
840
+
841
+ def convert_weights_to_fp16(model: nn.Module):
842
+ """Convert applicable model parameters to fp16"""
843
+
844
+ def _convert_weights_to_fp16(l):
845
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
846
+ l.weight.data = l.weight.data.half()
847
+ if l.bias is not None:
848
+ l.bias.data = l.bias.data.half()
849
+
850
+ if isinstance(l, nn.MultiheadAttention):
851
+ for attr in [
852
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
853
+ "in_proj_bias",
854
+ "bias_k",
855
+ "bias_v",
856
+ ]:
857
+ tensor = getattr(l, attr)
858
+ if tensor is not None:
859
+ tensor.data = tensor.data.half()
860
+
861
+ for name in ["text_projection", "proj"]:
862
+ if hasattr(l, name):
863
+ attr = getattr(l, name)
864
+ if attr is not None:
865
+ attr.data = attr.data.half()
866
+
867
+ model.apply(_convert_weights_to_fp16)
868
+
869
+
870
+ # Ignore the state dict of the vision part
871
+ def build_model_from_openai_state_dict(
872
+ state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
873
+ ):
874
+
875
+ embed_dim = model_cfg["embed_dim"]
876
+ audio_cfg = model_cfg["audio_cfg"]
877
+ text_cfg = model_cfg["text_cfg"]
878
+ context_length = state_dict["positional_embedding"].shape[0]
879
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
880
+ transformer_width = state_dict["ln_final.weight"].shape[0]
881
+ transformer_heads = transformer_width // 64
882
+ transformer_layers = len(
883
+ set(
884
+ k.split(".")[2]
885
+ for k in state_dict
886
+ if k.startswith(f"transformer.resblocks")
887
+ )
888
+ )
889
+
890
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
891
+ text_cfg = CLAPTextCfg(**text_cfg)
892
+
893
+ model = CLAP(
894
+ embed_dim,
895
+ audio_cfg=audio_cfg,
896
+ text_cfg=text_cfg,
897
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
898
+ enable_fusion=enable_fusion,
899
+ fusion_type=fusion_type,
900
+ )
901
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
902
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
903
+ pop_keys = list(state_dict.keys())[::]
904
+ # pop the visual branch saved weights
905
+ for key in pop_keys:
906
+ if key.startswith("visual."):
907
+ state_dict.pop(key, None)
908
+
909
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
910
+ state_dict.pop(key, None)
911
+
912
+ # not use fp16
913
+ # convert_weights_to_fp16(model)
914
+ model.load_state_dict(state_dict, strict=False)
915
+ return model.eval()
916
+
917
+
918
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
919
+ model.eval()
920
+ audio_length = model.audio_cfg.audio_length
921
+ example_audio = torch.ones((batch_size, audio_length), device=device)
922
+ example_text = torch.zeros(
923
+ (batch_size, model.context_length), dtype=torch.int, device=device
924
+ )
925
+ model = torch.jit.trace_module(
926
+ model,
927
+ inputs=dict(
928
+ forward=(example_audio, example_text),
929
+ encode_text=(example_text,),
930
+ encode_image=(example_audio,),
931
+ ),
932
+ )
933
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
934
+ return model
audioldm/clap/open_clip/model_configs/HTSAT-base.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "base"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/HTSAT-large.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "large"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/HTSAT-tiny.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "HTSAT",
14
+ "model_name": "tiny"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/PANN-10.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn10"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 18000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 960000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 360,
10
+ "fmin": 50,
11
+ "fmax": 8000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 4
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1536,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/PANN-14.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 2048,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn14"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/PANN-6.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "audio_cfg": {
4
+ "audio_length": 1024,
5
+ "clip_samples": 480000,
6
+ "mel_bins": 64,
7
+ "sample_rate": 48000,
8
+ "window_size": 1024,
9
+ "hop_size": 480,
10
+ "fmin": 50,
11
+ "fmax": 14000,
12
+ "class_num": 527,
13
+ "model_type": "PANN",
14
+ "model_name": "Cnn6"
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 512,
20
+ "heads": 8,
21
+ "layers": 12
22
+ }
23
+ }
audioldm/clap/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
audioldm/clap/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
audioldm/clap/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
audioldm/clap/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
audioldm/clap/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
audioldm/clap/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
audioldm/clap/open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
audioldm/clap/open_clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
audioldm/clap/open_clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
audioldm/clap/open_clip/openai.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import Union, List
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict
13
+ from .pretrained import (
14
+ get_pretrained_url,
15
+ list_pretrained_tag_models,
16
+ download_pretrained,
17
+ )
18
+
19
+ __all__ = ["list_openai_models", "load_openai_model"]
20
+
21
+
22
+ def list_openai_models() -> List[str]:
23
+ """Returns the names of available CLIP models"""
24
+ return list_pretrained_tag_models("openai")
25
+
26
+
27
+ def load_openai_model(
28
+ name: str,
29
+ model_cfg,
30
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
31
+ jit=True,
32
+ cache_dir=os.path.expanduser("~/.cache/clip"),
33
+ enable_fusion: bool = False,
34
+ fusion_type: str = "None",
35
+ ):
36
+ """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
37
+
38
+ Parameters
39
+ ----------
40
+ name : str
41
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
42
+ device : Union[str, torch.device]
43
+ The device to put the loaded model
44
+ jit : bool
45
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
46
+
47
+ Returns
48
+ -------
49
+ model : torch.nn.Module
50
+ The CLAP model
51
+ preprocess : Callable[[PIL.Image], torch.Tensor]
52
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
53
+ """
54
+ if get_pretrained_url(name, "openai"):
55
+ model_path = download_pretrained(
56
+ get_pretrained_url(name, "openai"), root=cache_dir
57
+ )
58
+ elif os.path.isfile(name):
59
+ model_path = name
60
+ else:
61
+ raise RuntimeError(
62
+ f"Model {name} not found; available models = {list_openai_models()}"
63
+ )
64
+
65
+ try:
66
+ # loading JIT archive
67
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
68
+ state_dict = None
69
+ except RuntimeError:
70
+ # loading saved state dict
71
+ if jit:
72
+ warnings.warn(
73
+ f"File {model_path} is not a JIT archive. Loading as a state dict instead"
74
+ )
75
+ jit = False
76
+ state_dict = torch.load(model_path, map_location="cpu")
77
+
78
+ if not jit:
79
+ try:
80
+ model = build_model_from_openai_state_dict(
81
+ state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
82
+ ).to(device)
83
+ except KeyError:
84
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
85
+ model = build_model_from_openai_state_dict(
86
+ sd, model_cfg, enable_fusion, fusion_type
87
+ ).to(device)
88
+
89
+ if str(device) == "cpu":
90
+ model.float()
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(
95
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
96
+ )
97
+ device_node = [
98
+ n
99
+ for n in device_holder.graph.findAllNodes("prim::Constant")
100
+ if "Device" in repr(n)
101
+ ][-1]
102
+
103
+ def patch_device(module):
104
+ try:
105
+ graphs = [module.graph] if hasattr(module, "graph") else []
106
+ except RuntimeError:
107
+ graphs = []
108
+
109
+ if hasattr(module, "forward1"):
110
+ graphs.append(module.forward1.graph)
111
+
112
+ for graph in graphs:
113
+ for node in graph.findAllNodes("prim::Constant"):
114
+ if "value" in node.attributeNames() and str(node["value"]).startswith(
115
+ "cuda"
116
+ ):
117
+ node.copyAttributes(device_node)
118
+
119
+ model.apply(patch_device)
120
+ patch_device(model.encode_audio)
121
+ patch_device(model.encode_text)
122
+
123
+ # patch dtype to float32 on CPU
124
+ if str(device) == "cpu":
125
+ float_holder = torch.jit.trace(
126
+ lambda: torch.ones([]).float(), example_inputs=[]
127
+ )
128
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
129
+ float_node = float_input.node()
130
+
131
+ def patch_float(module):
132
+ try:
133
+ graphs = [module.graph] if hasattr(module, "graph") else []
134
+ except RuntimeError:
135
+ graphs = []
136
+
137
+ if hasattr(module, "forward1"):
138
+ graphs.append(module.forward1.graph)
139
+
140
+ for graph in graphs:
141
+ for node in graph.findAllNodes("aten::to"):
142
+ inputs = list(node.inputs())
143
+ for i in [
144
+ 1,
145
+ 2,
146
+ ]: # dtype can be the second or third argument to aten::to()
147
+ if inputs[i].node()["value"] == 5:
148
+ inputs[i].node().copyAttributes(float_node)
149
+
150
+ model.apply(patch_float)
151
+ patch_float(model.encode_audio)
152
+ patch_float(model.encode_text)
153
+ model.float()
154
+
155
+ model.audio_branch.audio_length = model.audio_cfg.audio_length
156
+ return model
audioldm/clap/open_clip/pann_model.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
2
+ # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
3
+ # Some layers are re-designed for CLAP
4
+ import os
5
+
6
+ os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
12
+ from torchlibrosa.augmentation import SpecAugmentation
13
+
14
+ from .utils import do_mixup, interpolate, pad_framewise_output
15
+ from .feature_fusion import iAFF, AFF, DAF
16
+
17
+
18
+ def init_layer(layer):
19
+ """Initialize a Linear or Convolutional layer."""
20
+ nn.init.xavier_uniform_(layer.weight)
21
+
22
+ if hasattr(layer, "bias"):
23
+ if layer.bias is not None:
24
+ layer.bias.data.fill_(0.0)
25
+
26
+ def init_bn(bn):
27
+ """Initialize a Batchnorm layer."""
28
+ bn.bias.data.fill_(0.0)
29
+ bn.weight.data.fill_(1.0)
30
+
31
+
32
+ class ConvBlock(nn.Module):
33
+ def __init__(self, in_channels, out_channels):
34
+
35
+ super(ConvBlock, self).__init__()
36
+
37
+ self.conv1 = nn.Conv2d(
38
+ in_channels=in_channels,
39
+ out_channels=out_channels,
40
+ kernel_size=(3, 3),
41
+ stride=(1, 1),
42
+ padding=(1, 1),
43
+ bias=False,
44
+ )
45
+
46
+ self.conv2 = nn.Conv2d(
47
+ in_channels=out_channels,
48
+ out_channels=out_channels,
49
+ kernel_size=(3, 3),
50
+ stride=(1, 1),
51
+ padding=(1, 1),
52
+ bias=False,
53
+ )
54
+
55
+ self.bn1 = nn.BatchNorm2d(out_channels)
56
+ self.bn2 = nn.BatchNorm2d(out_channels)
57
+
58
+ self.init_weight()
59
+
60
+ def init_weight(self):
61
+ init_layer(self.conv1)
62
+ init_layer(self.conv2)
63
+ init_bn(self.bn1)
64
+ init_bn(self.bn2)
65
+
66
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
67
+
68
+ x = input
69
+ x = F.relu_(self.bn1(self.conv1(x)))
70
+ x = F.relu_(self.bn2(self.conv2(x)))
71
+ if pool_type == "max":
72
+ x = F.max_pool2d(x, kernel_size=pool_size)
73
+ elif pool_type == "avg":
74
+ x = F.avg_pool2d(x, kernel_size=pool_size)
75
+ elif pool_type == "avg+max":
76
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
77
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
78
+ x = x1 + x2
79
+ else:
80
+ raise Exception("Incorrect argument!")
81
+
82
+ return x
83
+
84
+
85
+ class ConvBlock5x5(nn.Module):
86
+ def __init__(self, in_channels, out_channels):
87
+
88
+ super(ConvBlock5x5, self).__init__()
89
+
90
+ self.conv1 = nn.Conv2d(
91
+ in_channels=in_channels,
92
+ out_channels=out_channels,
93
+ kernel_size=(5, 5),
94
+ stride=(1, 1),
95
+ padding=(2, 2),
96
+ bias=False,
97
+ )
98
+
99
+ self.bn1 = nn.BatchNorm2d(out_channels)
100
+
101
+ self.init_weight()
102
+
103
+ def init_weight(self):
104
+ init_layer(self.conv1)
105
+ init_bn(self.bn1)
106
+
107
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
108
+
109
+ x = input
110
+ x = F.relu_(self.bn1(self.conv1(x)))
111
+ if pool_type == "max":
112
+ x = F.max_pool2d(x, kernel_size=pool_size)
113
+ elif pool_type == "avg":
114
+ x = F.avg_pool2d(x, kernel_size=pool_size)
115
+ elif pool_type == "avg+max":
116
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
117
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
118
+ x = x1 + x2
119
+ else:
120
+ raise Exception("Incorrect argument!")
121
+
122
+ return x
123
+
124
+
125
+ class AttBlock(nn.Module):
126
+ def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
127
+ super(AttBlock, self).__init__()
128
+
129
+ self.activation = activation
130
+ self.temperature = temperature
131
+ self.att = nn.Conv1d(
132
+ in_channels=n_in,
133
+ out_channels=n_out,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0,
137
+ bias=True,
138
+ )
139
+ self.cla = nn.Conv1d(
140
+ in_channels=n_in,
141
+ out_channels=n_out,
142
+ kernel_size=1,
143
+ stride=1,
144
+ padding=0,
145
+ bias=True,
146
+ )
147
+
148
+ self.bn_att = nn.BatchNorm1d(n_out)
149
+ self.init_weights()
150
+
151
+ def init_weights(self):
152
+ init_layer(self.att)
153
+ init_layer(self.cla)
154
+ init_bn(self.bn_att)
155
+
156
+ def forward(self, x):
157
+ # x: (n_samples, n_in, n_time)
158
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
159
+ cla = self.nonlinear_transform(self.cla(x))
160
+ x = torch.sum(norm_att * cla, dim=2)
161
+ return x, norm_att, cla
162
+
163
+ def nonlinear_transform(self, x):
164
+ if self.activation == "linear":
165
+ return x
166
+ elif self.activation == "sigmoid":
167
+ return torch.sigmoid(x)
168
+
169
+
170
+ class Cnn14(nn.Module):
171
+ def __init__(
172
+ self,
173
+ sample_rate,
174
+ window_size,
175
+ hop_size,
176
+ mel_bins,
177
+ fmin,
178
+ fmax,
179
+ classes_num,
180
+ enable_fusion=False,
181
+ fusion_type="None",
182
+ ):
183
+
184
+ super(Cnn14, self).__init__()
185
+
186
+ window = "hann"
187
+ center = True
188
+ pad_mode = "reflect"
189
+ ref = 1.0
190
+ amin = 1e-10
191
+ top_db = None
192
+
193
+ self.enable_fusion = enable_fusion
194
+ self.fusion_type = fusion_type
195
+
196
+ # Spectrogram extractor
197
+ self.spectrogram_extractor = Spectrogram(
198
+ n_fft=window_size,
199
+ hop_length=hop_size,
200
+ win_length=window_size,
201
+ window=window,
202
+ center=center,
203
+ pad_mode=pad_mode,
204
+ freeze_parameters=True,
205
+ )
206
+
207
+ # Logmel feature extractor
208
+ self.logmel_extractor = LogmelFilterBank(
209
+ sr=sample_rate,
210
+ n_fft=window_size,
211
+ n_mels=mel_bins,
212
+ fmin=fmin,
213
+ fmax=fmax,
214
+ ref=ref,
215
+ amin=amin,
216
+ top_db=top_db,
217
+ freeze_parameters=True,
218
+ )
219
+
220
+ # Spec augmenter
221
+ self.spec_augmenter = SpecAugmentation(
222
+ time_drop_width=64,
223
+ time_stripes_num=2,
224
+ freq_drop_width=8,
225
+ freq_stripes_num=2,
226
+ )
227
+
228
+ self.bn0 = nn.BatchNorm2d(64)
229
+
230
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
231
+ self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
232
+ else:
233
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
234
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
235
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
236
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
237
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
238
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
239
+
240
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
241
+ self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
242
+
243
+ if (self.enable_fusion) and (
244
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
245
+ ):
246
+ self.mel_conv1d = nn.Sequential(
247
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
248
+ nn.BatchNorm1d(64), # No Relu
249
+ )
250
+ if self.fusion_type == "daf_1d":
251
+ self.fusion_model = DAF()
252
+ elif self.fusion_type == "aff_1d":
253
+ self.fusion_model = AFF(channels=64, type="1D")
254
+ elif self.fusion_type == "iaff_1d":
255
+ self.fusion_model = iAFF(channels=64, type="1D")
256
+
257
+ if (self.enable_fusion) and (
258
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
259
+ ):
260
+ self.mel_conv2d = nn.Sequential(
261
+ nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
262
+ nn.BatchNorm2d(64),
263
+ nn.ReLU(inplace=True),
264
+ )
265
+
266
+ if self.fusion_type == "daf_2d":
267
+ self.fusion_model = DAF()
268
+ elif self.fusion_type == "aff_2d":
269
+ self.fusion_model = AFF(channels=64, type="2D")
270
+ elif self.fusion_type == "iaff_2d":
271
+ self.fusion_model = iAFF(channels=64, type="2D")
272
+ self.init_weight()
273
+
274
+ def init_weight(self):
275
+ init_bn(self.bn0)
276
+ init_layer(self.fc1)
277
+ init_layer(self.fc_audioset)
278
+
279
+ def forward(self, input, mixup_lambda=None, device=None):
280
+ """
281
+ Input: (batch_size, data_length)"""
282
+
283
+ if self.enable_fusion and input["longer"].sum() == 0:
284
+ # if no audio is longer than 10s, then randomly select one audio to be longer
285
+ input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
286
+
287
+ if not self.enable_fusion:
288
+ x = self.spectrogram_extractor(
289
+ input["waveform"].to(device=device, non_blocking=True)
290
+ ) # (batch_size, 1, time_steps, freq_bins)
291
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
292
+
293
+ x = x.transpose(1, 3)
294
+ x = self.bn0(x)
295
+ x = x.transpose(1, 3)
296
+ else:
297
+ longer_list = input["longer"].to(device=device, non_blocking=True)
298
+ x = input["mel_fusion"].to(device=device, non_blocking=True)
299
+ longer_list_idx = torch.where(longer_list)[0]
300
+ x = x.transpose(1, 3)
301
+ x = self.bn0(x)
302
+ x = x.transpose(1, 3)
303
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
304
+ new_x = x[:, 0:1, :, :].clone().contiguous()
305
+ # local processing
306
+ if len(longer_list_idx) > 0:
307
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
308
+ FB, FC, FT, FF = fusion_x_local.size()
309
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
310
+ fusion_x_local = torch.permute(
311
+ fusion_x_local, (0, 2, 1)
312
+ ).contiguous()
313
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
314
+ fusion_x_local = fusion_x_local.view(
315
+ FB, FC, FF, fusion_x_local.size(-1)
316
+ )
317
+ fusion_x_local = (
318
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
319
+ .contiguous()
320
+ .flatten(2)
321
+ )
322
+ if fusion_x_local.size(-1) < FT:
323
+ fusion_x_local = torch.cat(
324
+ [
325
+ fusion_x_local,
326
+ torch.zeros(
327
+ (FB, FF, FT - fusion_x_local.size(-1)),
328
+ device=device,
329
+ ),
330
+ ],
331
+ dim=-1,
332
+ )
333
+ else:
334
+ fusion_x_local = fusion_x_local[:, :, :FT]
335
+ # 1D fusion
336
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
337
+ new_x[longer_list_idx] = self.fusion_model(
338
+ new_x[longer_list_idx], fusion_x_local
339
+ )
340
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
341
+ else:
342
+ x = new_x
343
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
344
+ x = x # no change
345
+
346
+ if self.training:
347
+ x = self.spec_augmenter(x)
348
+ # Mixup on spectrogram
349
+ if self.training and mixup_lambda is not None:
350
+ x = do_mixup(x, mixup_lambda)
351
+ if (self.enable_fusion) and (
352
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
353
+ ):
354
+ global_x = x[:, 0:1, :, :]
355
+
356
+ # global processing
357
+ B, C, H, W = global_x.shape
358
+ global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
359
+ if len(longer_list_idx) > 0:
360
+ local_x = x[longer_list_idx, 1:, :, :].contiguous()
361
+ TH = global_x.size(-2)
362
+ # local processing
363
+ B, C, H, W = local_x.shape
364
+ local_x = local_x.view(B * C, 1, H, W)
365
+ local_x = self.mel_conv2d(local_x)
366
+ local_x = local_x.view(
367
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
368
+ )
369
+ local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
370
+ TB, TC, _, TW = local_x.size()
371
+ if local_x.size(-2) < TH:
372
+ local_x = torch.cat(
373
+ [
374
+ local_x,
375
+ torch.zeros(
376
+ (TB, TC, TH - local_x.size(-2), TW),
377
+ device=global_x.device,
378
+ ),
379
+ ],
380
+ dim=-2,
381
+ )
382
+ else:
383
+ local_x = local_x[:, :, :TH, :]
384
+
385
+ global_x[longer_list_idx] = self.fusion_model(
386
+ global_x[longer_list_idx], local_x
387
+ )
388
+ x = global_x
389
+ else:
390
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
391
+
392
+ x = F.dropout(x, p=0.2, training=self.training)
393
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
394
+ x = F.dropout(x, p=0.2, training=self.training)
395
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
396
+ x = F.dropout(x, p=0.2, training=self.training)
397
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
398
+ x = F.dropout(x, p=0.2, training=self.training)
399
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
400
+ x = F.dropout(x, p=0.2, training=self.training)
401
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
402
+ x = F.dropout(x, p=0.2, training=self.training)
403
+ x = torch.mean(x, dim=3)
404
+
405
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
406
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
407
+ latent_x = latent_x1 + latent_x2
408
+ latent_x = latent_x.transpose(1, 2)
409
+ latent_x = F.relu_(self.fc1(latent_x))
410
+ latent_output = interpolate(latent_x, 32)
411
+
412
+ (x1, _) = torch.max(x, dim=2)
413
+ x2 = torch.mean(x, dim=2)
414
+ x = x1 + x2
415
+ x = F.dropout(x, p=0.5, training=self.training)
416
+ x = F.relu_(self.fc1(x))
417
+ embedding = F.dropout(x, p=0.5, training=self.training)
418
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
419
+
420
+ output_dict = {
421
+ "clipwise_output": clipwise_output,
422
+ "embedding": embedding,
423
+ "fine_grained_embedding": latent_output,
424
+ }
425
+ return output_dict
426
+
427
+
428
+ class Cnn6(nn.Module):
429
+ def __init__(
430
+ self,
431
+ sample_rate,
432
+ window_size,
433
+ hop_size,
434
+ mel_bins,
435
+ fmin,
436
+ fmax,
437
+ classes_num,
438
+ enable_fusion=False,
439
+ fusion_type="None",
440
+ ):
441
+
442
+ super(Cnn6, self).__init__()
443
+
444
+ window = "hann"
445
+ center = True
446
+ pad_mode = "reflect"
447
+ ref = 1.0
448
+ amin = 1e-10
449
+ top_db = None
450
+
451
+ self.enable_fusion = enable_fusion
452
+ self.fusion_type = fusion_type
453
+
454
+ # Spectrogram extractor
455
+ self.spectrogram_extractor = Spectrogram(
456
+ n_fft=window_size,
457
+ hop_length=hop_size,
458
+ win_length=window_size,
459
+ window=window,
460
+ center=center,
461
+ pad_mode=pad_mode,
462
+ freeze_parameters=True,
463
+ )
464
+
465
+ # Logmel feature extractor
466
+ self.logmel_extractor = LogmelFilterBank(
467
+ sr=sample_rate,
468
+ n_fft=window_size,
469
+ n_mels=mel_bins,
470
+ fmin=fmin,
471
+ fmax=fmax,
472
+ ref=ref,
473
+ amin=amin,
474
+ top_db=top_db,
475
+ freeze_parameters=True,
476
+ )
477
+
478
+ # Spec augmenter
479
+ self.spec_augmenter = SpecAugmentation(
480
+ time_drop_width=64,
481
+ time_stripes_num=2,
482
+ freq_drop_width=8,
483
+ freq_stripes_num=2,
484
+ )
485
+
486
+ self.bn0 = nn.BatchNorm2d(64)
487
+
488
+ self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
489
+ self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
490
+ self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
491
+ self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
492
+
493
+ self.fc1 = nn.Linear(512, 512, bias=True)
494
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
495
+
496
+ self.init_weight()
497
+
498
+ def init_weight(self):
499
+ init_bn(self.bn0)
500
+ init_layer(self.fc1)
501
+ init_layer(self.fc_audioset)
502
+
503
+ def forward(self, input, mixup_lambda=None, device=None):
504
+ """
505
+ Input: (batch_size, data_length)"""
506
+
507
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
508
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
509
+
510
+ x = x.transpose(1, 3)
511
+ x = self.bn0(x)
512
+ x = x.transpose(1, 3)
513
+
514
+ if self.training:
515
+ x = self.spec_augmenter(x)
516
+
517
+ # Mixup on spectrogram
518
+ if self.training and mixup_lambda is not None:
519
+ x = do_mixup(x, mixup_lambda)
520
+
521
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
522
+ x = F.dropout(x, p=0.2, training=self.training)
523
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
524
+ x = F.dropout(x, p=0.2, training=self.training)
525
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
526
+ x = F.dropout(x, p=0.2, training=self.training)
527
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
528
+ x = F.dropout(x, p=0.2, training=self.training)
529
+ x = torch.mean(x, dim=3)
530
+
531
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
532
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
533
+ latent_x = latent_x1 + latent_x2
534
+ latent_x = latent_x.transpose(1, 2)
535
+ latent_x = F.relu_(self.fc1(latent_x))
536
+ latent_output = interpolate(latent_x, 16)
537
+
538
+ (x1, _) = torch.max(x, dim=2)
539
+ x2 = torch.mean(x, dim=2)
540
+ x = x1 + x2
541
+ x = F.dropout(x, p=0.5, training=self.training)
542
+ x = F.relu_(self.fc1(x))
543
+ embedding = F.dropout(x, p=0.5, training=self.training)
544
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
545
+
546
+ output_dict = {
547
+ "clipwise_output": clipwise_output,
548
+ "embedding": embedding,
549
+ "fine_grained_embedding": latent_output,
550
+ }
551
+
552
+ return output_dict
553
+
554
+
555
+ class Cnn10(nn.Module):
556
+ def __init__(
557
+ self,
558
+ sample_rate,
559
+ window_size,
560
+ hop_size,
561
+ mel_bins,
562
+ fmin,
563
+ fmax,
564
+ classes_num,
565
+ enable_fusion=False,
566
+ fusion_type="None",
567
+ ):
568
+
569
+ super(Cnn10, self).__init__()
570
+
571
+ window = "hann"
572
+ center = True
573
+ pad_mode = "reflect"
574
+ ref = 1.0
575
+ amin = 1e-10
576
+ top_db = None
577
+
578
+ self.enable_fusion = enable_fusion
579
+ self.fusion_type = fusion_type
580
+
581
+ # Spectrogram extractor
582
+ self.spectrogram_extractor = Spectrogram(
583
+ n_fft=window_size,
584
+ hop_length=hop_size,
585
+ win_length=window_size,
586
+ window=window,
587
+ center=center,
588
+ pad_mode=pad_mode,
589
+ freeze_parameters=True,
590
+ )
591
+
592
+ # Logmel feature extractor
593
+ self.logmel_extractor = LogmelFilterBank(
594
+ sr=sample_rate,
595
+ n_fft=window_size,
596
+ n_mels=mel_bins,
597
+ fmin=fmin,
598
+ fmax=fmax,
599
+ ref=ref,
600
+ amin=amin,
601
+ top_db=top_db,
602
+ freeze_parameters=True,
603
+ )
604
+
605
+ # Spec augmenter
606
+ self.spec_augmenter = SpecAugmentation(
607
+ time_drop_width=64,
608
+ time_stripes_num=2,
609
+ freq_drop_width=8,
610
+ freq_stripes_num=2,
611
+ )
612
+
613
+ self.bn0 = nn.BatchNorm2d(64)
614
+
615
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
616
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
617
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
618
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
619
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
620
+
621
+ self.fc1 = nn.Linear(1024, 1024, bias=True)
622
+ self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
623
+
624
+ self.init_weight()
625
+
626
+ def init_weight(self):
627
+ init_bn(self.bn0)
628
+ init_layer(self.fc1)
629
+ init_layer(self.fc_audioset)
630
+
631
+ def forward(self, input, mixup_lambda=None, device=None):
632
+ """
633
+ Input: (batch_size, data_length)"""
634
+
635
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
636
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
637
+
638
+ x = x.transpose(1, 3)
639
+ x = self.bn0(x)
640
+ x = x.transpose(1, 3)
641
+
642
+ if self.training:
643
+ x = self.spec_augmenter(x)
644
+
645
+ # Mixup on spectrogram
646
+ if self.training and mixup_lambda is not None:
647
+ x = do_mixup(x, mixup_lambda)
648
+
649
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
650
+ x = F.dropout(x, p=0.2, training=self.training)
651
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
652
+ x = F.dropout(x, p=0.2, training=self.training)
653
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
654
+ x = F.dropout(x, p=0.2, training=self.training)
655
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
656
+ x = F.dropout(x, p=0.2, training=self.training)
657
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
658
+ x = F.dropout(x, p=0.2, training=self.training)
659
+ x = torch.mean(x, dim=3)
660
+
661
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
662
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
663
+ latent_x = latent_x1 + latent_x2
664
+ latent_x = latent_x.transpose(1, 2)
665
+ latent_x = F.relu_(self.fc1(latent_x))
666
+ latent_output = interpolate(latent_x, 32)
667
+
668
+ (x1, _) = torch.max(x, dim=2)
669
+ x2 = torch.mean(x, dim=2)
670
+ x = x1 + x2
671
+ x = F.dropout(x, p=0.5, training=self.training)
672
+ x = F.relu_(self.fc1(x))
673
+ embedding = F.dropout(x, p=0.5, training=self.training)
674
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
675
+
676
+ output_dict = {
677
+ "clipwise_output": clipwise_output,
678
+ "embedding": embedding,
679
+ "fine_grained_embedding": latent_output,
680
+ }
681
+
682
+ return output_dict
683
+
684
+
685
+ def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
686
+ try:
687
+ ModelProto = eval(audio_cfg.model_name)
688
+ model = ModelProto(
689
+ sample_rate=audio_cfg.sample_rate,
690
+ window_size=audio_cfg.window_size,
691
+ hop_size=audio_cfg.hop_size,
692
+ mel_bins=audio_cfg.mel_bins,
693
+ fmin=audio_cfg.fmin,
694
+ fmax=audio_cfg.fmax,
695
+ classes_num=audio_cfg.class_num,
696
+ enable_fusion=enable_fusion,
697
+ fusion_type=fusion_type,
698
+ )
699
+ return model
700
+ except:
701
+ raise RuntimeError(
702
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
703
+ )
audioldm/clap/open_clip/pretrained.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+
6
+ from tqdm import tqdm
7
+
8
+ _RN50 = dict(
9
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
12
+ )
13
+
14
+ _RN50_quickgelu = dict(
15
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
18
+ )
19
+
20
+ _RN101 = dict(
21
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
23
+ )
24
+
25
+ _RN101_quickgelu = dict(
26
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
28
+ )
29
+
30
+ _RN50x4 = dict(
31
+ openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32
+ )
33
+
34
+ _RN50x16 = dict(
35
+ openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36
+ )
37
+
38
+ _RN50x64 = dict(
39
+ openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
+ )
41
+
42
+ _VITB32 = dict(
43
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47
+ )
48
+
49
+ _VITB32_quickgelu = dict(
50
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54
+ )
55
+
56
+ _VITB16 = dict(
57
+ openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58
+ )
59
+
60
+ _VITL14 = dict(
61
+ openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62
+ )
63
+
64
+ _PRETRAINED = {
65
+ "RN50": _RN50,
66
+ "RN50-quickgelu": _RN50_quickgelu,
67
+ "RN101": _RN101,
68
+ "RN101-quickgelu": _RN101_quickgelu,
69
+ "RN50x4": _RN50x4,
70
+ "RN50x16": _RN50x16,
71
+ "ViT-B-32": _VITB32,
72
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
73
+ "ViT-B-16": _VITB16,
74
+ "ViT-L-14": _VITL14,
75
+ }
76
+
77
+
78
+ def list_pretrained(as_str: bool = False):
79
+ """returns list of pretrained models
80
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81
+ """
82
+ return [
83
+ ":".join([k, t]) if as_str else (k, t)
84
+ for k in _PRETRAINED.keys()
85
+ for t in _PRETRAINED[k].keys()
86
+ ]
87
+
88
+
89
+ def list_pretrained_tag_models(tag: str):
90
+ """return all models having the specified pretrain tag"""
91
+ models = []
92
+ for k in _PRETRAINED.keys():
93
+ if tag in _PRETRAINED[k]:
94
+ models.append(k)
95
+ return models
96
+
97
+
98
+ def list_pretrained_model_tags(model: str):
99
+ """return all pretrain tags for the specified model architecture"""
100
+ tags = []
101
+ if model in _PRETRAINED:
102
+ tags.extend(_PRETRAINED[model].keys())
103
+ return tags
104
+
105
+
106
+ def get_pretrained_url(model: str, tag: str):
107
+ if model not in _PRETRAINED:
108
+ return ""
109
+ model_pretrained = _PRETRAINED[model]
110
+ if tag not in model_pretrained:
111
+ return ""
112
+ return model_pretrained[tag]
113
+
114
+
115
+ def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
116
+ os.makedirs(root, exist_ok=True)
117
+ filename = os.path.basename(url)
118
+
119
+ if "openaipublic" in url:
120
+ expected_sha256 = url.split("/")[-2]
121
+ else:
122
+ expected_sha256 = ""
123
+
124
+ download_target = os.path.join(root, filename)
125
+
126
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
127
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
128
+
129
+ if os.path.isfile(download_target):
130
+ if expected_sha256:
131
+ if (
132
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
133
+ == expected_sha256
134
+ ):
135
+ return download_target
136
+ else:
137
+ warnings.warn(
138
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
139
+ )
140
+ else:
141
+ return download_target
142
+
143
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
144
+ with tqdm(
145
+ total=int(source.info().get("Content-Length")),
146
+ ncols=80,
147
+ unit="iB",
148
+ unit_scale=True,
149
+ ) as loop:
150
+ while True:
151
+ buffer = source.read(8192)
152
+ if not buffer:
153
+ break
154
+
155
+ output.write(buffer)
156
+ loop.update(len(buffer))
157
+
158
+ if (
159
+ expected_sha256
160
+ and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
161
+ != expected_sha256
162
+ ):
163
+ raise RuntimeError(
164
+ f"Model has been downloaded but the SHA256 checksum does not not match"
165
+ )
166
+
167
+ return download_target
audioldm/clap/open_clip/timm_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch.nn as nn
8
+
9
+ try:
10
+ import timm
11
+ from timm.models.layers import Mlp, to_2tuple
12
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
13
+ from timm.models.layers.attention_pool2d import (
14
+ AttentionPool2d as AbsAttentionPool2d,
15
+ )
16
+ except ImportError as e:
17
+ timm = None
18
+
19
+ from .utils import freeze_batch_norm_2d
20
+
21
+
22
+ class TimmModel(nn.Module):
23
+ """timm model adapter
24
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model_name,
30
+ embed_dim,
31
+ image_size=224,
32
+ pool="avg",
33
+ proj="linear",
34
+ drop=0.0,
35
+ pretrained=False,
36
+ ):
37
+ super().__init__()
38
+ if timm is None:
39
+ raise RuntimeError("Please `pip install timm` to use timm models.")
40
+
41
+ self.image_size = to_2tuple(image_size)
42
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
43
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
44
+ feature_ndim = 1 if not feat_size else 2
45
+ if pool in ("abs_attn", "rot_attn"):
46
+ assert feature_ndim == 2
47
+ # if attn pooling used, remove both classifier and default pool
48
+ self.trunk.reset_classifier(0, global_pool="")
49
+ else:
50
+ # reset global pool if pool config set, otherwise leave as network default
51
+ reset_kwargs = dict(global_pool=pool) if pool else {}
52
+ self.trunk.reset_classifier(0, **reset_kwargs)
53
+ prev_chs = self.trunk.num_features
54
+
55
+ head_layers = OrderedDict()
56
+ if pool == "abs_attn":
57
+ head_layers["pool"] = AbsAttentionPool2d(
58
+ prev_chs, feat_size=feat_size, out_features=embed_dim
59
+ )
60
+ prev_chs = embed_dim
61
+ elif pool == "rot_attn":
62
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63
+ prev_chs = embed_dim
64
+ else:
65
+ assert proj, "projection layer needed if non-attention pooling is used."
66
+
67
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68
+ if proj == "linear":
69
+ head_layers["drop"] = nn.Dropout(drop)
70
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71
+ elif proj == "mlp":
72
+ head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73
+
74
+ self.head = nn.Sequential(head_layers)
75
+
76
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77
+ """lock modules
78
+ Args:
79
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80
+ """
81
+ if not unlocked_groups:
82
+ # lock full model
83
+ for param in self.trunk.parameters():
84
+ param.requires_grad = False
85
+ if freeze_bn_stats:
86
+ freeze_batch_norm_2d(self.trunk)
87
+ else:
88
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89
+ try:
90
+ # FIXME import here until API stable and in an official release
91
+ from timm.models.helpers import group_parameters, group_modules
92
+ except ImportError:
93
+ raise RuntimeError(
94
+ "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95
+ )
96
+ matcher = self.trunk.group_matcher()
97
+ gparams = group_parameters(self.trunk, matcher)
98
+ max_layer_id = max(gparams.keys())
99
+ max_layer_id = max_layer_id - unlocked_groups
100
+ for group_idx in range(max_layer_id + 1):
101
+ group = gparams[group_idx]
102
+ for param in group:
103
+ self.trunk.get_parameter(param).requires_grad = False
104
+ if freeze_bn_stats:
105
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
106
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107
+ freeze_batch_norm_2d(self.trunk, gmodules)
108
+
109
+ def forward(self, x):
110
+ x = self.trunk(x)
111
+ x = self.head(x)
112
+ return x
audioldm/clap/open_clip/tokenizer.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+
16
+ @lru_cache()
17
+ def default_bpe():
18
+ return os.path.join(
19
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20
+ )
21
+
22
+
23
+ @lru_cache()
24
+ def bytes_to_unicode():
25
+ """
26
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
27
+ The reversible bpe codes work on unicode strings.
28
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
31
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
33
+ """
34
+ bs = (
35
+ list(range(ord("!"), ord("~") + 1))
36
+ + list(range(ord("¡"), ord("¬") + 1))
37
+ + list(range(ord("®"), ord("ÿ") + 1))
38
+ )
39
+ cs = bs[:]
40
+ n = 0
41
+ for b in range(2**8):
42
+ if b not in bs:
43
+ bs.append(b)
44
+ cs.append(2**8 + n)
45
+ n += 1
46
+ cs = [chr(n) for n in cs]
47
+ return dict(zip(bs, cs))
48
+
49
+
50
+ def get_pairs(word):
51
+ """Return set of symbol pairs in a word.
52
+ Word is represented as tuple of symbols (symbols being variable-length strings).
53
+ """
54
+ pairs = set()
55
+ prev_char = word[0]
56
+ for char in word[1:]:
57
+ pairs.add((prev_char, char))
58
+ prev_char = char
59
+ return pairs
60
+
61
+
62
+ def basic_clean(text):
63
+ text = ftfy.fix_text(text)
64
+ text = html.unescape(html.unescape(text))
65
+ return text.strip()
66
+
67
+
68
+ def whitespace_clean(text):
69
+ text = re.sub(r"\s+", " ", text)
70
+ text = text.strip()
71
+ return text
72
+
73
+
74
+ class SimpleTokenizer(object):
75
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76
+ self.byte_encoder = bytes_to_unicode()
77
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79
+ merges = merges[1 : 49152 - 256 - 2 + 1]
80
+ merges = [tuple(merge.split()) for merge in merges]
81
+ vocab = list(bytes_to_unicode().values())
82
+ vocab = vocab + [v + "</w>" for v in vocab]
83
+ for merge in merges:
84
+ vocab.append("".join(merge))
85
+ if not special_tokens:
86
+ special_tokens = ["<start_of_text>", "<end_of_text>"]
87
+ else:
88
+ special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
89
+ vocab.extend(special_tokens)
90
+ self.encoder = dict(zip(vocab, range(len(vocab))))
91
+ self.decoder = {v: k for k, v in self.encoder.items()}
92
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
93
+ self.cache = {t: t for t in special_tokens}
94
+ special = "|".join(special_tokens)
95
+ self.pat = re.compile(
96
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97
+ re.IGNORECASE,
98
+ )
99
+
100
+ self.vocab_size = len(self.encoder)
101
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
102
+
103
+ def bpe(self, token):
104
+ if token in self.cache:
105
+ return self.cache[token]
106
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
107
+ pairs = get_pairs(word)
108
+
109
+ if not pairs:
110
+ return token + "</w>"
111
+
112
+ while True:
113
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
+ if bigram not in self.bpe_ranks:
115
+ break
116
+ first, second = bigram
117
+ new_word = []
118
+ i = 0
119
+ while i < len(word):
120
+ try:
121
+ j = word.index(first, i)
122
+ new_word.extend(word[i:j])
123
+ i = j
124
+ except:
125
+ new_word.extend(word[i:])
126
+ break
127
+
128
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
+ new_word.append(first + second)
130
+ i += 2
131
+ else:
132
+ new_word.append(word[i])
133
+ i += 1
134
+ new_word = tuple(new_word)
135
+ word = new_word
136
+ if len(word) == 1:
137
+ break
138
+ else:
139
+ pairs = get_pairs(word)
140
+ word = " ".join(word)
141
+ self.cache[token] = word
142
+ return word
143
+
144
+ def encode(self, text):
145
+ bpe_tokens = []
146
+ text = whitespace_clean(basic_clean(text)).lower()
147
+ for token in re.findall(self.pat, text):
148
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149
+ bpe_tokens.extend(
150
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151
+ )
152
+ return bpe_tokens
153
+
154
+ def decode(self, tokens):
155
+ text = "".join([self.decoder[token] for token in tokens])
156
+ text = (
157
+ bytearray([self.byte_decoder[c] for c in text])
158
+ .decode("utf-8", errors="replace")
159
+ .replace("</w>", " ")
160
+ )
161
+ return text
162
+
163
+
164
+ _tokenizer = SimpleTokenizer()
165
+
166
+
167
+ def tokenize(
168
+ texts: Union[str, List[str]], context_length: int = 77
169
+ ) -> torch.LongTensor:
170
+ """
171
+ Returns the tokenized representation of given input string(s)
172
+
173
+ Parameters
174
+ ----------
175
+ texts : Union[str, List[str]]
176
+ An input string or a list of input strings to tokenize
177
+ context_length : int
178
+ The context length to use; all CLIP models use 77 as the context length
179
+
180
+ Returns
181
+ -------
182
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183
+ """
184
+ if isinstance(texts, str):
185
+ texts = [texts]
186
+
187
+ sot_token = _tokenizer.encoder["<start_of_text>"]
188
+ eot_token = _tokenizer.encoder["<end_of_text>"]
189
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191
+
192
+ for i, tokens in enumerate(all_tokens):
193
+ if len(tokens) > context_length:
194
+ tokens = tokens[:context_length] # Truncate
195
+ result[i, : len(tokens)] = torch.tensor(tokens)
196
+
197
+ return result
audioldm/clap/open_clip/transform.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import (
2
+ Normalize,
3
+ Compose,
4
+ RandomResizedCrop,
5
+ InterpolationMode,
6
+ ToTensor,
7
+ Resize,
8
+ CenterCrop,
9
+ )
10
+
11
+
12
+ def _convert_to_rgb(image):
13
+ return image.convert("RGB")
14
+
15
+
16
+ def image_transform(
17
+ image_size: int,
18
+ is_train: bool,
19
+ mean=(0.48145466, 0.4578275, 0.40821073),
20
+ std=(0.26862954, 0.26130258, 0.27577711),
21
+ ):
22
+ normalize = Normalize(mean=mean, std=std)
23
+ if is_train:
24
+ return Compose(
25
+ [
26
+ RandomResizedCrop(
27
+ image_size,
28
+ scale=(0.9, 1.0),
29
+ interpolation=InterpolationMode.BICUBIC,
30
+ ),
31
+ _convert_to_rgb,
32
+ ToTensor(),
33
+ normalize,
34
+ ]
35
+ )
36
+ else:
37
+ return Compose(
38
+ [
39
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40
+ CenterCrop(image_size),
41
+ _convert_to_rgb,
42
+ ToTensor(),
43
+ normalize,
44
+ ]
45
+ )
audioldm/clap/open_clip/utils.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn as nn
4
+ from torchvision.ops.misc import FrozenBatchNorm2d
5
+ import logging
6
+ # import h5py
7
+ from tqdm import tqdm
8
+ import random
9
+ import json
10
+ import os
11
+ import pathlib
12
+
13
+ # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
14
+ dataset_split = {
15
+ "audiocaps": ["train", "valid", "test"],
16
+ "audioset": ["balanced_train", "unbalanced_train", "eval"],
17
+ "BBCSoundEffects": ["train", "test"],
18
+ "Clotho": ["train", "test", "valid"],
19
+ "free_to_use_sounds": ["train", "test"],
20
+ "paramount_motion": ["train", "test"],
21
+ "sonniss_game_effects": ["train", "test"],
22
+ "wesoundeffects": ["train", "test"],
23
+ "MACS": ["train", "test"],
24
+ "freesound": ["train", "test"],
25
+ "FSD50K": ["train", "test", "valid"],
26
+ "fsd50k_class_label": ["train", "test", "valid"],
27
+ "esc50": ["train", "test"],
28
+ "audiostock": ["train", "test"],
29
+ "freesound_no_overlap_noesc50": ["train", "test"],
30
+ "epidemic_sound_effects": ["train", "test"],
31
+ "VGGSound": ["train", "test"],
32
+ "urbansound8k_class_label": ["train", "test"],
33
+ "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
34
+ "epidemic_sound_effects_t5": ["train", "test"],
35
+ "WavText5K": ["train", "test"],
36
+ "esc50_no_overlap": ["train", "test"],
37
+ "usd8k_no_overlap": ["train", "test"],
38
+ "fsd50k_200_class_label": ["train", "test", "valid"],
39
+ }
40
+
41
+
42
+ def freeze_batch_norm_2d(module, module_match={}, name=""):
43
+ """
44
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
45
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
46
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
47
+
48
+ Args:
49
+ module (torch.nn.Module): Any PyTorch module.
50
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
51
+ name (str): Full module name (prefix)
52
+
53
+ Returns:
54
+ torch.nn.Module: Resulting module
55
+
56
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
57
+ """
58
+ res = module
59
+ is_match = True
60
+ if module_match:
61
+ is_match = name in module_match
62
+ if is_match and isinstance(
63
+ module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
64
+ ):
65
+ res = FrozenBatchNorm2d(module.num_features)
66
+ res.num_features = module.num_features
67
+ res.affine = module.affine
68
+ if module.affine:
69
+ res.weight.data = module.weight.data.clone().detach()
70
+ res.bias.data = module.bias.data.clone().detach()
71
+ res.running_mean.data = module.running_mean.data
72
+ res.running_var.data = module.running_var.data
73
+ res.eps = module.eps
74
+ else:
75
+ for child_name, child in module.named_children():
76
+ full_child_name = ".".join([name, child_name]) if name else child_name
77
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
78
+ if new_child is not child:
79
+ res.add_module(child_name, new_child)
80
+ return res
81
+
82
+
83
+ def exist(dataset_name, dataset_type):
84
+ """
85
+ Check if dataset exists
86
+ """
87
+ if dataset_type in dataset_split[dataset_name]:
88
+ return True
89
+ else:
90
+ return False
91
+
92
+
93
+ def get_tar_path_from_dataset_name(
94
+ dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
95
+ ):
96
+ """
97
+ Get tar path from dataset name and type
98
+ """
99
+ output = []
100
+ for n in dataset_names:
101
+ if full_dataset is not None and n in full_dataset:
102
+ current_dataset_types = dataset_split[n]
103
+ else:
104
+ current_dataset_types = dataset_types
105
+ for s in current_dataset_types:
106
+ tmp = []
107
+ if islocal:
108
+ sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
109
+ if not os.path.exists(sizefilepath_):
110
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
111
+ else:
112
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
113
+ if not os.path.exists(sizefilepath_):
114
+ continue
115
+ sizes = json.load(open(sizefilepath_, "r"))
116
+ for k in sizes.keys():
117
+ if islocal:
118
+ tmp.append(f"{dataset_path}/{n}/{s}/{k}")
119
+ else:
120
+ tmp.append(
121
+ f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
122
+ )
123
+ if proportion != 1:
124
+ tmp = random.sample(tmp, int(proportion * len(tmp)))
125
+ output.append(tmp)
126
+ return sum(output, [])
127
+
128
+
129
+ def get_tar_path_from_txts(txt_path, islocal, proportion=1):
130
+ """
131
+ Get tar path from txt path
132
+ """
133
+ if isinstance(txt_path, (list, tuple)):
134
+ return sum(
135
+ [
136
+ get_tar_path_from_txts(
137
+ txt_path[i], islocal=islocal, proportion=proportion
138
+ )
139
+ for i in range(len(txt_path))
140
+ ],
141
+ [],
142
+ )
143
+ if isinstance(txt_path, str):
144
+ with open(txt_path) as f:
145
+ lines = f.readlines()
146
+ if islocal:
147
+ lines = [
148
+ lines[i]
149
+ .split("\n")[0]
150
+ .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
151
+ for i in range(len(lines))
152
+ ]
153
+ else:
154
+ lines = [
155
+ lines[i].split("\n")[0].replace(".tar", ".tar -")
156
+ for i in range(len(lines))
157
+ ]
158
+ if proportion != 1:
159
+ print("Sampling tars with proportion of {}".format(proportion))
160
+ lines = random.sample(lines, int(proportion * len(lines)))
161
+ return lines
162
+
163
+
164
+ def get_mix_lambda(mixup_alpha, batch_size):
165
+ mixup_lambdas = [
166
+ np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
167
+ ]
168
+ return np.array(mixup_lambdas).astype(np.float32)
169
+
170
+
171
+ def do_mixup(x, mixup_lambda):
172
+ """
173
+ Args:
174
+ x: (batch_size , ...)
175
+ mixup_lambda: (batch_size,)
176
+ Returns:
177
+ out: (batch_size, ...)
178
+ """
179
+ out = (
180
+ x.transpose(0, -1) * mixup_lambda
181
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
182
+ ).transpose(0, -1)
183
+ return out
184
+
185
+
186
+ def interpolate(x, ratio):
187
+ """Interpolate data in time domain. This is used to compensate the
188
+ resolution reduction in downsampling of a CNN.
189
+
190
+ Args:
191
+ x: (batch_size, time_steps, classes_num)
192
+ ratio: int, ratio to interpolate
193
+ Returns:
194
+ upsampled: (batch_size, time_steps * ratio, classes_num)
195
+ """
196
+ (batch_size, time_steps, classes_num) = x.shape
197
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
198
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
199
+ return upsampled
200
+
201
+
202
+ def pad_framewise_output(framewise_output, frames_num):
203
+ """Pad framewise_output to the same length as input frames. The pad value
204
+ is the same as the value of the last frame.
205
+ Args:
206
+ framewise_output: (batch_size, frames_num, classes_num)
207
+ frames_num: int, number of frames to pad
208
+ Outputs:
209
+ output: (batch_size, frames_num, classes_num)
210
+ """
211
+ pad = framewise_output[:, -1:, :].repeat(
212
+ 1, frames_num - framewise_output.shape[1], 1
213
+ )
214
+ """tensor for padding"""
215
+
216
+ output = torch.cat((framewise_output, pad), dim=1)
217
+ """(batch_size, frames_num, classes_num)"""
218
+
219
+
220
+ # def process_ipc(index_path, classes_num, filename):
221
+ # # load data
222
+ # logging.info("Load Data...............")
223
+ # ipc = [[] for _ in range(classes_num)]
224
+ # with h5py.File(index_path, "r") as f:
225
+ # for i in tqdm(range(len(f["target"]))):
226
+ # t_class = np.where(f["target"][i])[0]
227
+ # for t in t_class:
228
+ # ipc[t].append(i)
229
+ # print(ipc)
230
+ # np.save(filename, ipc)
231
+ # logging.info("Load Data Succeed...............")
232
+
233
+
234
+ def save_to_dict(s, o_={}):
235
+ sp = s.split(": ")
236
+ o_.update({sp[0]: float(sp[1])})
237
+ return o_
238
+
239
+
240
+ def get_data_from_log(txt_path):
241
+ """
242
+ Output dictionary from out.txt log file
243
+ """
244
+ with open(txt_path) as f:
245
+ lines = f.readlines()
246
+ val_data = {}
247
+ train_data = {}
248
+ train_losses = []
249
+ train_losses_epoch = []
250
+ for i in range(len(lines)):
251
+ if "| INFO |" in lines[i]:
252
+ if "Eval Epoch" in lines[i]:
253
+ if "val_loss" in lines[i]:
254
+ # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
255
+ line = lines[i].split("Eval Epoch: ")[-1]
256
+ num_epoch = int(line.split(" ")[0].split(" ")[0])
257
+ d = {
258
+ line.split(" ")[0]
259
+ .split(" ")[1]
260
+ .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
261
+ }
262
+ for i in range(1, len(line.split(" "))):
263
+ d = save_to_dict(line.split(" ")[i], d)
264
+ val_data[num_epoch] = d
265
+ elif "Train Epoch" in lines[i]:
266
+ num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
267
+ loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
268
+ train_losses.append(loss)
269
+ train_losses_epoch.append(num_epoch)
270
+ for i in range(len(train_losses)):
271
+ train_data[i] = {
272
+ "num_epoch": train_losses_epoch[i],
273
+ "train_loss": train_losses[i],
274
+ }
275
+ return train_data, val_data
276
+
277
+
278
+ def save_p(obj, filename):
279
+ import pickle
280
+
281
+ try:
282
+ from deepdiff import DeepDiff
283
+ except:
284
+ os.system("pip install deepdiff")
285
+ from deepdiff import DeepDiff
286
+ with open(filename, "wb") as file:
287
+ pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
288
+ with open(filename, "rb") as file:
289
+ z = pickle.load(file)
290
+ assert (
291
+ DeepDiff(obj, z, ignore_string_case=True) == {}
292
+ ), "there is something wrong with the saving process"
293
+ return
294
+
295
+
296
+ def load_p(filename):
297
+ import pickle
298
+
299
+ with open(filename, "rb") as file:
300
+ z = pickle.load(file)
301
+ return z
302
+
303
+
304
+ def save_json(data, name="data.json"):
305
+ import json
306
+
307
+ with open(name, "w") as fp:
308
+ json.dump(data, fp)
309
+ return
310
+
311
+
312
+ def load_json(name):
313
+ import json
314
+
315
+ with open(name, "r") as fp:
316
+ data = json.load(fp)
317
+ return data
318
+
319
+
320
+ from multiprocessing import Process, Manager
321
+ from multiprocessing import Process, Value, Array
322
+ from ctypes import c_wchar
323
+
324
+
325
+ def load_class_label(path):
326
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
327
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
328
+ out = None
329
+ if path is not None:
330
+ if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
331
+ out = load_p(path)
332
+ elif pathlib.Path(path).suffix in [".json", ".txt"]:
333
+ out = load_json(path)
334
+ elif pathlib.Path(path).suffix in [".npy", ".npz"]:
335
+ out = np.load(path)
336
+ elif pathlib.Path(path).suffix in [".csv"]:
337
+ import pandas as pd
338
+
339
+ out = pd.read_csv(path)
340
+ return out
341
+ # if out is None:
342
+ # return None
343
+ # else:
344
+ # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
345
+ # val = Array('i', out.values(), lock=False)
346
+ # return (key, val)
347
+
348
+
349
+ from torch import optim
350
+
351
+
352
+ def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
353
+ if optimizer_name.lower() == "adamw":
354
+ optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
355
+ elif optimizer_name.lower() == "sgd":
356
+ optimizer = optim.SGD(params, lr=lr, momentum=momentum)
357
+ elif optimizer_name.lower() == "adam":
358
+ optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
359
+ else:
360
+ raise ValueError("optimizer name is not correct")
361
+ return optimizer
audioldm/clap/open_clip/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.2.1"
audioldm/clap/training/__init__.py ADDED
File without changes
audioldm/clap/training/audioset_textmap.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
3
+ size 84448
audioldm/clap/training/data.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ # import h5py
8
+ from dataclasses import dataclass
9
+ from audioldm.clap.training.params import parse_args
10
+ # import braceexpand
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torchvision.datasets as datasets
17
+ import torchvision.transforms
18
+ # import webdataset as wds
19
+ from PIL import Image
20
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
21
+ from torch.utils.data.distributed import DistributedSampler
22
+ from functools import partial
23
+ import soundfile as sf
24
+ import io
25
+ from pathlib import Path
26
+ # import wget
27
+
28
+ from audioldm.clap.open_clip.utils import (
29
+ get_tar_path_from_dataset_name,
30
+ dataset_split,
31
+ )
32
+ from audioldm.clap.open_clip.utils import load_p, load_class_label
33
+ import copy
34
+
35
+ try:
36
+ import horovod.torch as hvd
37
+ except ImportError:
38
+ hvd = None
39
+
40
+ try:
41
+ import torchaudio
42
+ except ImportError:
43
+ torchaudio = None
44
+
45
+ from audioldm.clap.open_clip import tokenize
46
+
47
+
48
+ def tokenizer(text):
49
+ return tokenize(text).squeeze(0)
50
+
51
+
52
+ from transformers import RobertaTokenizer
53
+
54
+ tokenize = RobertaTokenizer.from_pretrained("roberta-base")
55
+
56
+
57
+ def tokenizer(text):
58
+ result = tokenize(
59
+ text,
60
+ padding="max_length",
61
+ truncation=True,
62
+ max_length=77,
63
+ return_tensors="pt",
64
+ )
65
+ return {k: v.squeeze(0) for k, v in result.items()}
66
+
67
+
68
+ # initizlied the audioset map
69
+ _AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
70
+ _AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
71
+
72
+
73
+ def int16_to_float32(x):
74
+ return (x / 32767.0).astype(np.float32)
75
+
76
+
77
+ def float32_to_int16(x):
78
+ x = np.clip(x, a_min=-1.0, a_max=1.0)
79
+ return (x * 32767.0).astype(np.int16)
80
+
81
+
82
+ # For Toy Dataset
83
+ # class ToyDataset(Dataset):
84
+ # def __init__(self, index_path, ipc, config, eval_mode=False):
85
+ # """Toy Dataset for testing the audioset input with text labels
86
+ # Parameters
87
+ # ----------
88
+ # index_path: str
89
+ # the link to the h5 file of each audio
90
+ # idc: str
91
+ # the link to the npy file, the number of samples in each class
92
+ # config: dict
93
+ # the audio cfg file
94
+ # eval_model (bool): to indicate if the dataset is a testing dataset
95
+ # """
96
+ # self.audio_cfg = config["audio_cfg"]
97
+ # self.text_cfg = config["text_cfg"]
98
+ # self.fp = h5py.File(index_path, "r")
99
+ # self.ipc = np.load(ipc, allow_pickle=True)
100
+ # self.total_size = len(self.fp["audio_name"])
101
+ # self.classes_num = self.audio_cfg["class_num"]
102
+ # self.eval_mode = eval_mode
103
+
104
+ # if not eval_mode:
105
+ # self.generate_queue()
106
+ # else:
107
+ # self.queue = []
108
+ # for i in range(self.total_size):
109
+ # target = self.fp["target"][i]
110
+ # if np.sum(target) > 0:
111
+ # self.queue.append(i)
112
+ # self.total_size = len(self.queue)
113
+ # logging.info("total dataset size: %d" % (self.total_size))
114
+ # logging.info("class num: %d" % (self.classes_num))
115
+
116
+ # def time_shifting(self, x):
117
+ # frame_num = len(x)
118
+ # shift_len = random.randint(0, frame_num - 1)
119
+ # new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
120
+ # return new_sample
121
+
122
+ # def generate_queue(self):
123
+ # self.queue = []
124
+ # while len(self.queue) < self.total_size:
125
+ # class_set = [*range(self.classes_num)]
126
+ # random.shuffle(class_set)
127
+ # self.queue += [
128
+ # self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
129
+ # ]
130
+ # self.queue = self.queue[: self.total_size]
131
+
132
+ # logging.info("queue regenerated:%s" % (self.queue[-5:]))
133
+
134
+ # def crop_wav(self, x):
135
+ # crop_size = self.audio_cfg["crop_size"]
136
+ # crop_pos = random.randint(0, len(x) - crop_size - 1)
137
+ # return x[crop_pos : crop_pos + crop_size]
138
+
139
+ # def prompt_text(self, target):
140
+ # events = _AUDIOSET_MAP[np.where(target > 0)]
141
+ # event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
142
+ # text = tokenize(event_text)[0]
143
+ # return text
144
+
145
+ # def __getitem__(self, index):
146
+ # """Load waveform, text, and target of an audio clip
147
+
148
+ # Parameters
149
+ # ----------
150
+ # index: int
151
+ # the index number
152
+ # Return
153
+ # ------
154
+ # output: dict {
155
+ # "hdf5_path": str,
156
+ # "index_in_hdf5": int,
157
+ # "audio_name": str,
158
+ # "waveform": list (audio_length,),
159
+ # "target": list (class_num, ),
160
+ # "text": torch.tensor (context_length,)
161
+ # }
162
+ # the output dictionary
163
+ # """
164
+ # s_index = self.queue[index]
165
+
166
+ # audio_name = self.fp["audio_name"][s_index].decode()
167
+ # # Hardcode here CHANGE
168
+ # hdf5_path = (
169
+ # self.fp["hdf5_path"][s_index]
170
+ # .decode()
171
+ # .replace(
172
+ # "../workspace",
173
+ # "/home/la/kechen/Research/ke_zsasp/workspace",
174
+ # )
175
+ # )
176
+ # r_idx = self.fp["index_in_hdf5"][s_index]
177
+ # target = self.fp["target"][s_index].astype(np.float32)
178
+ # text = self.prompt_text(target)
179
+ # with h5py.File(hdf5_path, "r") as f:
180
+ # waveform = int16_to_float32(f["waveform"][r_idx])[
181
+ # : self.audio_cfg["clip_samples"]
182
+ # ]
183
+ # assert (
184
+ # len(waveform) == self.audio_cfg["clip_samples"]
185
+ # ), "The sample length is not match"
186
+ # # Time shift
187
+ # # if (self.config.enable_time_shift) and (not self.eval_mode):
188
+ # # waveform = self.time_shifting(waveform)
189
+ # # # Label Enhance
190
+ # # if (self.config.crop_size is not None) and (not self.eval_mode):
191
+ # # waveform = self.crop_wav(waveform)
192
+ # # # the label enhance rate is fixed 0.5
193
+ # # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
194
+ # # kidx = np.where(target)[0]
195
+ # # for k in kidx:
196
+ # # for add_key in self.class_map[k][1]:
197
+ # # target[add_key] = 1.0
198
+ # # if len(self.class_map[k][2]) > 0:
199
+ # # add_key = random.choice(self.class_map[k][2])
200
+ # # target[add_key] = 1.0
201
+
202
+ # # missing the text input
203
+ # mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
204
+ # mel_spec = (
205
+ # torch.cat(
206
+ # [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
207
+ # )
208
+ # .cpu()
209
+ # .numpy()
210
+ # )
211
+ # longer = random.choice([True, False])
212
+ # if longer == False:
213
+ # mel_spec[1:, :, :] = 0.0
214
+ # data_dict = {
215
+ # "hdf5_path": hdf5_path,
216
+ # "index_in_hdf5": r_idx,
217
+ # "audio_name": audio_name,
218
+ # "waveform": waveform,
219
+ # "class_label": target,
220
+ # "text": text,
221
+ # "longer": longer,
222
+ # "mel_fusion": mel_spec,
223
+ # }
224
+ # return data_dict
225
+
226
+ # def __len__(self):
227
+ # return self.total_size
228
+
229
+
230
+ class CsvDataset(Dataset):
231
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
232
+ logging.debug(f"Loading csv data from {input_filename}.")
233
+ df = pd.read_csv(input_filename, sep=sep)
234
+
235
+ self.images = df[img_key].tolist()
236
+ self.captions = df[caption_key].tolist()
237
+ self.transforms = transforms
238
+ logging.debug("Done loading data.")
239
+
240
+ def __len__(self):
241
+ return len(self.captions)
242
+
243
+ def __getitem__(self, idx):
244
+ images = self.transforms(Image.open(str(self.images[idx])))
245
+ texts = tokenize([str(self.captions[idx])])[0]
246
+ return images, texts
247
+
248
+
249
+ @dataclass
250
+ class DataInfo:
251
+ dataloader: DataLoader
252
+ sampler: DistributedSampler
253
+
254
+
255
+ def preprocess_txt(text):
256
+ return tokenize([str(text)])[0]
257
+
258
+
259
+ def get_dataset_size(shards, sizefilepath_=None, is_local=True):
260
+ if isinstance(shards, list):
261
+ size_list = []
262
+ for s in shards:
263
+ size_list.append(
264
+ get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
265
+ )
266
+ else:
267
+ if not is_local:
268
+ for n in dataset_split.keys():
269
+ if n in shards.split("/"):
270
+ break
271
+ for s in dataset_split[n]:
272
+ if s in shards.split("/"):
273
+ break
274
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
275
+ shards_list = list(braceexpand.braceexpand(shards))
276
+ dir_path = os.path.dirname(shards)
277
+ if sizefilepath_ is not None:
278
+ sizes = json.load(open(sizefilepath_, "r"))
279
+ total_size = sum(
280
+ [
281
+ int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
282
+ for shard in shards_list
283
+ ]
284
+ )
285
+ else:
286
+ sizes_filename = os.path.join(dir_path, "sizes.json")
287
+ len_filename = os.path.join(dir_path, "__len__")
288
+ if os.path.exists(sizes_filename):
289
+ sizes = json.load(open(sizes_filename, "r"))
290
+ total_size = sum(
291
+ [int(sizes[os.path.basename(shard)]) for shard in shards_list]
292
+ )
293
+ elif os.path.exists(len_filename):
294
+ # FIXME this used to be eval(open(...)) but that seemed rather unsafe
295
+ total_size = ast.literal_eval(open(len_filename, "r").read())
296
+ else:
297
+ raise Exception(
298
+ "Cannot find sizes file for dataset. Please specify the path to the file."
299
+ )
300
+ # total_size = None # num samples undefined
301
+ # some common dataset sizes (at time of authors last download)
302
+ # cc3m-train: 2905954
303
+ # cc12m: 10968539
304
+ # LAION-400m: 407332084
305
+ num_shards = len(shards_list)
306
+ if isinstance(shards, list):
307
+ return sum(size_list), len(shards)
308
+ else:
309
+ return total_size, num_shards
310
+
311
+
312
+ def get_imagenet(args, preprocess_fns, split):
313
+ assert split in ["train", "val", "v2"]
314
+ is_train = split == "train"
315
+ preprocess_train, preprocess_val = preprocess_fns
316
+
317
+ if split == "v2":
318
+ from imagenetv2_pytorch import ImageNetV2Dataset
319
+
320
+ dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
321
+ else:
322
+ if is_train:
323
+ data_path = args.imagenet_train
324
+ preprocess_fn = preprocess_train
325
+ else:
326
+ data_path = args.imagenet_val
327
+ preprocess_fn = preprocess_val
328
+ assert data_path
329
+
330
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
331
+
332
+ if is_train:
333
+ idxs = np.zeros(len(dataset.targets))
334
+ target_array = np.array(dataset.targets)
335
+ k = 50
336
+ for c in range(1000):
337
+ m = target_array == c
338
+ n = len(idxs[m])
339
+ arr = np.zeros(n)
340
+ arr[:k] = 1
341
+ np.random.shuffle(arr)
342
+ idxs[m] = arr
343
+
344
+ idxs = idxs.astype("int")
345
+ sampler = SubsetRandomSampler(np.where(idxs)[0])
346
+ else:
347
+ sampler = None
348
+
349
+ dataloader = torch.utils.data.DataLoader(
350
+ dataset,
351
+ batch_size=args.batch_size,
352
+ num_workers=args.workers,
353
+ sampler=sampler,
354
+ )
355
+
356
+ return DataInfo(dataloader, sampler)
357
+
358
+
359
+ def count_samples(dataloader):
360
+ os.environ["WDS_EPOCH"] = "0"
361
+ n_elements, n_batches = 0, 0
362
+ for images, texts in dataloader:
363
+ n_batches += 1
364
+ n_elements += len(images)
365
+ assert len(images) == len(texts)
366
+ return n_elements, n_batches
367
+
368
+
369
+ def filter_no_caption(sample):
370
+ return "txt" in sample
371
+
372
+
373
+ def log_and_continue(exn):
374
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
375
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
376
+ return True
377
+
378
+
379
+ _SHARD_SHUFFLE_SIZE = 2000
380
+ _SHARD_SHUFFLE_INITIAL = 500
381
+ _SAMPLE_SHUFFLE_SIZE = 5000
382
+ _SAMPLE_SHUFFLE_INITIAL = 1000
383
+
384
+
385
+ def sample_prop(sizefile, inputs, proportion, is_local=True):
386
+ """
387
+ Sample a proportion of the data.
388
+ """
389
+ file_path_dict = {
390
+ os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
391
+ for i in range(len(inputs))
392
+ }
393
+ sampled_filepath_dict = {}
394
+ sampled_size_dict = {}
395
+ if not is_local:
396
+ if os.path.exists("sizes.json"):
397
+ os.remove("sizes.json")
398
+ wget.download(sizefile, "sizes.json")
399
+ sizefile = "sizes.json"
400
+ with open(sizefile, "r", encoding="UTF-8") as f:
401
+ load_dict = json.load(f)
402
+ L = int(len(file_path_dict) * proportion)
403
+ subkeys = random.sample(file_path_dict.keys(), L)
404
+ for k in subkeys:
405
+ sampled_size_dict[k] = load_dict[k]
406
+ sampled_filepath_dict[k] = file_path_dict[k]
407
+ return (
408
+ sum(sampled_size_dict.values()),
409
+ L,
410
+ [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
411
+ sampled_size_dict,
412
+ )
413
+
414
+
415
+ def get_mel(audio_data, audio_cfg):
416
+ # mel shape: (n_mels, T)
417
+ mel = torchaudio.transforms.MelSpectrogram(
418
+ sample_rate=audio_cfg["sample_rate"],
419
+ n_fft=audio_cfg["window_size"],
420
+ win_length=audio_cfg["window_size"],
421
+ hop_length=audio_cfg["hop_size"],
422
+ center=True,
423
+ pad_mode="reflect",
424
+ power=2.0,
425
+ norm=None,
426
+ onesided=True,
427
+ n_mels=64,
428
+ f_min=audio_cfg["fmin"],
429
+ f_max=audio_cfg["fmax"],
430
+ ).to(audio_data.device)
431
+ mel = mel(audio_data)
432
+ # Align to librosa:
433
+ # librosa_melspec = librosa.feature.melspectrogram(
434
+ # waveform,
435
+ # sr=audio_cfg['sample_rate'],
436
+ # n_fft=audio_cfg['window_size'],
437
+ # hop_length=audio_cfg['hop_size'],
438
+ # win_length=audio_cfg['window_size'],
439
+ # center=True,
440
+ # pad_mode="reflect",
441
+ # power=2.0,
442
+ # n_mels=64,
443
+ # norm=None,
444
+ # htk=True,
445
+ # f_min=audio_cfg['fmin'],
446
+ # f_max=audio_cfg['fmax']
447
+ # )
448
+ # we use log mel spectrogram as input
449
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
450
+ return mel.T # (T, n_mels)
451
+
452
+
453
+ def get_audio_features(
454
+ sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
455
+ ):
456
+ """
457
+ Calculate and add audio features to sample.
458
+ Sample: a dict containing all the data of current sample.
459
+ audio_data: a tensor of shape (T) containing audio data.
460
+ max_len: the maximum length of audio data.
461
+ data_truncating: the method of truncating data.
462
+ data_filling: the method of filling data.
463
+ audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
464
+ """
465
+ with torch.no_grad():
466
+ if len(audio_data) > max_len:
467
+ if data_truncating == "rand_trunc":
468
+ longer = torch.tensor([True])
469
+ elif data_truncating == "fusion":
470
+ # fusion
471
+ mel = get_mel(audio_data, audio_cfg)
472
+ # split to three parts
473
+ chunk_frames = (
474
+ max_len // audio_cfg["hop_size"] + 1
475
+ ) # the +1 related to how the spectrogram is computed
476
+ total_frames = mel.shape[0]
477
+ if chunk_frames == total_frames:
478
+ # there is a corner case where the audio length is
479
+ # larger than max_len but smaller than max_len+hop_size.
480
+ # In this case, we just use the whole audio.
481
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
482
+ sample["mel_fusion"] = mel_fusion
483
+ longer = torch.tensor([False])
484
+ else:
485
+ ranges = np.array_split(
486
+ list(range(0, total_frames - chunk_frames + 1)), 3
487
+ )
488
+ # print('total_frames-chunk_frames:', total_frames-chunk_frames,
489
+ # 'len(audio_data):', len(audio_data),
490
+ # 'chunk_frames:', chunk_frames,
491
+ # 'total_frames:', total_frames)
492
+ if len(ranges[1]) == 0:
493
+ # if the audio is too short, we just use the first chunk
494
+ ranges[1] = [0]
495
+ if len(ranges[2]) == 0:
496
+ # if the audio is too short, we just use the first chunk
497
+ ranges[2] = [0]
498
+ # randomly choose index for each part
499
+ idx_front = np.random.choice(ranges[0])
500
+ idx_middle = np.random.choice(ranges[1])
501
+ idx_back = np.random.choice(ranges[2])
502
+ # select mel
503
+ mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
504
+ mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
505
+ mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
506
+
507
+ # shrink the mel
508
+ mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(
509
+ mel[None]
510
+ )[0]
511
+ # logging.info(f"mel_shrink.shape: {mel_shrink.shape}")
512
+
513
+ # stack
514
+ mel_fusion = torch.stack(
515
+ [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink],
516
+ dim=0,
517
+ )
518
+ sample["mel_fusion"] = mel_fusion
519
+ longer = torch.tensor([True])
520
+ else:
521
+ raise NotImplementedError(
522
+ f"data_truncating {data_truncating} not implemented"
523
+ )
524
+ # random crop to max_len (for compatibility)
525
+ overflow = len(audio_data) - max_len
526
+ idx = np.random.randint(0, overflow + 1)
527
+ audio_data = audio_data[idx : idx + max_len]
528
+
529
+ else: # padding if too short
530
+ if len(audio_data) < max_len: # do nothing if equal
531
+ if data_filling == "repeatpad":
532
+ n_repeat = int(max_len / len(audio_data))
533
+ audio_data = audio_data.repeat(n_repeat)
534
+ # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0)
535
+ # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0]
536
+ audio_data = F.pad(
537
+ audio_data,
538
+ (0, max_len - len(audio_data)),
539
+ mode="constant",
540
+ value=0,
541
+ )
542
+ elif data_filling == "pad":
543
+ audio_data = F.pad(
544
+ audio_data,
545
+ (0, max_len - len(audio_data)),
546
+ mode="constant",
547
+ value=0,
548
+ )
549
+ elif data_filling == "repeat":
550
+ n_repeat = int(max_len / len(audio_data))
551
+ audio_data = audio_data.repeat(n_repeat + 1)[:max_len]
552
+ else:
553
+ raise NotImplementedError(
554
+ f"data_filling {data_filling} not implemented"
555
+ )
556
+ if data_truncating == "fusion":
557
+ mel = get_mel(audio_data, audio_cfg)
558
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
559
+ sample["mel_fusion"] = mel_fusion
560
+ longer = torch.tensor([False])
561
+
562
+ sample["longer"] = longer
563
+ sample["waveform"] = audio_data
564
+
565
+ return sample
566
+
567
+
568
+ def preprocess(
569
+ sample,
570
+ audio_ext,
571
+ text_ext,
572
+ max_len,
573
+ audio_cfg,
574
+ class_index_dict=None,
575
+ data_filling="pad",
576
+ data_truncating="rand_trunc",
577
+ text_augment_selection=None,
578
+ ):
579
+ """
580
+ Preprocess a single sample for wdsdataloader.
581
+ """
582
+ audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
583
+ audio_data = int16_to_float32(float32_to_int16(audio_data))
584
+ audio_data = torch.tensor(audio_data).float()
585
+
586
+ # TODO: (yusong) to be include in the future
587
+ # # if torchaudio not installed, use soundfile to load audio
588
+ # if torchaudio is None:
589
+ # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
590
+ # audio_data = torch.tensor(audio_data).float()
591
+ # else:
592
+ # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
593
+ # with tempfile.TemporaryDirectory() as dirname:
594
+ # os.makedirs(dirname, exist_ok=True)
595
+ # fname = os.path.join(dirname, f"file.flac")
596
+ # with open(fname, "wb") as stream:
597
+ # stream.write(sample[audio_ext])
598
+ # audio_data, orig_sr = torchaudio.load(fname)
599
+ # audio_data = audio_data[0, :].float()
600
+
601
+ sample = get_audio_features(
602
+ sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
603
+ )
604
+ del sample[audio_ext]
605
+
606
+ try:
607
+ json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
608
+ except:
609
+ print("sample[__url__]:", sample["__url__"])
610
+
611
+ # For selecting augmented text from dataset
612
+ if text_augment_selection is None or text_augment_selection == "none":
613
+ texts = json_dict_raw["text"]
614
+ elif text_augment_selection == "all":
615
+ if "text_augment_all" in json_dict_raw.keys():
616
+ texts = json_dict_raw["text_augment_all"]
617
+ else:
618
+ texts = json_dict_raw["text"]
619
+ elif text_augment_selection == "augment_only":
620
+ if "text_augment_all" in json_dict_raw.keys():
621
+ if json_dict_raw["text_augment_t5"] is None:
622
+ texts = json_dict_raw["text"]
623
+ else:
624
+ texts = json_dict_raw["text_augment_t5"]
625
+ else:
626
+ texts = json_dict_raw["text"]
627
+ else:
628
+ raise NotImplementedError(
629
+ f"text_augment_selection {text_augment_selection} not implemented"
630
+ )
631
+ sample["full_text"] = texts
632
+
633
+ if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
634
+ texts = random.choice(texts)
635
+ sample["raw_text"] = texts
636
+ sample["text"] = tokenizer(texts) # text shape: [num_token]
637
+ if class_index_dict is not None:
638
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
639
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
640
+ # key, val = class_index_dict
641
+ # key = key[:].split('\n')
642
+ # _dict = {k: v for k, v in zip(key, val)}
643
+ sample["class_label"] = np.zeros(len(class_index_dict.keys()))
644
+ for x in json_dict_raw["tag"]:
645
+ sample["class_label"][class_index_dict[x]] = 1
646
+ sample["class_label"] = torch.tensor(sample["class_label"]).float()
647
+ del sample[text_ext]
648
+ sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
649
+ sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
650
+ sample["audio_orig_sr"] = orig_sr
651
+ return sample
652
+
653
+
654
+ def collate_fn(batch):
655
+ """
656
+ Collate function for wdsdataloader.
657
+ batch: a list of dict, each dict is a sample
658
+ """
659
+ # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
660
+ batch_dict = {}
661
+ for k in batch[0].keys():
662
+ if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
663
+ batch_dict[k] = {}
664
+ for kk in batch[0][k].keys():
665
+ tmp = []
666
+ for i in range(len(batch)):
667
+ tmp.append(batch[i][k][kk])
668
+ batch_dict[k][kk] = torch.vstack(tmp)
669
+ elif isinstance(batch[0][k], torch.Tensor):
670
+ batch_dict[k] = torch.stack([sample[k] for sample in batch])
671
+ elif isinstance(batch[0][k], np.ndarray):
672
+ batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
673
+ else:
674
+ batch_dict[k] = [sample[k] for sample in batch]
675
+ return batch_dict
676
+
677
+
678
+ def get_wds_dataset(
679
+ args,
680
+ model_cfg,
681
+ is_train,
682
+ audio_ext="flac",
683
+ text_ext="json",
684
+ max_len=480000,
685
+ proportion=1.0,
686
+ sizefilepath_=None,
687
+ is_local=None,
688
+ ):
689
+ """
690
+ Get a dataset for wdsdataloader.
691
+ """
692
+ if is_local is None and (not args.remotedata is None):
693
+ is_local = not args.remotedata
694
+
695
+ input_shards = args.train_data if is_train else args.val_data
696
+ assert input_shards is not None
697
+
698
+ if not sizefilepath_ is None:
699
+ sizefilepath = sizefilepath_
700
+ else:
701
+ sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
702
+
703
+ if proportion != 1.0:
704
+ num_samples, num_shards, input_shards, _ = sample_prop(
705
+ sizefilepath, input_shards, proportion, is_local=is_local
706
+ )
707
+ else:
708
+ num_samples, num_shards = get_dataset_size(
709
+ input_shards, sizefilepath_=sizefilepath_, is_local=is_local
710
+ )
711
+
712
+ if not num_samples:
713
+ if is_train:
714
+ num_samples = args.train_num_samples
715
+ if not num_samples:
716
+ raise RuntimeError(
717
+ "Currently, number of dataset samples must be specified for training dataset. "
718
+ "Please specify via `--train-num-samples` if no dataset length info present."
719
+ )
720
+ else:
721
+ num_samples = (
722
+ args.val_num_samples or 0
723
+ ) # eval will just exhaust the iterator if not specified
724
+
725
+ pipeline = [wds.SimpleShardList(input_shards)]
726
+ # at this point we have an iterator over all the shards
727
+ # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
728
+ if is_train or args.parallel_eval:
729
+ pipeline.extend(
730
+ [
731
+ wds.detshuffle(
732
+ bufsize=_SHARD_SHUFFLE_SIZE,
733
+ initial=_SHARD_SHUFFLE_INITIAL,
734
+ seed=args.seed,
735
+ ),
736
+ wds.split_by_node,
737
+ wds.split_by_worker,
738
+ # at this point, we have an iterator over the shards assigned to each worker at each node
739
+ wds.tarfile_to_samples(handler=log_and_continue),
740
+ wds.shuffle(
741
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
742
+ initial=_SAMPLE_SHUFFLE_INITIAL,
743
+ rng=random.Random(args.seed),
744
+ ),
745
+ # wds.repeatedly, # FIXME determine if this is beneficial
746
+ ]
747
+ )
748
+ else:
749
+ pipeline.extend(
750
+ [
751
+ wds.split_by_worker,
752
+ # at this point, we have an iterator over the shards assigned to each worker
753
+ wds.tarfile_to_samples(handler=log_and_continue),
754
+ ]
755
+ )
756
+ pipeline.append(
757
+ wds.map(
758
+ partial(
759
+ preprocess,
760
+ audio_ext=audio_ext,
761
+ text_ext=text_ext,
762
+ max_len=max_len,
763
+ audio_cfg=model_cfg["audio_cfg"],
764
+ class_index_dict=copy.deepcopy(args.class_index_dict),
765
+ data_filling=args.data_filling,
766
+ data_truncating=args.data_truncating,
767
+ text_augment_selection=args.text_augment_selection,
768
+ )
769
+ ),
770
+ )
771
+
772
+ pipeline.append(
773
+ wds.batched(
774
+ args.batch_size,
775
+ partial=not (is_train or args.parallel_eval),
776
+ collation_fn=collate_fn,
777
+ )
778
+ )
779
+
780
+ dataset = wds.DataPipeline(*pipeline)
781
+ if is_train or args.parallel_eval:
782
+ # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
783
+ # (yusong): See comments below.
784
+ # roll over and repeat a few samples to get same number of full batches on each node
785
+ global_batch_size = args.batch_size * args.world_size
786
+ num_batches = math.ceil(num_samples / global_batch_size)
787
+ num_workers = max(1, args.workers)
788
+ num_worker_batches = math.ceil(
789
+ num_batches / num_workers
790
+ ) # per dataloader worker
791
+ num_batches = num_worker_batches * num_workers
792
+ num_samples = num_batches * global_batch_size
793
+ dataset = dataset.with_epoch(
794
+ num_worker_batches
795
+ ) # each worker is iterating over this
796
+ else:
797
+ # last batches are partial, eval is done on single (master) node
798
+ num_batches = math.ceil(num_samples / args.batch_size)
799
+
800
+ kwargs = {}
801
+ if args.horovod: # multi-node training on summit
802
+ kwargs["multiprocessing_context"] = "forkserver"
803
+
804
+ dataloader = wds.WebLoader(
805
+ dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
806
+ )
807
+
808
+ # FIXME not clear which approach is better, with_epoch before vs after dataloader?
809
+ # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
810
+ # if is_train:
811
+ # # roll over and repeat a few samples to get same number of full batches on each node
812
+ # global_batch_size = args.batch_size * args.world_size
813
+ # num_batches = math.ceil(num_samples / global_batch_size)
814
+ # num_workers = max(1, args.workers)
815
+ # num_batches = math.ceil(num_batches / num_workers) * num_workers
816
+ # num_samples = num_batches * global_batch_size
817
+ # dataloader = dataloader.with_epoch(num_batches)
818
+ # else:
819
+ # # last batches are partial, eval is done on single (master) node
820
+ # num_batches = math.ceil(num_samples / args.batch_size)
821
+
822
+ # add meta-data to dataloader instance for convenience
823
+ dataloader.num_batches = num_batches
824
+ dataloader.num_samples = num_samples
825
+
826
+ return DataInfo(dataloader, None)
827
+
828
+
829
+ def wds_batch_list2dict(
830
+ batch,
831
+ keys=[
832
+ "__url__",
833
+ "__key__",
834
+ "waveform",
835
+ "text",
836
+ "raw_text",
837
+ "audio_name",
838
+ "text_name",
839
+ "audio_orig_sr",
840
+ ],
841
+ ):
842
+ """
843
+ Return a dictionary of the batch, with keys as the names of the fields.
844
+ """
845
+ assert len(keys) == len(
846
+ batch
847
+ ), "batch must have same number of keys as keys argument"
848
+ return {keys[i]: batch[i] for i in range(len(batch))}
849
+
850
+
851
+ def get_csv_dataset(args, preprocess_fn, is_train):
852
+ input_filename = args.train_data if is_train else args.val_data
853
+ assert input_filename
854
+ dataset = CsvDataset(
855
+ input_filename,
856
+ preprocess_fn,
857
+ img_key=args.csv_img_key,
858
+ caption_key=args.csv_caption_key,
859
+ sep=args.csv_separator,
860
+ )
861
+ num_samples = len(dataset)
862
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
863
+ shuffle = is_train and sampler is None
864
+
865
+ dataloader = DataLoader(
866
+ dataset,
867
+ batch_size=args.batch_size,
868
+ shuffle=shuffle,
869
+ num_workers=args.workers,
870
+ pin_memory=True,
871
+ sampler=sampler,
872
+ drop_last=is_train,
873
+ )
874
+ dataloader.num_samples = num_samples
875
+ dataloader.num_batches = len(dataloader)
876
+
877
+ return DataInfo(dataloader, sampler)
878
+
879
+
880
+ def get_toy_dataset(args, model_cfg, is_train):
881
+ index_path = args.train_data if is_train else args.val_data
882
+ ipc_path = args.train_ipc if is_train else args.val_ipc
883
+ assert index_path and ipc_path
884
+ eval_mode = not is_train
885
+ dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
886
+
887
+ num_samples = len(dataset)
888
+ sampler = (
889
+ DistributedSampler(dataset, shuffle=False)
890
+ if args.distributed and is_train
891
+ else None
892
+ )
893
+
894
+ dataloader = DataLoader(
895
+ dataset,
896
+ batch_size=args.batch_size,
897
+ shuffle=False,
898
+ num_workers=args.workers,
899
+ sampler=sampler,
900
+ drop_last=is_train,
901
+ )
902
+ dataloader.num_samples = num_samples
903
+ dataloader.num_batches = len(dataloader)
904
+
905
+ return DataInfo(dataloader, sampler)
906
+
907
+
908
+ def get_dataset_fn(data_path, dataset_type):
909
+ if dataset_type == "webdataset":
910
+ return get_wds_dataset
911
+ elif dataset_type == "csv":
912
+ return get_csv_dataset
913
+ elif dataset_type == "auto":
914
+ ext = data_path.split(".")[-1]
915
+ if ext in ["csv", "tsv"]:
916
+ return get_csv_dataset
917
+ elif ext in ["tar"]:
918
+ return get_wds_dataset
919
+ else:
920
+ raise ValueError(
921
+ f"Tried to figure out dataset type, but failed for extention {ext}."
922
+ )
923
+ elif dataset_type == "toy":
924
+ return get_toy_dataset
925
+ else:
926
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
927
+
928
+
929
+ def get_data(args, model_cfg):
930
+ data = {}
931
+
932
+ args.class_index_dict = load_class_label(args.class_label_path)
933
+
934
+ if args.datasetinfos is None:
935
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
936
+ if args.dataset_type == "webdataset":
937
+ args.train_data = get_tar_path_from_dataset_name(
938
+ args.datasetnames,
939
+ args.datasetinfos,
940
+ islocal=not args.remotedata,
941
+ proportion=args.dataset_proportion,
942
+ dataset_path=args.datasetpath,
943
+ full_dataset=args.full_train_dataset,
944
+ )
945
+
946
+ if args.full_train_dataset is None:
947
+ args.full_train_dataset = []
948
+ if args.exclude_eval_dataset is None:
949
+ args.exclude_eval_dataset = []
950
+ excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
951
+
952
+ val_dataset_names = (
953
+ [n for n in args.datasetnames if n not in excluded_eval_datasets]
954
+ if excluded_eval_datasets
955
+ else args.datasetnames
956
+ )
957
+ args.val_dataset_names = val_dataset_names
958
+ args.val_data = get_tar_path_from_dataset_name(
959
+ val_dataset_names,
960
+ ["valid", "test", "eval"],
961
+ islocal=not args.remotedata,
962
+ proportion=1,
963
+ dataset_path=args.datasetpath,
964
+ full_dataset=None,
965
+ )
966
+
967
+ if args.train_data:
968
+ data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
969
+ args, model_cfg, is_train=True
970
+ )
971
+
972
+ if args.val_data:
973
+ data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
974
+ args, model_cfg, is_train=False
975
+ )
976
+
977
+ return data