sanchit-gandhi commited on
Commit
f1daa60
·
1 Parent(s): d03edfa

camera ready

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -0
  2. audioldm/__init__.py +0 -3
  3. audioldm/audio/__init__.py +0 -0
  4. audioldm/audio/audio_processing.py +0 -100
  5. audioldm/audio/stft.py +0 -180
  6. audioldm/audio/tools.py +0 -33
  7. audioldm/clap/__init__.py +0 -0
  8. audioldm/clap/encoders.py +0 -170
  9. audioldm/clap/open_clip/__init__.py +0 -25
  10. audioldm/clap/open_clip/bert.py +0 -40
  11. audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  12. audioldm/clap/open_clip/factory.py +0 -277
  13. audioldm/clap/open_clip/feature_fusion.py +0 -192
  14. audioldm/clap/open_clip/htsat.py +0 -1308
  15. audioldm/clap/open_clip/linear_probe.py +0 -66
  16. audioldm/clap/open_clip/loss.py +0 -398
  17. audioldm/clap/open_clip/model.py +0 -936
  18. audioldm/clap/open_clip/model_configs/HTSAT-base.json +0 -23
  19. audioldm/clap/open_clip/model_configs/HTSAT-large.json +0 -23
  20. audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +0 -23
  21. audioldm/clap/open_clip/model_configs/HTSAT-tiny.json +0 -23
  22. audioldm/clap/open_clip/model_configs/PANN-10.json +0 -23
  23. audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json +0 -23
  24. audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +0 -23
  25. audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +0 -23
  26. audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json +0 -23
  27. audioldm/clap/open_clip/model_configs/PANN-14.json +0 -23
  28. audioldm/clap/open_clip/model_configs/PANN-6.json +0 -23
  29. audioldm/clap/open_clip/model_configs/RN101-quickgelu.json +0 -22
  30. audioldm/clap/open_clip/model_configs/RN101.json +0 -21
  31. audioldm/clap/open_clip/model_configs/RN50-quickgelu.json +0 -22
  32. audioldm/clap/open_clip/model_configs/RN50.json +0 -21
  33. audioldm/clap/open_clip/model_configs/RN50x16.json +0 -21
  34. audioldm/clap/open_clip/model_configs/RN50x4.json +0 -21
  35. audioldm/clap/open_clip/model_configs/ViT-B-16.json +0 -16
  36. audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +0 -17
  37. audioldm/clap/open_clip/model_configs/ViT-B-32.json +0 -16
  38. audioldm/clap/open_clip/model_configs/ViT-L-14.json +0 -16
  39. audioldm/clap/open_clip/openai.py +0 -156
  40. audioldm/clap/open_clip/pann_model.py +0 -703
  41. audioldm/clap/open_clip/pretrained.py +0 -167
  42. audioldm/clap/open_clip/timm_model.py +0 -112
  43. audioldm/clap/open_clip/tokenizer.py +0 -197
  44. audioldm/clap/open_clip/transform.py +0 -45
  45. audioldm/clap/open_clip/utils.py +0 -361
  46. audioldm/clap/open_clip/version.py +0 -1
  47. audioldm/clap/training/__init__.py +0 -0
  48. audioldm/clap/training/audioset_textmap.npy +0 -3
  49. audioldm/clap/training/data.py +0 -977
  50. audioldm/clap/training/distributed.py +0 -150
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
  license: bigscience-openrail-m
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: bigscience-openrail-m
11
+ duplicated_from: haoheliu/audioldm-text-to-audio-generation
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
audioldm/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .ldm import LatentDiffusion
2
- from .utils import seed_everything
3
- from .pipeline import *
 
 
 
 
audioldm/audio/__init__.py DELETED
File without changes
audioldm/audio/audio_processing.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import librosa.util as librosa_util
4
- from scipy.signal import get_window
5
-
6
-
7
- def window_sumsquare(
8
- window,
9
- n_frames,
10
- hop_length,
11
- win_length,
12
- n_fft,
13
- dtype=np.float32,
14
- norm=None,
15
- ):
16
- """
17
- # from librosa 0.6
18
- Compute the sum-square envelope of a window function at a given hop length.
19
-
20
- This is used to estimate modulation effects induced by windowing
21
- observations in short-time fourier transforms.
22
-
23
- Parameters
24
- ----------
25
- window : string, tuple, number, callable, or list-like
26
- Window specification, as in `get_window`
27
-
28
- n_frames : int > 0
29
- The number of analysis frames
30
-
31
- hop_length : int > 0
32
- The number of samples to advance between frames
33
-
34
- win_length : [optional]
35
- The length of the window function. By default, this matches `n_fft`.
36
-
37
- n_fft : int > 0
38
- The length of each analysis frame.
39
-
40
- dtype : np.dtype
41
- The data type of the output
42
-
43
- Returns
44
- -------
45
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46
- The sum-squared envelope of the window function
47
- """
48
- if win_length is None:
49
- win_length = n_fft
50
-
51
- n = n_fft + hop_length * (n_frames - 1)
52
- x = np.zeros(n, dtype=dtype)
53
-
54
- # Compute the squared window at the desired length
55
- win_sq = get_window(window, win_length, fftbins=True)
56
- win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57
- win_sq = librosa_util.pad_center(win_sq, n_fft)
58
-
59
- # Fill the envelope
60
- for i in range(n_frames):
61
- sample = i * hop_length
62
- x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63
- return x
64
-
65
-
66
- def griffin_lim(magnitudes, stft_fn, n_iters=30):
67
- """
68
- PARAMS
69
- ------
70
- magnitudes: spectrogram magnitudes
71
- stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72
- """
73
-
74
- angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75
- angles = angles.astype(np.float32)
76
- angles = torch.autograd.Variable(torch.from_numpy(angles))
77
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78
-
79
- for i in range(n_iters):
80
- _, angles = stft_fn.transform(signal)
81
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82
- return signal
83
-
84
-
85
- def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86
- """
87
- PARAMS
88
- ------
89
- C: compression factor
90
- """
91
- return normalize_fun(torch.clamp(x, min=clip_val) * C)
92
-
93
-
94
- def dynamic_range_decompression(x, C=1):
95
- """
96
- PARAMS
97
- ------
98
- C: compression factor used to compress
99
- """
100
- return torch.exp(x) / C
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/stft.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
- from scipy.signal import get_window
5
- from librosa.util import pad_center, tiny
6
- from librosa.filters import mel as librosa_mel_fn
7
-
8
- from audioldm.audio.audio_processing import (
9
- dynamic_range_compression,
10
- dynamic_range_decompression,
11
- window_sumsquare,
12
- )
13
-
14
-
15
- class STFT(torch.nn.Module):
16
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17
-
18
- def __init__(self, filter_length, hop_length, win_length, window="hann"):
19
- super(STFT, self).__init__()
20
- self.filter_length = filter_length
21
- self.hop_length = hop_length
22
- self.win_length = win_length
23
- self.window = window
24
- self.forward_transform = None
25
- scale = self.filter_length / self.hop_length
26
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
27
-
28
- cutoff = int((self.filter_length / 2 + 1))
29
- fourier_basis = np.vstack(
30
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31
- )
32
-
33
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34
- inverse_basis = torch.FloatTensor(
35
- np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36
- )
37
-
38
- if window is not None:
39
- assert filter_length >= win_length
40
- # get window and zero center pad it to filter_length
41
- fft_window = get_window(window, win_length, fftbins=True)
42
- fft_window = pad_center(fft_window, filter_length)
43
- fft_window = torch.from_numpy(fft_window).float()
44
-
45
- # window the bases
46
- forward_basis *= fft_window
47
- inverse_basis *= fft_window
48
-
49
- self.register_buffer("forward_basis", forward_basis.float())
50
- self.register_buffer("inverse_basis", inverse_basis.float())
51
-
52
- def transform(self, input_data):
53
- num_batches = input_data.size(0)
54
- num_samples = input_data.size(1)
55
-
56
- self.num_samples = num_samples
57
-
58
- # similar to librosa, reflect-pad the input
59
- input_data = input_data.view(num_batches, 1, num_samples)
60
- input_data = F.pad(
61
- input_data.unsqueeze(1),
62
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63
- mode="reflect",
64
- )
65
- input_data = input_data.squeeze(1)
66
-
67
- forward_transform = F.conv1d(
68
- input_data,
69
- torch.autograd.Variable(self.forward_basis, requires_grad=False),
70
- stride=self.hop_length,
71
- padding=0,
72
- ).cpu()
73
-
74
- cutoff = int((self.filter_length / 2) + 1)
75
- real_part = forward_transform[:, :cutoff, :]
76
- imag_part = forward_transform[:, cutoff:, :]
77
-
78
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
79
- phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80
-
81
- return magnitude, phase
82
-
83
- def inverse(self, magnitude, phase):
84
- recombine_magnitude_phase = torch.cat(
85
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86
- )
87
-
88
- inverse_transform = F.conv_transpose1d(
89
- recombine_magnitude_phase,
90
- torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91
- stride=self.hop_length,
92
- padding=0,
93
- )
94
-
95
- if self.window is not None:
96
- window_sum = window_sumsquare(
97
- self.window,
98
- magnitude.size(-1),
99
- hop_length=self.hop_length,
100
- win_length=self.win_length,
101
- n_fft=self.filter_length,
102
- dtype=np.float32,
103
- )
104
- # remove modulation effects
105
- approx_nonzero_indices = torch.from_numpy(
106
- np.where(window_sum > tiny(window_sum))[0]
107
- )
108
- window_sum = torch.autograd.Variable(
109
- torch.from_numpy(window_sum), requires_grad=False
110
- )
111
- window_sum = window_sum
112
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113
- approx_nonzero_indices
114
- ]
115
-
116
- # scale by hop ratio
117
- inverse_transform *= float(self.filter_length) / self.hop_length
118
-
119
- inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120
- inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121
-
122
- return inverse_transform
123
-
124
- def forward(self, input_data):
125
- self.magnitude, self.phase = self.transform(input_data)
126
- reconstruction = self.inverse(self.magnitude, self.phase)
127
- return reconstruction
128
-
129
-
130
- class TacotronSTFT(torch.nn.Module):
131
- def __init__(
132
- self,
133
- filter_length,
134
- hop_length,
135
- win_length,
136
- n_mel_channels,
137
- sampling_rate,
138
- mel_fmin,
139
- mel_fmax,
140
- ):
141
- super(TacotronSTFT, self).__init__()
142
- self.n_mel_channels = n_mel_channels
143
- self.sampling_rate = sampling_rate
144
- self.stft_fn = STFT(filter_length, hop_length, win_length)
145
- mel_basis = librosa_mel_fn(
146
- sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147
- )
148
- mel_basis = torch.from_numpy(mel_basis).float()
149
- self.register_buffer("mel_basis", mel_basis)
150
-
151
- def spectral_normalize(self, magnitudes, normalize_fun):
152
- output = dynamic_range_compression(magnitudes, normalize_fun)
153
- return output
154
-
155
- def spectral_de_normalize(self, magnitudes):
156
- output = dynamic_range_decompression(magnitudes)
157
- return output
158
-
159
- def mel_spectrogram(self, y, normalize_fun=torch.log):
160
- """Computes mel-spectrograms from a batch of waves
161
- PARAMS
162
- ------
163
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164
-
165
- RETURNS
166
- -------
167
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168
- """
169
- assert torch.min(y.data) >= -1, torch.min(y.data)
170
- assert torch.max(y.data) <= 1, torch.max(y.data)
171
-
172
- magnitudes, phases = self.stft_fn.transform(y)
173
- magnitudes = magnitudes.data
174
- mel_output = torch.matmul(self.mel_basis, magnitudes)
175
- mel_output = self.spectral_normalize(mel_output, normalize_fun)
176
- energy = torch.norm(magnitudes, dim=1)
177
-
178
- log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
179
-
180
- return mel_output, log_magnitudes, energy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/tools.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- def get_mel_from_wav(audio, _stft):
6
- audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
7
- audio = torch.autograd.Variable(audio, requires_grad=False)
8
- melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
9
- melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
10
- log_magnitudes_stft = (
11
- torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
12
- )
13
- energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
14
- return melspec, log_magnitudes_stft, energy
15
-
16
-
17
- # def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
18
- # mel = torch.stack([mel])
19
- # mel_decompress = _stft.spectral_de_normalize(mel)
20
- # mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
21
- # spec_from_mel_scaling = 1000
22
- # spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
23
- # spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
24
- # spec_from_mel = spec_from_mel * spec_from_mel_scaling
25
-
26
- # audio = griffin_lim(
27
- # torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
28
- # )
29
-
30
- # audio = audio.squeeze()
31
- # audio = audio.cpu().numpy()
32
- # audio_path = out_filename
33
- # write(audio_path, _stft.sampling_rate, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/__init__.py DELETED
File without changes
audioldm/clap/encoders.py DELETED
@@ -1,170 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from audioldm.clap.open_clip import create_model
4
- from audioldm.clap.training.data import get_audio_features
5
- import torchaudio
6
- from transformers import RobertaTokenizer
7
- import torch.nn.functional as F
8
-
9
-
10
- class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
11
- def __init__(
12
- self,
13
- pretrained_path="",
14
- key="class",
15
- sampling_rate=16000,
16
- embed_mode="audio",
17
- amodel = "HTSAT-tiny",
18
- unconditional_prob=0.1,
19
- random_mute=False,
20
- max_random_mute_portion=0.5,
21
- training_mode=True,
22
- ):
23
- super().__init__()
24
-
25
- self.key = key
26
- self.device = "cpu"
27
- self.precision = "fp32"
28
- self.amodel = amodel
29
- self.tmodel = "roberta" # the best text encoder in our training
30
- self.enable_fusion = False # False if you do not want to use the fusion model
31
- self.fusion_type = "aff_2d"
32
- self.pretrained = pretrained_path
33
- self.embed_mode = embed_mode
34
- self.embed_mode_orig = embed_mode
35
- self.sampling_rate = sampling_rate
36
- self.unconditional_prob = unconditional_prob
37
- self.random_mute = random_mute
38
- self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
39
- self.max_random_mute_portion = max_random_mute_portion
40
- self.training_mode = training_mode
41
- self.model, self.model_cfg = create_model(
42
- self.amodel,
43
- self.tmodel,
44
- self.pretrained,
45
- precision=self.precision,
46
- device=self.device,
47
- enable_fusion=self.enable_fusion,
48
- fusion_type=self.fusion_type,
49
- )
50
- for p in self.model.parameters():
51
- p.requires_grad = False
52
-
53
- self.model.eval()
54
-
55
- def get_unconditional_condition(self, batchsize):
56
- self.unconditional_token = self.model.get_text_embedding(
57
- self.tokenizer(["", ""])
58
- )[0:1]
59
- return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
60
-
61
- def batch_to_list(self, batch):
62
- ret = []
63
- for i in range(batch.size(0)):
64
- ret.append(batch[i])
65
- return ret
66
-
67
- def make_decision(self, probability):
68
- if float(torch.rand(1)) < probability:
69
- return True
70
- else:
71
- return False
72
-
73
- def random_uniform(self, start, end):
74
- val = torch.rand(1).item()
75
- return start + (end - start) * val
76
-
77
- def _random_mute(self, waveform):
78
- # waveform: [bs, t-steps]
79
- t_steps = waveform.size(-1)
80
- for i in range(waveform.size(0)):
81
- mute_size = int(
82
- self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
83
- )
84
- mute_start = int(self.random_uniform(0, t_steps - mute_size))
85
- waveform[i, mute_start : mute_start + mute_size] = 0
86
- return waveform
87
-
88
- def cos_similarity(self, waveform, text):
89
- # waveform: [bs, t_steps]
90
- with torch.no_grad():
91
- self.embed_mode = "audio"
92
- audio_emb = self(waveform.cuda())
93
- self.embed_mode = "text"
94
- text_emb = self(text)
95
- similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
96
- return similarity.squeeze()
97
-
98
- def forward(self, batch, key=None):
99
- # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
100
- # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
101
- if self.model.training == True and not self.training_mode:
102
- print(
103
- "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
104
- )
105
- self.model, self.model_cfg = create_model(
106
- self.amodel,
107
- self.tmodel,
108
- self.pretrained,
109
- precision=self.precision,
110
- device="cuda",
111
- enable_fusion=self.enable_fusion,
112
- fusion_type=self.fusion_type,
113
- )
114
- for p in self.model.parameters():
115
- p.requires_grad = False
116
- self.model.eval()
117
-
118
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
119
- if self.embed_mode == "audio":
120
- with torch.no_grad():
121
- audio_dict_list = []
122
- assert (
123
- self.sampling_rate == 16000
124
- ), "We only support 16000 sampling rate"
125
- if self.random_mute:
126
- batch = self._random_mute(batch)
127
- # batch: [bs, 1, t-samples]
128
- batch = torchaudio.functional.resample(
129
- batch, orig_freq=self.sampling_rate, new_freq=48000
130
- )
131
- for waveform in self.batch_to_list(batch):
132
- audio_dict = {}
133
- audio_dict = get_audio_features(
134
- audio_dict,
135
- waveform,
136
- 480000,
137
- data_truncating="fusion",
138
- data_filling="repeatpad",
139
- audio_cfg=self.model_cfg["audio_cfg"],
140
- )
141
- audio_dict_list.append(audio_dict)
142
- # [bs, 512]
143
- embed = self.model.get_audio_embedding(audio_dict_list)
144
- elif self.embed_mode == "text":
145
- with torch.no_grad():
146
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
147
- text_data = self.tokenizer(batch)
148
- embed = self.model.get_text_embedding(text_data)
149
-
150
- embed = embed.unsqueeze(1)
151
- self.unconditional_token = self.model.get_text_embedding(
152
- self.tokenizer(["", ""])
153
- )[0:1]
154
-
155
- for i in range(embed.size(0)):
156
- if self.make_decision(self.unconditional_prob):
157
- embed[i] = self.unconditional_token
158
-
159
- # [bs, 1, 512]
160
- return embed.detach()
161
-
162
- def tokenizer(self, text):
163
- result = self.tokenize(
164
- text,
165
- padding="max_length",
166
- truncation=True,
167
- max_length=512,
168
- return_tensors="pt",
169
- )
170
- return {k: v.squeeze(0) for k, v in result.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .factory import (
2
- list_models,
3
- create_model,
4
- create_model_and_transforms,
5
- add_model_config,
6
- )
7
- from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
- from .model import (
9
- CLAP,
10
- CLAPTextCfg,
11
- CLAPVisionCfg,
12
- CLAPAudioCfp,
13
- convert_weights_to_fp16,
14
- trace_model,
15
- )
16
- from .openai import load_openai_model, list_openai_models
17
- from .pretrained import (
18
- list_pretrained,
19
- list_pretrained_tag_models,
20
- list_pretrained_model_tags,
21
- get_pretrained_url,
22
- download_pretrained,
23
- )
24
- from .tokenizer import SimpleTokenizer, tokenize
25
- from .transform import image_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/bert.py DELETED
@@ -1,40 +0,0 @@
1
- from transformers import BertTokenizer, BertModel
2
-
3
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4
- model = BertModel.from_pretrained("bert-base-uncased")
5
- text = "Replace me by any text you'd like."
6
-
7
-
8
- def bert_embeddings(text):
9
- # text = "Replace me by any text you'd like."
10
- encoded_input = tokenizer(text, return_tensors="pt")
11
- output = model(**encoded_input)
12
- return output
13
-
14
-
15
- from transformers import RobertaTokenizer, RobertaModel
16
-
17
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18
- model = RobertaModel.from_pretrained("roberta-base")
19
- text = "Replace me by any text you'd like."
20
-
21
-
22
- def Roberta_embeddings(text):
23
- # text = "Replace me by any text you'd like."
24
- encoded_input = tokenizer(text, return_tensors="pt")
25
- output = model(**encoded_input)
26
- return output
27
-
28
-
29
- from transformers import BartTokenizer, BartModel
30
-
31
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32
- model = BartModel.from_pretrained("facebook/bart-base")
33
- text = "Replace me by any text you'd like."
34
-
35
-
36
- def bart_embeddings(text):
37
- # text = "Replace me by any text you'd like."
38
- encoded_input = tokenizer(text, return_tensors="pt")
39
- output = model(**encoded_input)
40
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
audioldm/clap/open_clip/factory.py DELETED
@@ -1,277 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- import re
6
- from copy import deepcopy
7
- from pathlib import Path
8
-
9
- import torch
10
-
11
- from .model import CLAP, convert_weights_to_fp16
12
- from .openai import load_openai_model
13
- from .pretrained import get_pretrained_url, download_pretrained
14
- from .transform import image_transform
15
-
16
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
-
19
-
20
- def _natural_key(string_):
21
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
-
23
-
24
- def _rescan_model_configs():
25
- global _MODEL_CONFIGS
26
-
27
- config_ext = (".json",)
28
- config_files = []
29
- for config_path in _MODEL_CONFIG_PATHS:
30
- if config_path.is_file() and config_path.suffix in config_ext:
31
- config_files.append(config_path)
32
- elif config_path.is_dir():
33
- for ext in config_ext:
34
- config_files.extend(config_path.glob(f"*{ext}"))
35
-
36
- for cf in config_files:
37
- if os.path.basename(cf)[0] == ".":
38
- continue # Ignore hidden files
39
-
40
- with open(cf, "r") as f:
41
- model_cfg = json.load(f)
42
- if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
43
- _MODEL_CONFIGS[cf.stem] = model_cfg
44
-
45
- _MODEL_CONFIGS = {
46
- k: v
47
- for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
48
- }
49
-
50
-
51
- _rescan_model_configs() # initial populate of model config registry
52
-
53
-
54
- def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
55
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
56
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
57
- state_dict = checkpoint["state_dict"]
58
- else:
59
- state_dict = checkpoint
60
- if skip_params:
61
- if next(iter(state_dict.items()))[0].startswith("module"):
62
- state_dict = {k[7:]: v for k, v in state_dict.items()}
63
- # for k in state_dict:
64
- # if k.startswith('transformer'):
65
- # v = state_dict.pop(k)
66
- # state_dict['text_branch.' + k[12:]] = v
67
- return state_dict
68
-
69
-
70
- def create_model(
71
- amodel_name: str,
72
- tmodel_name: str,
73
- pretrained: str = "",
74
- precision: str = "fp32",
75
- device: torch.device = torch.device("cpu"),
76
- jit: bool = False,
77
- force_quick_gelu: bool = False,
78
- openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
79
- skip_params=True,
80
- pretrained_audio: str = "",
81
- pretrained_text: str = "",
82
- enable_fusion: bool = False,
83
- fusion_type: str = "None"
84
- # pretrained_image: bool = False,
85
- ):
86
- amodel_name = amodel_name.replace(
87
- "/", "-"
88
- ) # for callers using old naming with / in ViT names
89
- pretrained_orig = pretrained
90
- pretrained = pretrained.lower()
91
- if pretrained == "openai":
92
- if amodel_name in _MODEL_CONFIGS:
93
- logging.info(f"Loading {amodel_name} model config.")
94
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
95
- else:
96
- logging.error(
97
- f"Model config for {amodel_name} not found; available models {list_models()}."
98
- )
99
- raise RuntimeError(f"Model config for {amodel_name} not found.")
100
-
101
- logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
102
- # Hard Code in model name
103
- model_cfg["text_cfg"]["model_type"] = tmodel_name
104
- model = load_openai_model(
105
- "ViT-B-16",
106
- model_cfg,
107
- device=device,
108
- jit=jit,
109
- cache_dir=openai_model_cache_dir,
110
- enable_fusion=enable_fusion,
111
- fusion_type=fusion_type,
112
- )
113
- # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
114
- if precision == "amp" or precision == "fp32":
115
- model = model.float()
116
- else:
117
- if amodel_name in _MODEL_CONFIGS:
118
- logging.info(f"Loading {amodel_name} model config.")
119
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
120
- else:
121
- logging.error(
122
- f"Model config for {amodel_name} not found; available models {list_models()}."
123
- )
124
- raise RuntimeError(f"Model config for {amodel_name} not found.")
125
-
126
- if force_quick_gelu:
127
- # override for use of QuickGELU on non-OpenAI transformer models
128
- model_cfg["quick_gelu"] = True
129
-
130
- # if pretrained_image:
131
- # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
132
- # # pretrained weight loading for timm models set via vision_cfg
133
- # model_cfg['vision_cfg']['timm_model_pretrained'] = True
134
- # else:
135
- # assert False, 'pretrained image towers currently only supported for timm models'
136
- model_cfg["text_cfg"]["model_type"] = tmodel_name
137
- model_cfg["enable_fusion"] = enable_fusion
138
- model_cfg["fusion_type"] = fusion_type
139
- model = CLAP(**model_cfg)
140
-
141
- if pretrained:
142
- checkpoint_path = ""
143
- url = get_pretrained_url(amodel_name, pretrained)
144
- if url:
145
- checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
146
- elif os.path.exists(pretrained_orig):
147
- checkpoint_path = pretrained_orig
148
- if checkpoint_path:
149
- logging.info(
150
- f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
151
- )
152
- ckpt = load_state_dict(checkpoint_path, skip_params=True)
153
- model.load_state_dict(ckpt)
154
- param_names = [n for n, p in model.named_parameters()]
155
- # for n in param_names:
156
- # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
157
- else:
158
- logging.warning(
159
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
160
- )
161
- raise RuntimeError(
162
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
163
- )
164
-
165
- if pretrained_audio:
166
- if amodel_name.startswith("PANN"):
167
- if "Cnn14_mAP" in pretrained_audio: # official checkpoint
168
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
169
- audio_ckpt = audio_ckpt["model"]
170
- keys = list(audio_ckpt.keys())
171
- for key in keys:
172
- if (
173
- "spectrogram_extractor" not in key
174
- and "logmel_extractor" not in key
175
- ):
176
- v = audio_ckpt.pop(key)
177
- audio_ckpt["audio_branch." + key] = v
178
- elif os.path.basename(pretrained_audio).startswith(
179
- "PANN"
180
- ): # checkpoint trained via HTSAT codebase
181
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
182
- audio_ckpt = audio_ckpt["state_dict"]
183
- keys = list(audio_ckpt.keys())
184
- for key in keys:
185
- if key.startswith("sed_model"):
186
- v = audio_ckpt.pop(key)
187
- audio_ckpt["audio_branch." + key[10:]] = v
188
- elif os.path.basename(pretrained_audio).startswith(
189
- "finetuned"
190
- ): # checkpoint trained via linear probe codebase
191
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
192
- else:
193
- raise ValueError("Unknown audio checkpoint")
194
- elif amodel_name.startswith("HTSAT"):
195
- if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
196
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
197
- audio_ckpt = audio_ckpt["state_dict"]
198
- keys = list(audio_ckpt.keys())
199
- for key in keys:
200
- if key.startswith("sed_model") and (
201
- "spectrogram_extractor" not in key
202
- and "logmel_extractor" not in key
203
- ):
204
- v = audio_ckpt.pop(key)
205
- audio_ckpt["audio_branch." + key[10:]] = v
206
- elif os.path.basename(pretrained_audio).startswith(
207
- "HTSAT"
208
- ): # checkpoint trained via HTSAT codebase
209
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
210
- audio_ckpt = audio_ckpt["state_dict"]
211
- keys = list(audio_ckpt.keys())
212
- for key in keys:
213
- if key.startswith("sed_model"):
214
- v = audio_ckpt.pop(key)
215
- audio_ckpt["audio_branch." + key[10:]] = v
216
- elif os.path.basename(pretrained_audio).startswith(
217
- "finetuned"
218
- ): # checkpoint trained via linear probe codebase
219
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
220
- else:
221
- raise ValueError("Unknown audio checkpoint")
222
- else:
223
- raise f"this audio encoder pretrained checkpoint is not support"
224
-
225
- model.load_state_dict(audio_ckpt, strict=False)
226
- logging.info(
227
- f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
228
- )
229
- param_names = [n for n, p in model.named_parameters()]
230
- for n in param_names:
231
- print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
232
-
233
- model.to(device=device)
234
- if precision == "fp16":
235
- assert device.type != "cpu"
236
- convert_weights_to_fp16(model)
237
-
238
- if jit:
239
- model = torch.jit.script(model)
240
-
241
- return model, model_cfg
242
-
243
-
244
- def create_model_and_transforms(
245
- model_name: str,
246
- pretrained: str = "",
247
- precision: str = "fp32",
248
- device: torch.device = torch.device("cpu"),
249
- jit: bool = False,
250
- force_quick_gelu: bool = False,
251
- # pretrained_image: bool = False,
252
- ):
253
- model = create_model(
254
- model_name,
255
- pretrained,
256
- precision,
257
- device,
258
- jit,
259
- force_quick_gelu=force_quick_gelu,
260
- # pretrained_image=pretrained_image
261
- )
262
- preprocess_train = image_transform(model.visual.image_size, is_train=True)
263
- preprocess_val = image_transform(model.visual.image_size, is_train=False)
264
- return model, preprocess_train, preprocess_val
265
-
266
-
267
- def list_models():
268
- """enumerate available model architectures based on config files"""
269
- return list(_MODEL_CONFIGS.keys())
270
-
271
-
272
- def add_model_config(path):
273
- """add model config path or file and update registry"""
274
- if not isinstance(path, Path):
275
- path = Path(path)
276
- _MODEL_CONFIG_PATHS.append(path)
277
- _rescan_model_configs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/feature_fusion.py DELETED
@@ -1,192 +0,0 @@
1
- """
2
- Feature Fusion for Varible-Length Data Processing
3
- AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
- According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
-
11
- class DAF(nn.Module):
12
- """
13
- 直接相加 DirectAddFuse
14
- """
15
-
16
- def __init__(self):
17
- super(DAF, self).__init__()
18
-
19
- def forward(self, x, residual):
20
- return x + residual
21
-
22
-
23
- class iAFF(nn.Module):
24
- """
25
- 多特征融合 iAFF
26
- """
27
-
28
- def __init__(self, channels=64, r=4, type="2D"):
29
- super(iAFF, self).__init__()
30
- inter_channels = int(channels // r)
31
-
32
- if type == "1D":
33
- # 本地注意力
34
- self.local_att = nn.Sequential(
35
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
- nn.BatchNorm1d(inter_channels),
37
- nn.ReLU(inplace=True),
38
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
- nn.BatchNorm1d(channels),
40
- )
41
-
42
- # 全局注意力
43
- self.global_att = nn.Sequential(
44
- nn.AdaptiveAvgPool1d(1),
45
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
- nn.BatchNorm1d(inter_channels),
47
- nn.ReLU(inplace=True),
48
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
- nn.BatchNorm1d(channels),
50
- )
51
-
52
- # 第二次本地注意力
53
- self.local_att2 = nn.Sequential(
54
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
- nn.BatchNorm1d(inter_channels),
56
- nn.ReLU(inplace=True),
57
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
- nn.BatchNorm1d(channels),
59
- )
60
- # 第二次全局注意力
61
- self.global_att2 = nn.Sequential(
62
- nn.AdaptiveAvgPool1d(1),
63
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
- nn.BatchNorm1d(inter_channels),
65
- nn.ReLU(inplace=True),
66
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
- nn.BatchNorm1d(channels),
68
- )
69
- elif type == "2D":
70
- # 本地注意力
71
- self.local_att = nn.Sequential(
72
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
- nn.BatchNorm2d(inter_channels),
74
- nn.ReLU(inplace=True),
75
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
- nn.BatchNorm2d(channels),
77
- )
78
-
79
- # 全局注意力
80
- self.global_att = nn.Sequential(
81
- nn.AdaptiveAvgPool2d(1),
82
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
- nn.BatchNorm2d(inter_channels),
84
- nn.ReLU(inplace=True),
85
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
- nn.BatchNorm2d(channels),
87
- )
88
-
89
- # 第二次本地注意力
90
- self.local_att2 = nn.Sequential(
91
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
- nn.BatchNorm2d(inter_channels),
93
- nn.ReLU(inplace=True),
94
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
- nn.BatchNorm2d(channels),
96
- )
97
- # 第二次全局注意力
98
- self.global_att2 = nn.Sequential(
99
- nn.AdaptiveAvgPool2d(1),
100
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
- nn.BatchNorm2d(inter_channels),
102
- nn.ReLU(inplace=True),
103
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
- nn.BatchNorm2d(channels),
105
- )
106
- else:
107
- raise f"the type is not supported"
108
-
109
- self.sigmoid = nn.Sigmoid()
110
-
111
- def forward(self, x, residual):
112
- flag = False
113
- xa = x + residual
114
- if xa.size(0) == 1:
115
- xa = torch.cat([xa, xa], dim=0)
116
- flag = True
117
- xl = self.local_att(xa)
118
- xg = self.global_att(xa)
119
- xlg = xl + xg
120
- wei = self.sigmoid(xlg)
121
- xi = x * wei + residual * (1 - wei)
122
-
123
- xl2 = self.local_att2(xi)
124
- xg2 = self.global_att(xi)
125
- xlg2 = xl2 + xg2
126
- wei2 = self.sigmoid(xlg2)
127
- xo = x * wei2 + residual * (1 - wei2)
128
- if flag:
129
- xo = xo[0].unsqueeze(0)
130
- return xo
131
-
132
-
133
- class AFF(nn.Module):
134
- """
135
- 多特征融合 AFF
136
- """
137
-
138
- def __init__(self, channels=64, r=4, type="2D"):
139
- super(AFF, self).__init__()
140
- inter_channels = int(channels // r)
141
-
142
- if type == "1D":
143
- self.local_att = nn.Sequential(
144
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
- nn.BatchNorm1d(inter_channels),
146
- nn.ReLU(inplace=True),
147
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
- nn.BatchNorm1d(channels),
149
- )
150
- self.global_att = nn.Sequential(
151
- nn.AdaptiveAvgPool1d(1),
152
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
- nn.BatchNorm1d(inter_channels),
154
- nn.ReLU(inplace=True),
155
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
- nn.BatchNorm1d(channels),
157
- )
158
- elif type == "2D":
159
- self.local_att = nn.Sequential(
160
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
- nn.BatchNorm2d(inter_channels),
162
- nn.ReLU(inplace=True),
163
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
- nn.BatchNorm2d(channels),
165
- )
166
- self.global_att = nn.Sequential(
167
- nn.AdaptiveAvgPool2d(1),
168
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
- nn.BatchNorm2d(inter_channels),
170
- nn.ReLU(inplace=True),
171
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
- nn.BatchNorm2d(channels),
173
- )
174
- else:
175
- raise f"the type is not supported."
176
-
177
- self.sigmoid = nn.Sigmoid()
178
-
179
- def forward(self, x, residual):
180
- flag = False
181
- xa = x + residual
182
- if xa.size(0) == 1:
183
- xa = torch.cat([xa, xa], dim=0)
184
- flag = True
185
- xl = self.local_att(xa)
186
- xg = self.global_att(xa)
187
- xlg = xl + xg
188
- wei = self.sigmoid(xlg)
189
- xo = 2 * x * wei + 2 * residual * (1 - wei)
190
- if flag:
191
- xo = xo[0].unsqueeze(0)
192
- return xo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/htsat.py DELETED
@@ -1,1308 +0,0 @@
1
- # Ke Chen
2
3
- # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
- # Some layers designed on the model
5
- # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
- # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from itertools import repeat
12
- import collections.abc
13
- import math
14
- import warnings
15
-
16
- from torch.nn.init import _calculate_fan_in_and_fan_out
17
- import torch.utils.checkpoint as checkpoint
18
-
19
- import random
20
-
21
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
- from torchlibrosa.augmentation import SpecAugmentation
23
-
24
- from itertools import repeat
25
- from .utils import do_mixup, interpolate
26
-
27
- from .feature_fusion import iAFF, AFF, DAF
28
-
29
- # from PyTorch internals
30
- def _ntuple(n):
31
- def parse(x):
32
- if isinstance(x, collections.abc.Iterable):
33
- return x
34
- return tuple(repeat(x, n))
35
-
36
- return parse
37
-
38
-
39
- to_1tuple = _ntuple(1)
40
- to_2tuple = _ntuple(2)
41
- to_3tuple = _ntuple(3)
42
- to_4tuple = _ntuple(4)
43
- to_ntuple = _ntuple
44
-
45
-
46
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
49
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
50
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
51
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
52
- 'survival rate' as the argument.
53
- """
54
- if drop_prob == 0.0 or not training:
55
- return x
56
- keep_prob = 1 - drop_prob
57
- shape = (x.shape[0],) + (1,) * (
58
- x.ndim - 1
59
- ) # work with diff dim tensors, not just 2D ConvNets
60
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
61
- random_tensor.floor_() # binarize
62
- output = x.div(keep_prob) * random_tensor
63
- return output
64
-
65
-
66
- class DropPath(nn.Module):
67
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
68
-
69
- def __init__(self, drop_prob=None):
70
- super(DropPath, self).__init__()
71
- self.drop_prob = drop_prob
72
-
73
- def forward(self, x):
74
- return drop_path(x, self.drop_prob, self.training)
75
-
76
-
77
- class PatchEmbed(nn.Module):
78
- """2D Image to Patch Embedding"""
79
-
80
- def __init__(
81
- self,
82
- img_size=224,
83
- patch_size=16,
84
- in_chans=3,
85
- embed_dim=768,
86
- norm_layer=None,
87
- flatten=True,
88
- patch_stride=16,
89
- enable_fusion=False,
90
- fusion_type="None",
91
- ):
92
- super().__init__()
93
- img_size = to_2tuple(img_size)
94
- patch_size = to_2tuple(patch_size)
95
- patch_stride = to_2tuple(patch_stride)
96
- self.img_size = img_size
97
- self.patch_size = patch_size
98
- self.patch_stride = patch_stride
99
- self.grid_size = (
100
- img_size[0] // patch_stride[0],
101
- img_size[1] // patch_stride[1],
102
- )
103
- self.num_patches = self.grid_size[0] * self.grid_size[1]
104
- self.flatten = flatten
105
- self.in_chans = in_chans
106
- self.embed_dim = embed_dim
107
-
108
- self.enable_fusion = enable_fusion
109
- self.fusion_type = fusion_type
110
-
111
- padding = (
112
- (patch_size[0] - patch_stride[0]) // 2,
113
- (patch_size[1] - patch_stride[1]) // 2,
114
- )
115
-
116
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
117
- self.proj = nn.Conv2d(
118
- in_chans * 4,
119
- embed_dim,
120
- kernel_size=patch_size,
121
- stride=patch_stride,
122
- padding=padding,
123
- )
124
- else:
125
- self.proj = nn.Conv2d(
126
- in_chans,
127
- embed_dim,
128
- kernel_size=patch_size,
129
- stride=patch_stride,
130
- padding=padding,
131
- )
132
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
133
-
134
- if (self.enable_fusion) and (
135
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
136
- ):
137
- self.mel_conv2d = nn.Conv2d(
138
- in_chans,
139
- embed_dim,
140
- kernel_size=(patch_size[0], patch_size[1] * 3),
141
- stride=(patch_stride[0], patch_stride[1] * 3),
142
- padding=padding,
143
- )
144
- if self.fusion_type == "daf_2d":
145
- self.fusion_model = DAF()
146
- elif self.fusion_type == "aff_2d":
147
- self.fusion_model = AFF(channels=embed_dim, type="2D")
148
- elif self.fusion_type == "iaff_2d":
149
- self.fusion_model = iAFF(channels=embed_dim, type="2D")
150
-
151
- def forward(self, x, longer_idx=None):
152
- if (self.enable_fusion) and (
153
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
154
- ):
155
- global_x = x[:, 0:1, :, :]
156
-
157
- # global processing
158
- B, C, H, W = global_x.shape
159
- assert (
160
- H == self.img_size[0] and W == self.img_size[1]
161
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162
- global_x = self.proj(global_x)
163
- TW = global_x.size(-1)
164
- if len(longer_idx) > 0:
165
- # local processing
166
- local_x = x[longer_idx, 1:, :, :].contiguous()
167
- B, C, H, W = local_x.shape
168
- local_x = local_x.view(B * C, 1, H, W)
169
- local_x = self.mel_conv2d(local_x)
170
- local_x = local_x.view(
171
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
172
- )
173
- local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
174
- TB, TC, TH, _ = local_x.size()
175
- if local_x.size(-1) < TW:
176
- local_x = torch.cat(
177
- [
178
- local_x,
179
- torch.zeros(
180
- (TB, TC, TH, TW - local_x.size(-1)),
181
- device=global_x.device,
182
- ),
183
- ],
184
- dim=-1,
185
- )
186
- else:
187
- local_x = local_x[:, :, :, :TW]
188
-
189
- global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
190
- x = global_x
191
- else:
192
- B, C, H, W = x.shape
193
- assert (
194
- H == self.img_size[0] and W == self.img_size[1]
195
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
196
- x = self.proj(x)
197
-
198
- if self.flatten:
199
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
200
- x = self.norm(x)
201
- return x
202
-
203
-
204
- class Mlp(nn.Module):
205
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
206
-
207
- def __init__(
208
- self,
209
- in_features,
210
- hidden_features=None,
211
- out_features=None,
212
- act_layer=nn.GELU,
213
- drop=0.0,
214
- ):
215
- super().__init__()
216
- out_features = out_features or in_features
217
- hidden_features = hidden_features or in_features
218
- self.fc1 = nn.Linear(in_features, hidden_features)
219
- self.act = act_layer()
220
- self.fc2 = nn.Linear(hidden_features, out_features)
221
- self.drop = nn.Dropout(drop)
222
-
223
- def forward(self, x):
224
- x = self.fc1(x)
225
- x = self.act(x)
226
- x = self.drop(x)
227
- x = self.fc2(x)
228
- x = self.drop(x)
229
- return x
230
-
231
-
232
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
233
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
234
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
235
- def norm_cdf(x):
236
- # Computes standard normal cumulative distribution function
237
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
238
-
239
- if (mean < a - 2 * std) or (mean > b + 2 * std):
240
- warnings.warn(
241
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
242
- "The distribution of values may be incorrect.",
243
- stacklevel=2,
244
- )
245
-
246
- with torch.no_grad():
247
- # Values are generated by using a truncated uniform distribution and
248
- # then using the inverse CDF for the normal distribution.
249
- # Get upper and lower cdf values
250
- l = norm_cdf((a - mean) / std)
251
- u = norm_cdf((b - mean) / std)
252
-
253
- # Uniformly fill tensor with values from [l, u], then translate to
254
- # [2l-1, 2u-1].
255
- tensor.uniform_(2 * l - 1, 2 * u - 1)
256
-
257
- # Use inverse cdf transform for normal distribution to get truncated
258
- # standard normal
259
- tensor.erfinv_()
260
-
261
- # Transform to proper mean, std
262
- tensor.mul_(std * math.sqrt(2.0))
263
- tensor.add_(mean)
264
-
265
- # Clamp to ensure it's in the proper range
266
- tensor.clamp_(min=a, max=b)
267
- return tensor
268
-
269
-
270
- def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
271
- # type: (Tensor, float, float, float, float) -> Tensor
272
- r"""Fills the input Tensor with values drawn from a truncated
273
- normal distribution. The values are effectively drawn from the
274
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275
- with values outside :math:`[a, b]` redrawn until they are within
276
- the bounds. The method used for generating the random values works
277
- best when :math:`a \leq \text{mean} \leq b`.
278
- Args:
279
- tensor: an n-dimensional `torch.Tensor`
280
- mean: the mean of the normal distribution
281
- std: the standard deviation of the normal distribution
282
- a: the minimum cutoff value
283
- b: the maximum cutoff value
284
- Examples:
285
- >>> w = torch.empty(3, 5)
286
- >>> nn.init.trunc_normal_(w)
287
- """
288
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289
-
290
-
291
- def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
292
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
293
- if mode == "fan_in":
294
- denom = fan_in
295
- elif mode == "fan_out":
296
- denom = fan_out
297
- elif mode == "fan_avg":
298
- denom = (fan_in + fan_out) / 2
299
-
300
- variance = scale / denom
301
-
302
- if distribution == "truncated_normal":
303
- # constant is stddev of standard normal truncated to (-2, 2)
304
- trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
305
- elif distribution == "normal":
306
- tensor.normal_(std=math.sqrt(variance))
307
- elif distribution == "uniform":
308
- bound = math.sqrt(3 * variance)
309
- tensor.uniform_(-bound, bound)
310
- else:
311
- raise ValueError(f"invalid distribution {distribution}")
312
-
313
-
314
- def lecun_normal_(tensor):
315
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
316
-
317
-
318
- def window_partition(x, window_size):
319
- """
320
- Args:
321
- x: (B, H, W, C)
322
- window_size (int): window size
323
- Returns:
324
- windows: (num_windows*B, window_size, window_size, C)
325
- """
326
- B, H, W, C = x.shape
327
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
328
- windows = (
329
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
330
- )
331
- return windows
332
-
333
-
334
- def window_reverse(windows, window_size, H, W):
335
- """
336
- Args:
337
- windows: (num_windows*B, window_size, window_size, C)
338
- window_size (int): Window size
339
- H (int): Height of image
340
- W (int): Width of image
341
- Returns:
342
- x: (B, H, W, C)
343
- """
344
- B = int(windows.shape[0] / (H * W / window_size / window_size))
345
- x = windows.view(
346
- B, H // window_size, W // window_size, window_size, window_size, -1
347
- )
348
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
349
- return x
350
-
351
-
352
- class WindowAttention(nn.Module):
353
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
354
- It supports both of shifted and non-shifted window.
355
- Args:
356
- dim (int): Number of input channels.
357
- window_size (tuple[int]): The height and width of the window.
358
- num_heads (int): Number of attention heads.
359
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
361
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
362
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
363
- """
364
-
365
- def __init__(
366
- self,
367
- dim,
368
- window_size,
369
- num_heads,
370
- qkv_bias=True,
371
- qk_scale=None,
372
- attn_drop=0.0,
373
- proj_drop=0.0,
374
- ):
375
-
376
- super().__init__()
377
- self.dim = dim
378
- self.window_size = window_size # Wh, Ww
379
- self.num_heads = num_heads
380
- head_dim = dim // num_heads
381
- self.scale = qk_scale or head_dim**-0.5
382
-
383
- # define a parameter table of relative position bias
384
- self.relative_position_bias_table = nn.Parameter(
385
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
386
- ) # 2*Wh-1 * 2*Ww-1, nH
387
-
388
- # get pair-wise relative position index for each token inside the window
389
- coords_h = torch.arange(self.window_size[0])
390
- coords_w = torch.arange(self.window_size[1])
391
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
392
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
393
- relative_coords = (
394
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
395
- ) # 2, Wh*Ww, Wh*Ww
396
- relative_coords = relative_coords.permute(
397
- 1, 2, 0
398
- ).contiguous() # Wh*Ww, Wh*Ww, 2
399
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
400
- relative_coords[:, :, 1] += self.window_size[1] - 1
401
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
402
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
403
- self.register_buffer("relative_position_index", relative_position_index)
404
-
405
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
406
- self.attn_drop = nn.Dropout(attn_drop)
407
- self.proj = nn.Linear(dim, dim)
408
- self.proj_drop = nn.Dropout(proj_drop)
409
-
410
- trunc_normal_(self.relative_position_bias_table, std=0.02)
411
- self.softmax = nn.Softmax(dim=-1)
412
-
413
- def forward(self, x, mask=None):
414
- """
415
- Args:
416
- x: input features with shape of (num_windows*B, N, C)
417
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
418
- """
419
- B_, N, C = x.shape
420
- qkv = (
421
- self.qkv(x)
422
- .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
423
- .permute(2, 0, 3, 1, 4)
424
- )
425
- q, k, v = (
426
- qkv[0],
427
- qkv[1],
428
- qkv[2],
429
- ) # make torchscript happy (cannot use tensor as tuple)
430
-
431
- q = q * self.scale
432
- attn = q @ k.transpose(-2, -1)
433
-
434
- relative_position_bias = self.relative_position_bias_table[
435
- self.relative_position_index.view(-1)
436
- ].view(
437
- self.window_size[0] * self.window_size[1],
438
- self.window_size[0] * self.window_size[1],
439
- -1,
440
- ) # Wh*Ww,Wh*Ww,nH
441
- relative_position_bias = relative_position_bias.permute(
442
- 2, 0, 1
443
- ).contiguous() # nH, Wh*Ww, Wh*Ww
444
- attn = attn + relative_position_bias.unsqueeze(0)
445
-
446
- if mask is not None:
447
- nW = mask.shape[0]
448
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
449
- 1
450
- ).unsqueeze(0)
451
- attn = attn.view(-1, self.num_heads, N, N)
452
- attn = self.softmax(attn)
453
- else:
454
- attn = self.softmax(attn)
455
-
456
- attn = self.attn_drop(attn)
457
-
458
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
459
- x = self.proj(x)
460
- x = self.proj_drop(x)
461
- return x, attn
462
-
463
- def extra_repr(self):
464
- return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
465
-
466
-
467
- # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
468
- class SwinTransformerBlock(nn.Module):
469
- r"""Swin Transformer Block.
470
- Args:
471
- dim (int): Number of input channels.
472
- input_resolution (tuple[int]): Input resulotion.
473
- num_heads (int): Number of attention heads.
474
- window_size (int): Window size.
475
- shift_size (int): Shift size for SW-MSA.
476
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
477
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
478
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
479
- drop (float, optional): Dropout rate. Default: 0.0
480
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
481
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
482
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
483
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
484
- """
485
-
486
- def __init__(
487
- self,
488
- dim,
489
- input_resolution,
490
- num_heads,
491
- window_size=7,
492
- shift_size=0,
493
- mlp_ratio=4.0,
494
- qkv_bias=True,
495
- qk_scale=None,
496
- drop=0.0,
497
- attn_drop=0.0,
498
- drop_path=0.0,
499
- act_layer=nn.GELU,
500
- norm_layer=nn.LayerNorm,
501
- norm_before_mlp="ln",
502
- ):
503
- super().__init__()
504
- self.dim = dim
505
- self.input_resolution = input_resolution
506
- self.num_heads = num_heads
507
- self.window_size = window_size
508
- self.shift_size = shift_size
509
- self.mlp_ratio = mlp_ratio
510
- self.norm_before_mlp = norm_before_mlp
511
- if min(self.input_resolution) <= self.window_size:
512
- # if window size is larger than input resolution, we don't partition windows
513
- self.shift_size = 0
514
- self.window_size = min(self.input_resolution)
515
- assert (
516
- 0 <= self.shift_size < self.window_size
517
- ), "shift_size must in 0-window_size"
518
-
519
- self.norm1 = norm_layer(dim)
520
- self.attn = WindowAttention(
521
- dim,
522
- window_size=to_2tuple(self.window_size),
523
- num_heads=num_heads,
524
- qkv_bias=qkv_bias,
525
- qk_scale=qk_scale,
526
- attn_drop=attn_drop,
527
- proj_drop=drop,
528
- )
529
-
530
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
531
- if self.norm_before_mlp == "ln":
532
- self.norm2 = nn.LayerNorm(dim)
533
- elif self.norm_before_mlp == "bn":
534
- self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
535
- 1, 2
536
- )
537
- else:
538
- raise NotImplementedError
539
- mlp_hidden_dim = int(dim * mlp_ratio)
540
- self.mlp = Mlp(
541
- in_features=dim,
542
- hidden_features=mlp_hidden_dim,
543
- act_layer=act_layer,
544
- drop=drop,
545
- )
546
-
547
- if self.shift_size > 0:
548
- # calculate attention mask for SW-MSA
549
- H, W = self.input_resolution
550
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
551
- h_slices = (
552
- slice(0, -self.window_size),
553
- slice(-self.window_size, -self.shift_size),
554
- slice(-self.shift_size, None),
555
- )
556
- w_slices = (
557
- slice(0, -self.window_size),
558
- slice(-self.window_size, -self.shift_size),
559
- slice(-self.shift_size, None),
560
- )
561
- cnt = 0
562
- for h in h_slices:
563
- for w in w_slices:
564
- img_mask[:, h, w, :] = cnt
565
- cnt += 1
566
-
567
- mask_windows = window_partition(
568
- img_mask, self.window_size
569
- ) # nW, window_size, window_size, 1
570
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
571
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
572
- attn_mask = attn_mask.masked_fill(
573
- attn_mask != 0, float(-100.0)
574
- ).masked_fill(attn_mask == 0, float(0.0))
575
- else:
576
- attn_mask = None
577
-
578
- self.register_buffer("attn_mask", attn_mask)
579
-
580
- def forward(self, x):
581
- # pdb.set_trace()
582
- H, W = self.input_resolution
583
- # print("H: ", H)
584
- # print("W: ", W)
585
- # pdb.set_trace()
586
- B, L, C = x.shape
587
- # assert L == H * W, "input feature has wrong size"
588
-
589
- shortcut = x
590
- x = self.norm1(x)
591
- x = x.view(B, H, W, C)
592
-
593
- # cyclic shift
594
- if self.shift_size > 0:
595
- shifted_x = torch.roll(
596
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
597
- )
598
- else:
599
- shifted_x = x
600
-
601
- # partition windows
602
- x_windows = window_partition(
603
- shifted_x, self.window_size
604
- ) # nW*B, window_size, window_size, C
605
- x_windows = x_windows.view(
606
- -1, self.window_size * self.window_size, C
607
- ) # nW*B, window_size*window_size, C
608
-
609
- # W-MSA/SW-MSA
610
- attn_windows, attn = self.attn(
611
- x_windows, mask=self.attn_mask
612
- ) # nW*B, window_size*window_size, C
613
-
614
- # merge windows
615
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
616
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
617
-
618
- # reverse cyclic shift
619
- if self.shift_size > 0:
620
- x = torch.roll(
621
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
622
- )
623
- else:
624
- x = shifted_x
625
- x = x.view(B, H * W, C)
626
-
627
- # FFN
628
- x = shortcut + self.drop_path(x)
629
- x = x + self.drop_path(self.mlp(self.norm2(x)))
630
-
631
- return x, attn
632
-
633
- def extra_repr(self):
634
- return (
635
- f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
636
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
637
- )
638
-
639
-
640
- class PatchMerging(nn.Module):
641
- r"""Patch Merging Layer.
642
- Args:
643
- input_resolution (tuple[int]): Resolution of input feature.
644
- dim (int): Number of input channels.
645
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
646
- """
647
-
648
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
649
- super().__init__()
650
- self.input_resolution = input_resolution
651
- self.dim = dim
652
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
653
- self.norm = norm_layer(4 * dim)
654
-
655
- def forward(self, x):
656
- """
657
- x: B, H*W, C
658
- """
659
- H, W = self.input_resolution
660
- B, L, C = x.shape
661
- assert L == H * W, "input feature has wrong size"
662
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
663
-
664
- x = x.view(B, H, W, C)
665
-
666
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
667
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
668
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
669
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
670
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
671
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
672
-
673
- x = self.norm(x)
674
- x = self.reduction(x)
675
-
676
- return x
677
-
678
- def extra_repr(self):
679
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
680
-
681
-
682
- class BasicLayer(nn.Module):
683
- """A basic Swin Transformer layer for one stage.
684
- Args:
685
- dim (int): Number of input channels.
686
- input_resolution (tuple[int]): Input resolution.
687
- depth (int): Number of blocks.
688
- num_heads (int): Number of attention heads.
689
- window_size (int): Local window size.
690
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
691
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
692
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
693
- drop (float, optional): Dropout rate. Default: 0.0
694
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
695
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
696
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
697
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
698
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
699
- """
700
-
701
- def __init__(
702
- self,
703
- dim,
704
- input_resolution,
705
- depth,
706
- num_heads,
707
- window_size,
708
- mlp_ratio=4.0,
709
- qkv_bias=True,
710
- qk_scale=None,
711
- drop=0.0,
712
- attn_drop=0.0,
713
- drop_path=0.0,
714
- norm_layer=nn.LayerNorm,
715
- downsample=None,
716
- use_checkpoint=False,
717
- norm_before_mlp="ln",
718
- ):
719
-
720
- super().__init__()
721
- self.dim = dim
722
- self.input_resolution = input_resolution
723
- self.depth = depth
724
- self.use_checkpoint = use_checkpoint
725
-
726
- # build blocks
727
- self.blocks = nn.ModuleList(
728
- [
729
- SwinTransformerBlock(
730
- dim=dim,
731
- input_resolution=input_resolution,
732
- num_heads=num_heads,
733
- window_size=window_size,
734
- shift_size=0 if (i % 2 == 0) else window_size // 2,
735
- mlp_ratio=mlp_ratio,
736
- qkv_bias=qkv_bias,
737
- qk_scale=qk_scale,
738
- drop=drop,
739
- attn_drop=attn_drop,
740
- drop_path=drop_path[i]
741
- if isinstance(drop_path, list)
742
- else drop_path,
743
- norm_layer=norm_layer,
744
- norm_before_mlp=norm_before_mlp,
745
- )
746
- for i in range(depth)
747
- ]
748
- )
749
-
750
- # patch merging layer
751
- if downsample is not None:
752
- self.downsample = downsample(
753
- input_resolution, dim=dim, norm_layer=norm_layer
754
- )
755
- else:
756
- self.downsample = None
757
-
758
- def forward(self, x):
759
- attns = []
760
- for blk in self.blocks:
761
- if self.use_checkpoint:
762
- x = checkpoint.checkpoint(blk, x)
763
- else:
764
- x, attn = blk(x)
765
- if not self.training:
766
- attns.append(attn.unsqueeze(0))
767
- if self.downsample is not None:
768
- x = self.downsample(x)
769
- if not self.training:
770
- attn = torch.cat(attns, dim=0)
771
- attn = torch.mean(attn, dim=0)
772
- return x, attn
773
-
774
- def extra_repr(self):
775
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
776
-
777
-
778
- # The Core of HTSAT
779
- class HTSAT_Swin_Transformer(nn.Module):
780
- r"""HTSAT based on the Swin Transformer
781
- Args:
782
- spec_size (int | tuple(int)): Input Spectrogram size. Default 256
783
- patch_size (int | tuple(int)): Patch size. Default: 4
784
- path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
785
- in_chans (int): Number of input image channels. Default: 1 (mono)
786
- num_classes (int): Number of classes for classification head. Default: 527
787
- embed_dim (int): Patch embedding dimension. Default: 96
788
- depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
789
- num_heads (tuple(int)): Number of attention heads in different layers.
790
- window_size (int): Window size. Default: 8
791
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
792
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
793
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
794
- drop_rate (float): Dropout rate. Default: 0
795
- attn_drop_rate (float): Attention dropout rate. Default: 0
796
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
797
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
798
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
799
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
800
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
801
- config (module): The configuration Module from config.py
802
- """
803
-
804
- def __init__(
805
- self,
806
- spec_size=256,
807
- patch_size=4,
808
- patch_stride=(4, 4),
809
- in_chans=1,
810
- num_classes=527,
811
- embed_dim=96,
812
- depths=[2, 2, 6, 2],
813
- num_heads=[4, 8, 16, 32],
814
- window_size=8,
815
- mlp_ratio=4.0,
816
- qkv_bias=True,
817
- qk_scale=None,
818
- drop_rate=0.0,
819
- attn_drop_rate=0.0,
820
- drop_path_rate=0.1,
821
- norm_layer=nn.LayerNorm,
822
- ape=False,
823
- patch_norm=True,
824
- use_checkpoint=False,
825
- norm_before_mlp="ln",
826
- config=None,
827
- enable_fusion=False,
828
- fusion_type="None",
829
- **kwargs,
830
- ):
831
- super(HTSAT_Swin_Transformer, self).__init__()
832
-
833
- self.config = config
834
- self.spec_size = spec_size
835
- self.patch_stride = patch_stride
836
- self.patch_size = patch_size
837
- self.window_size = window_size
838
- self.embed_dim = embed_dim
839
- self.depths = depths
840
- self.ape = ape
841
- self.in_chans = in_chans
842
- self.num_classes = num_classes
843
- self.num_heads = num_heads
844
- self.num_layers = len(self.depths)
845
- self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
846
-
847
- self.drop_rate = drop_rate
848
- self.attn_drop_rate = attn_drop_rate
849
- self.drop_path_rate = drop_path_rate
850
-
851
- self.qkv_bias = qkv_bias
852
- self.qk_scale = None
853
-
854
- self.patch_norm = patch_norm
855
- self.norm_layer = norm_layer if self.patch_norm else None
856
- self.norm_before_mlp = norm_before_mlp
857
- self.mlp_ratio = mlp_ratio
858
-
859
- self.use_checkpoint = use_checkpoint
860
-
861
- self.enable_fusion = enable_fusion
862
- self.fusion_type = fusion_type
863
-
864
- # process mel-spec ; used only once
865
- self.freq_ratio = self.spec_size // self.config.mel_bins
866
- window = "hann"
867
- center = True
868
- pad_mode = "reflect"
869
- ref = 1.0
870
- amin = 1e-10
871
- top_db = None
872
- self.interpolate_ratio = 32 # Downsampled ratio
873
- # Spectrogram extractor
874
- self.spectrogram_extractor = Spectrogram(
875
- n_fft=config.window_size,
876
- hop_length=config.hop_size,
877
- win_length=config.window_size,
878
- window=window,
879
- center=center,
880
- pad_mode=pad_mode,
881
- freeze_parameters=True,
882
- )
883
- # Logmel feature extractor
884
- self.logmel_extractor = LogmelFilterBank(
885
- sr=config.sample_rate,
886
- n_fft=config.window_size,
887
- n_mels=config.mel_bins,
888
- fmin=config.fmin,
889
- fmax=config.fmax,
890
- ref=ref,
891
- amin=amin,
892
- top_db=top_db,
893
- freeze_parameters=True,
894
- )
895
- # Spec augmenter
896
- self.spec_augmenter = SpecAugmentation(
897
- time_drop_width=64,
898
- time_stripes_num=2,
899
- freq_drop_width=8,
900
- freq_stripes_num=2,
901
- ) # 2 2
902
- self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
903
-
904
- # split spctrogram into non-overlapping patches
905
- self.patch_embed = PatchEmbed(
906
- img_size=self.spec_size,
907
- patch_size=self.patch_size,
908
- in_chans=self.in_chans,
909
- embed_dim=self.embed_dim,
910
- norm_layer=self.norm_layer,
911
- patch_stride=patch_stride,
912
- enable_fusion=self.enable_fusion,
913
- fusion_type=self.fusion_type,
914
- )
915
-
916
- num_patches = self.patch_embed.num_patches
917
- patches_resolution = self.patch_embed.grid_size
918
- self.patches_resolution = patches_resolution
919
-
920
- # absolute position embedding
921
- if self.ape:
922
- self.absolute_pos_embed = nn.Parameter(
923
- torch.zeros(1, num_patches, self.embed_dim)
924
- )
925
- trunc_normal_(self.absolute_pos_embed, std=0.02)
926
-
927
- self.pos_drop = nn.Dropout(p=self.drop_rate)
928
-
929
- # stochastic depth
930
- dpr = [
931
- x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
932
- ] # stochastic depth decay rule
933
-
934
- # build layers
935
- self.layers = nn.ModuleList()
936
- for i_layer in range(self.num_layers):
937
- layer = BasicLayer(
938
- dim=int(self.embed_dim * 2**i_layer),
939
- input_resolution=(
940
- patches_resolution[0] // (2**i_layer),
941
- patches_resolution[1] // (2**i_layer),
942
- ),
943
- depth=self.depths[i_layer],
944
- num_heads=self.num_heads[i_layer],
945
- window_size=self.window_size,
946
- mlp_ratio=self.mlp_ratio,
947
- qkv_bias=self.qkv_bias,
948
- qk_scale=self.qk_scale,
949
- drop=self.drop_rate,
950
- attn_drop=self.attn_drop_rate,
951
- drop_path=dpr[
952
- sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
953
- ],
954
- norm_layer=self.norm_layer,
955
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
956
- use_checkpoint=use_checkpoint,
957
- norm_before_mlp=self.norm_before_mlp,
958
- )
959
- self.layers.append(layer)
960
-
961
- self.norm = self.norm_layer(self.num_features)
962
- self.avgpool = nn.AdaptiveAvgPool1d(1)
963
- self.maxpool = nn.AdaptiveMaxPool1d(1)
964
-
965
- SF = (
966
- self.spec_size
967
- // (2 ** (len(self.depths) - 1))
968
- // self.patch_stride[0]
969
- // self.freq_ratio
970
- )
971
- self.tscam_conv = nn.Conv2d(
972
- in_channels=self.num_features,
973
- out_channels=self.num_classes,
974
- kernel_size=(SF, 3),
975
- padding=(0, 1),
976
- )
977
- self.head = nn.Linear(num_classes, num_classes)
978
-
979
- if (self.enable_fusion) and (
980
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
981
- ):
982
- self.mel_conv1d = nn.Sequential(
983
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
984
- nn.BatchNorm1d(64),
985
- )
986
- if self.fusion_type == "daf_1d":
987
- self.fusion_model = DAF()
988
- elif self.fusion_type == "aff_1d":
989
- self.fusion_model = AFF(channels=64, type="1D")
990
- elif self.fusion_type == "iaff_1d":
991
- self.fusion_model = iAFF(channels=64, type="1D")
992
-
993
- self.apply(self._init_weights)
994
-
995
- def _init_weights(self, m):
996
- if isinstance(m, nn.Linear):
997
- trunc_normal_(m.weight, std=0.02)
998
- if isinstance(m, nn.Linear) and m.bias is not None:
999
- nn.init.constant_(m.bias, 0)
1000
- elif isinstance(m, nn.LayerNorm):
1001
- nn.init.constant_(m.bias, 0)
1002
- nn.init.constant_(m.weight, 1.0)
1003
-
1004
- @torch.jit.ignore
1005
- def no_weight_decay(self):
1006
- return {"absolute_pos_embed"}
1007
-
1008
- @torch.jit.ignore
1009
- def no_weight_decay_keywords(self):
1010
- return {"relative_position_bias_table"}
1011
-
1012
- def forward_features(self, x, longer_idx=None):
1013
- # A deprecated optimization for using a hierarchical output from different blocks
1014
-
1015
- frames_num = x.shape[2]
1016
- x = self.patch_embed(x, longer_idx=longer_idx)
1017
- if self.ape:
1018
- x = x + self.absolute_pos_embed
1019
- x = self.pos_drop(x)
1020
- for i, layer in enumerate(self.layers):
1021
- x, attn = layer(x)
1022
- # for x
1023
- x = self.norm(x)
1024
- B, N, C = x.shape
1025
- SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1026
- ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1027
- x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
1028
- B, C, F, T = x.shape
1029
- # group 2D CNN
1030
- c_freq_bin = F // self.freq_ratio
1031
- x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1032
- x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
1033
- # get latent_output
1034
- fine_grained_latent_output = torch.mean(x, dim=2)
1035
- fine_grained_latent_output = interpolate(
1036
- fine_grained_latent_output.permute(0, 2, 1).contiguous(),
1037
- 8 * self.patch_stride[1],
1038
- )
1039
-
1040
- latent_output = self.avgpool(torch.flatten(x, 2))
1041
- latent_output = torch.flatten(latent_output, 1)
1042
-
1043
- # display the attention map, if needed
1044
-
1045
- x = self.tscam_conv(x)
1046
- x = torch.flatten(x, 2) # B, C, T
1047
-
1048
- fpx = interpolate(
1049
- torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
1050
- )
1051
-
1052
- x = self.avgpool(x)
1053
- x = torch.flatten(x, 1)
1054
-
1055
- output_dict = {
1056
- "framewise_output": fpx, # already sigmoided
1057
- "clipwise_output": torch.sigmoid(x),
1058
- "fine_grained_embedding": fine_grained_latent_output,
1059
- "embedding": latent_output,
1060
- }
1061
-
1062
- return output_dict
1063
-
1064
- def crop_wav(self, x, crop_size, spe_pos=None):
1065
- time_steps = x.shape[2]
1066
- tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1067
- for i in range(len(x)):
1068
- if spe_pos is None:
1069
- crop_pos = random.randint(0, time_steps - crop_size - 1)
1070
- else:
1071
- crop_pos = spe_pos
1072
- tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
1073
- return tx
1074
-
1075
- # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1076
- def reshape_wav2img(self, x):
1077
- B, C, T, F = x.shape
1078
- target_T = int(self.spec_size * self.freq_ratio)
1079
- target_F = self.spec_size // self.freq_ratio
1080
- assert (
1081
- T <= target_T and F <= target_F
1082
- ), "the wav size should less than or equal to the swin input size"
1083
- # to avoid bicubic zero error
1084
- if T < target_T:
1085
- x = nn.functional.interpolate(
1086
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1087
- )
1088
- if F < target_F:
1089
- x = nn.functional.interpolate(
1090
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1091
- )
1092
- x = x.permute(0, 1, 3, 2).contiguous()
1093
- x = x.reshape(
1094
- x.shape[0],
1095
- x.shape[1],
1096
- x.shape[2],
1097
- self.freq_ratio,
1098
- x.shape[3] // self.freq_ratio,
1099
- )
1100
- # print(x.shape)
1101
- x = x.permute(0, 1, 3, 2, 4).contiguous()
1102
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1103
- return x
1104
-
1105
- # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1106
- def repeat_wat2img(self, x, cur_pos):
1107
- B, C, T, F = x.shape
1108
- target_T = int(self.spec_size * self.freq_ratio)
1109
- target_F = self.spec_size // self.freq_ratio
1110
- assert (
1111
- T <= target_T and F <= target_F
1112
- ), "the wav size should less than or equal to the swin input size"
1113
- # to avoid bicubic zero error
1114
- if T < target_T:
1115
- x = nn.functional.interpolate(
1116
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1117
- )
1118
- if F < target_F:
1119
- x = nn.functional.interpolate(
1120
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1121
- )
1122
- x = x.permute(0, 1, 3, 2).contiguous() # B C F T
1123
- x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
1124
- x = x.repeat(repeats=(1, 1, 4, 1))
1125
- return x
1126
-
1127
- def forward(
1128
- self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
1129
- ): # out_feat_keys: List[str] = None):
1130
-
1131
- if self.enable_fusion and x["longer"].sum() == 0:
1132
- # if no audio is longer than 10s, then randomly select one audio to be longer
1133
- x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
1134
-
1135
- if not self.enable_fusion:
1136
- x = x["waveform"].to(device=device, non_blocking=True)
1137
- x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1138
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1139
- x = x.transpose(1, 3)
1140
- x = self.bn0(x)
1141
- x = x.transpose(1, 3)
1142
- if self.training:
1143
- x = self.spec_augmenter(x)
1144
-
1145
- if self.training and mixup_lambda is not None:
1146
- x = do_mixup(x, mixup_lambda)
1147
-
1148
- x = self.reshape_wav2img(x)
1149
- output_dict = self.forward_features(x)
1150
- else:
1151
- longer_list = x["longer"].to(device=device, non_blocking=True)
1152
- x = x["mel_fusion"].to(device=device, non_blocking=True)
1153
- x = x.transpose(1, 3)
1154
- x = self.bn0(x)
1155
- x = x.transpose(1, 3)
1156
- longer_list_idx = torch.where(longer_list)[0]
1157
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
1158
- new_x = x[:, 0:1, :, :].clone().contiguous()
1159
- if len(longer_list_idx) > 0:
1160
- # local processing
1161
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
1162
- FB, FC, FT, FF = fusion_x_local.size()
1163
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
1164
- fusion_x_local = torch.permute(
1165
- fusion_x_local, (0, 2, 1)
1166
- ).contiguous()
1167
- fusion_x_local = self.mel_conv1d(fusion_x_local)
1168
- fusion_x_local = fusion_x_local.view(
1169
- FB, FC, FF, fusion_x_local.size(-1)
1170
- )
1171
- fusion_x_local = (
1172
- torch.permute(fusion_x_local, (0, 2, 1, 3))
1173
- .contiguous()
1174
- .flatten(2)
1175
- )
1176
- if fusion_x_local.size(-1) < FT:
1177
- fusion_x_local = torch.cat(
1178
- [
1179
- fusion_x_local,
1180
- torch.zeros(
1181
- (FB, FF, FT - fusion_x_local.size(-1)),
1182
- device=device,
1183
- ),
1184
- ],
1185
- dim=-1,
1186
- )
1187
- else:
1188
- fusion_x_local = fusion_x_local[:, :, :FT]
1189
- # 1D fusion
1190
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
1191
- new_x[longer_list_idx] = self.fusion_model(
1192
- new_x[longer_list_idx], fusion_x_local
1193
- )
1194
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
1195
- else:
1196
- x = new_x
1197
-
1198
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
1199
- x = x # no change
1200
-
1201
- if self.training:
1202
- x = self.spec_augmenter(x)
1203
- if self.training and mixup_lambda is not None:
1204
- x = do_mixup(x, mixup_lambda)
1205
-
1206
- x = self.reshape_wav2img(x)
1207
- output_dict = self.forward_features(x, longer_idx=longer_list_idx)
1208
-
1209
- # if infer_mode:
1210
- # # in infer mode. we need to handle different length audio input
1211
- # frame_num = x.shape[2]
1212
- # target_T = int(self.spec_size * self.freq_ratio)
1213
- # repeat_ratio = math.floor(target_T / frame_num)
1214
- # x = x.repeat(repeats=(1,1,repeat_ratio,1))
1215
- # x = self.reshape_wav2img(x)
1216
- # output_dict = self.forward_features(x)
1217
- # else:
1218
- # if x.shape[2] > self.freq_ratio * self.spec_size:
1219
- # if self.training:
1220
- # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1221
- # x = self.reshape_wav2img(x)
1222
- # output_dict = self.forward_features(x)
1223
- # else:
1224
- # # Change: Hard code here
1225
- # overlap_size = (x.shape[2] - 1) // 4
1226
- # output_dicts = []
1227
- # crop_size = (x.shape[2] - 1) // 2
1228
- # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1229
- # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1230
- # tx = self.reshape_wav2img(tx)
1231
- # output_dicts.append(self.forward_features(tx))
1232
- # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1233
- # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1234
- # for d in output_dicts:
1235
- # clipwise_output += d["clipwise_output"]
1236
- # framewise_output += d["framewise_output"]
1237
- # clipwise_output = clipwise_output / len(output_dicts)
1238
- # framewise_output = framewise_output / len(output_dicts)
1239
- # output_dict = {
1240
- # 'framewise_output': framewise_output,
1241
- # 'clipwise_output': clipwise_output
1242
- # }
1243
- # else: # this part is typically used, and most easy one
1244
- # x = self.reshape_wav2img(x)
1245
- # output_dict = self.forward_features(x)
1246
- # x = self.head(x)
1247
-
1248
- # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
1249
-
1250
- return output_dict
1251
-
1252
-
1253
- def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
1254
- try:
1255
-
1256
- assert audio_cfg.model_name in [
1257
- "tiny",
1258
- "base",
1259
- "large",
1260
- ], "model name for HTS-AT is wrong!"
1261
- if audio_cfg.model_name == "tiny":
1262
- model = HTSAT_Swin_Transformer(
1263
- spec_size=256,
1264
- patch_size=4,
1265
- patch_stride=(4, 4),
1266
- num_classes=audio_cfg.class_num,
1267
- embed_dim=96,
1268
- depths=[2, 2, 6, 2],
1269
- num_heads=[4, 8, 16, 32],
1270
- window_size=8,
1271
- config=audio_cfg,
1272
- enable_fusion=enable_fusion,
1273
- fusion_type=fusion_type,
1274
- )
1275
- elif audio_cfg.model_name == "base":
1276
- model = HTSAT_Swin_Transformer(
1277
- spec_size=256,
1278
- patch_size=4,
1279
- patch_stride=(4, 4),
1280
- num_classes=audio_cfg.class_num,
1281
- embed_dim=128,
1282
- depths=[2, 2, 12, 2],
1283
- num_heads=[4, 8, 16, 32],
1284
- window_size=8,
1285
- config=audio_cfg,
1286
- enable_fusion=enable_fusion,
1287
- fusion_type=fusion_type,
1288
- )
1289
- elif audio_cfg.model_name == "large":
1290
- model = HTSAT_Swin_Transformer(
1291
- spec_size=256,
1292
- patch_size=4,
1293
- patch_stride=(4, 4),
1294
- num_classes=audio_cfg.class_num,
1295
- embed_dim=256,
1296
- depths=[2, 2, 12, 2],
1297
- num_heads=[4, 8, 16, 32],
1298
- window_size=8,
1299
- config=audio_cfg,
1300
- enable_fusion=enable_fusion,
1301
- fusion_type=fusion_type,
1302
- )
1303
-
1304
- return model
1305
- except:
1306
- raise RuntimeError(
1307
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
1308
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/linear_probe.py DELETED
@@ -1,66 +0,0 @@
1
- import numpy as np
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from .model import MLPLayers
5
-
6
-
7
- class LinearProbe(nn.Module):
8
- def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
9
- """
10
- Args:
11
- model: nn.Module
12
- mlp: bool, if True, then use the MLP layer as the linear probe module
13
- freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
14
- in_ch: int, the output channel from CLAP model
15
- out_ch: int, the output channel from linear probe (class_num)
16
- act: torch.nn.functional, the activation function before the loss function
17
- """
18
- super().__init__()
19
- in_ch = 512
20
- self.clap_model = model
21
- self.clap_model.text_branch = None # to save memory
22
- self.freeze = freeze
23
- if mlp:
24
- self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
25
- else:
26
- self.lp_layer = nn.Linear(in_ch, out_ch)
27
-
28
- if self.freeze:
29
- for param in self.clap_model.parameters():
30
- param.requires_grad = False
31
-
32
- if act == "None":
33
- self.act = None
34
- elif act == "relu":
35
- self.act = nn.ReLU()
36
- elif act == "elu":
37
- self.act = nn.ELU()
38
- elif act == "prelu":
39
- self.act = nn.PReLU(num_parameters=in_ch)
40
- elif act == "softmax":
41
- self.act = nn.Softmax(dim=-1)
42
- elif act == "sigmoid":
43
- self.act = nn.Sigmoid()
44
-
45
- def forward(self, x, mix_lambda=None, device=None):
46
- """
47
- Args:
48
- x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
49
- mix_lambda: torch.tensor [batch], the mixup lambda
50
- Returns:
51
- class_prob: torch.tensor [batch, class_num]
52
-
53
- """
54
- # batchnorm cancel grandient
55
- if self.freeze:
56
- self.clap_model.eval()
57
-
58
- x = self.clap_model.audio_projection(
59
- self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[
60
- "embedding"
61
- ]
62
- )
63
- out = self.lp_layer(x)
64
- if self.act is not None:
65
- out = self.act(out)
66
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/loss.py DELETED
@@ -1,398 +0,0 @@
1
- from multiprocessing.sharedctypes import Value
2
- import torch
3
- import torch.distributed.nn
4
- from torch import distributed as dist, nn as nn
5
- from torch.nn import functional as F
6
- import numpy as np
7
- from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
8
-
9
- try:
10
- import horovod.torch as hvd
11
- except ImportError:
12
- hvd = None
13
-
14
-
15
- def gather_features(
16
- audio_features,
17
- text_features,
18
- audio_features_mlp=None,
19
- text_features_mlp=None,
20
- local_loss=False,
21
- gather_with_grad=False,
22
- rank=0,
23
- world_size=1,
24
- use_horovod=False,
25
- mlp_loss=False,
26
- ):
27
- if use_horovod:
28
- assert hvd is not None, "Please install horovod"
29
- if gather_with_grad:
30
- all_audio_features = hvd.allgather(audio_features)
31
- all_text_features = hvd.allgather(text_features)
32
- if mlp_loss:
33
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
34
- all_text_features_mlp = hvd.allgather(text_features_mlp)
35
- else:
36
- with torch.no_grad():
37
- all_audio_features = hvd.allgather(audio_features)
38
- all_text_features = hvd.allgather(text_features)
39
- if mlp_loss:
40
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
41
- all_text_features_mlp = hvd.allgather(text_features_mlp)
42
- if not local_loss:
43
- # ensure grads for local rank when all_* features don't have a gradient
44
- gathered_audio_features = list(
45
- all_audio_features.chunk(world_size, dim=0)
46
- )
47
- gathered_text_features = list(
48
- all_text_features.chunk(world_size, dim=0)
49
- )
50
- gathered_audio_features[rank] = audio_features
51
- gathered_text_features[rank] = text_features
52
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
53
- all_text_features = torch.cat(gathered_text_features, dim=0)
54
- if mlp_loss:
55
- gathered_audio_features_mlp = list(
56
- all_audio_features_mlp.chunk(world_size, dim=0)
57
- )
58
- gathered_text_features_mlp = list(
59
- all_text_features_mlp.chunk(world_size, dim=0)
60
- )
61
- gathered_audio_features_mlp[rank] = audio_features_mlp
62
- gathered_text_features_mlp[rank] = text_features_mlp
63
- all_audio_features_mlp = torch.cat(
64
- gathered_audio_features_mlp, dim=0
65
- )
66
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
67
- else:
68
- # We gather tensors from all gpus
69
- if gather_with_grad:
70
- all_audio_features = torch.cat(
71
- torch.distributed.nn.all_gather(audio_features), dim=0
72
- )
73
- all_text_features = torch.cat(
74
- torch.distributed.nn.all_gather(text_features), dim=0
75
- )
76
- if mlp_loss:
77
- all_audio_features_mlp = torch.cat(
78
- torch.distributed.nn.all_gather(audio_features_mlp), dim=0
79
- )
80
- all_text_features_mlp = torch.cat(
81
- torch.distributed.nn.all_gather(text_features_mlp), dim=0
82
- )
83
- else:
84
- gathered_audio_features = [
85
- torch.zeros_like(audio_features) for _ in range(world_size)
86
- ]
87
- gathered_text_features = [
88
- torch.zeros_like(text_features) for _ in range(world_size)
89
- ]
90
- dist.all_gather(gathered_audio_features, audio_features)
91
- dist.all_gather(gathered_text_features, text_features)
92
- if mlp_loss:
93
- gathered_audio_features_mlp = [
94
- torch.zeros_like(audio_features_mlp) for _ in range(world_size)
95
- ]
96
- gathered_text_features_mlp = [
97
- torch.zeros_like(text_features_mlp) for _ in range(world_size)
98
- ]
99
- dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
100
- dist.all_gather(gathered_text_features_mlp, text_features_mlp)
101
- if not local_loss:
102
- # ensure grads for local rank when all_* features don't have a gradient
103
- gathered_audio_features[rank] = audio_features
104
- gathered_text_features[rank] = text_features
105
- if mlp_loss:
106
- gathered_audio_features_mlp[rank] = audio_features_mlp
107
- gathered_text_features_mlp[rank] = text_features_mlp
108
-
109
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
110
- all_text_features = torch.cat(gathered_text_features, dim=0)
111
- if mlp_loss:
112
- all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
113
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
114
- if mlp_loss:
115
- return (
116
- all_audio_features,
117
- all_text_features,
118
- all_audio_features_mlp,
119
- all_text_features_mlp,
120
- )
121
- else:
122
- return all_audio_features, all_text_features
123
-
124
-
125
- class ClipLoss(nn.Module):
126
- def __init__(
127
- self,
128
- local_loss=False,
129
- gather_with_grad=False,
130
- cache_labels=False,
131
- rank=0,
132
- world_size=1,
133
- use_horovod=False,
134
- mlp_loss=False,
135
- weight_loss_kappa=0,
136
- ):
137
- super().__init__()
138
- self.local_loss = local_loss
139
- self.gather_with_grad = gather_with_grad
140
- self.cache_labels = cache_labels
141
- self.rank = rank
142
- self.world_size = world_size
143
- self.use_horovod = use_horovod
144
- self.mlp_loss = mlp_loss
145
- self.weighted_loss = bool(weight_loss_kappa != 0)
146
- self.weight_loss_kappa = weight_loss_kappa
147
- # cache state
148
- self.prev_num_logits = 0
149
- self.labels = {}
150
-
151
- def forward(
152
- self,
153
- audio_features,
154
- text_features,
155
- logit_scale_a,
156
- logit_scale_t=None,
157
- audio_features_mlp=None,
158
- text_features_mlp=None,
159
- ):
160
- device = audio_features.device
161
- if self.mlp_loss:
162
- if self.world_size > 1:
163
- (
164
- all_audio_features,
165
- all_text_features,
166
- all_audio_features_mlp,
167
- all_text_features_mlp,
168
- ) = gather_features(
169
- audio_features=audio_features,
170
- text_features=text_features,
171
- audio_features_mlp=audio_features_mlp,
172
- text_features_mlp=text_features_mlp,
173
- local_loss=self.local_loss,
174
- gather_with_grad=self.gather_with_grad,
175
- rank=self.rank,
176
- world_size=self.world_size,
177
- use_horovod=self.use_horovod,
178
- mlp_loss=self.mlp_loss,
179
- )
180
- if self.local_loss:
181
- a_logits_per_audio = (
182
- logit_scale_a * audio_features @ all_text_features_mlp.T
183
- )
184
- a_logits_per_text = (
185
- logit_scale_a * text_features_mlp @ all_audio_features.T
186
- )
187
- t_logits_per_audio = (
188
- logit_scale_t * audio_features_mlp @ all_text_features.T
189
- )
190
- t_logits_per_text = (
191
- logit_scale_t * text_features @ all_audio_features_mlp.T
192
- )
193
- else:
194
- a_logits_per_audio = (
195
- logit_scale_a * all_audio_features @ all_text_features_mlp.T
196
- )
197
- a_logits_per_text = a_logits_per_audio.T
198
- t_logits_per_audio = (
199
- logit_scale_t * all_audio_features_mlp @ all_text_features.T
200
- )
201
- t_logits_per_text = t_logits_per_audio.T
202
- else:
203
- a_logits_per_audio = (
204
- logit_scale_a * audio_features @ text_features_mlp.T
205
- )
206
- a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
207
- t_logits_per_audio = (
208
- logit_scale_t * audio_features_mlp @ text_features.T
209
- )
210
- t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
211
-
212
- # calculated ground-truth and cache if enabled
213
- num_logits = a_logits_per_audio.shape[0]
214
- if self.prev_num_logits != num_logits or device not in self.labels:
215
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
216
- if self.world_size > 1 and self.local_loss:
217
- labels = labels + num_logits * self.rank
218
- if self.cache_labels:
219
- self.labels[device] = labels
220
- self.prev_num_logits = num_logits
221
- else:
222
- labels = self.labels[device]
223
-
224
- if not self.weighted_loss:
225
- total_loss = (
226
- F.cross_entropy(a_logits_per_audio, labels)
227
- + F.cross_entropy(a_logits_per_text, labels)
228
- + F.cross_entropy(t_logits_per_audio, labels)
229
- + F.cross_entropy(t_logits_per_text, labels)
230
- ) / 4
231
- else:
232
- audio_weight = (audio_features @ audio_features.T).detach()
233
- audio_weight = (
234
- torch.exp(
235
- torch.sum(audio_weight, axis=1)
236
- / (self.weight_loss_kappa * len(audio_weight))
237
- )
238
- ).detach()
239
- text_weight = (text_features @ text_features.T).detach()
240
- text_weight = (
241
- torch.exp(
242
- torch.sum(text_weight, axis=1)
243
- / (self.weight_loss_kappa * len(text_features))
244
- )
245
- ).detach()
246
- total_loss = (
247
- F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
248
- + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
249
- + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
250
- + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
251
- ) / 4
252
- else:
253
- if self.world_size > 1:
254
- all_audio_features, all_text_features = gather_features(
255
- audio_features=audio_features,
256
- text_features=text_features,
257
- local_loss=self.local_loss,
258
- gather_with_grad=self.gather_with_grad,
259
- rank=self.rank,
260
- world_size=self.world_size,
261
- use_horovod=self.use_horovod,
262
- mlp_loss=self.mlp_loss,
263
- )
264
-
265
- if self.local_loss:
266
- logits_per_audio = (
267
- logit_scale_a * audio_features @ all_text_features.T
268
- )
269
- logits_per_text = (
270
- logit_scale_a * text_features @ all_audio_features.T
271
- )
272
- else:
273
- logits_per_audio = (
274
- logit_scale_a * all_audio_features @ all_text_features.T
275
- )
276
- logits_per_text = logits_per_audio.T
277
- else:
278
- logits_per_audio = logit_scale_a * audio_features @ text_features.T
279
- logits_per_text = logit_scale_a * text_features @ audio_features.T
280
-
281
- # calculated ground-truth and cache if enabled
282
- num_logits = logits_per_audio.shape[0]
283
- if self.prev_num_logits != num_logits or device not in self.labels:
284
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
285
- if self.world_size > 1 and self.local_loss:
286
- labels = labels + num_logits * self.rank
287
- if self.cache_labels:
288
- self.labels[device] = labels
289
- self.prev_num_logits = num_logits
290
- else:
291
- labels = self.labels[device]
292
- if not self.weighted_loss:
293
- total_loss = (
294
- F.cross_entropy(logits_per_audio, labels)
295
- + F.cross_entropy(logits_per_text, labels)
296
- ) / 2
297
- else:
298
- audio_weight = (all_audio_features @ all_audio_features.T).detach()
299
- audio_weight = (
300
- torch.exp(
301
- torch.sum(audio_weight, axis=1)
302
- / (self.weight_loss_kappa * len(all_audio_features))
303
- )
304
- ).detach()
305
- text_weight = (all_text_features @ all_text_features.T).detach()
306
- text_weight = (
307
- torch.exp(
308
- torch.sum(text_weight, axis=1)
309
- / (self.weight_loss_kappa * len(all_text_features))
310
- )
311
- ).detach()
312
- total_loss = (
313
- F.cross_entropy(logits_per_audio, labels, weight=text_weight)
314
- + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
315
- ) / 2
316
- return total_loss
317
-
318
-
319
- def lp_gather_features(pred, target, world_size=1, use_horovod=False):
320
- if use_horovod:
321
- assert hvd is not None, "Please install horovod"
322
- with torch.no_grad():
323
- all_preds = hvd.allgather(pred)
324
- all_targets = hvd.allgath(target)
325
- else:
326
- gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
327
- gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
328
-
329
- dist.all_gather(gathered_preds, pred)
330
- dist.all_gather(gathered_targets, target)
331
- all_preds = torch.cat(gathered_preds, dim=0)
332
- all_targets = torch.cat(gathered_targets, dim=0)
333
-
334
- return all_preds, all_targets
335
-
336
-
337
- def get_map(pred, target):
338
- pred = torch.sigmoid(pred).numpy()
339
- target = target.numpy()
340
- return np.mean(average_precision_score(target, pred, average=None))
341
-
342
-
343
- def get_acc(pred, target):
344
- pred = torch.argmax(pred, 1).numpy()
345
- target = torch.argmax(target, 1).numpy()
346
- return accuracy_score(target, pred)
347
-
348
-
349
- def get_mauc(pred, target):
350
- pred = torch.sigmoid(pred).numpy()
351
- target = target.numpy()
352
- return np.mean(roc_auc_score(target, pred, average=None))
353
-
354
-
355
- class LPMetrics(object):
356
- def __init__(self, metric_names=["map", "acc", "mauc"]):
357
- self.metrics = []
358
- for name in metric_names:
359
- self.metrics.append(self.get_metric(name))
360
- self.metric_names = metric_names
361
-
362
- def get_metric(self, name):
363
- if name == "map":
364
- return get_map
365
- elif name == "acc":
366
- return get_acc
367
- elif name == "mauc":
368
- return get_mauc
369
- else:
370
- raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
371
-
372
- def evaluate_mertics(self, pred, target):
373
- metric_dict = {}
374
- for i in range(len(self.metric_names)):
375
- metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
376
- return metric_dict
377
-
378
-
379
- def calc_celoss(pred, target):
380
- target = torch.argmax(target, 1).long()
381
- return nn.CrossEntropyLoss()(pred, target)
382
-
383
-
384
- class LPLoss(nn.Module):
385
- def __init__(self, loss_name):
386
- super().__init__()
387
- if loss_name == "bce":
388
- self.loss_func = nn.BCEWithLogitsLoss()
389
- elif loss_name == "ce":
390
- self.loss_func = calc_celoss
391
- elif loss_name == "mse":
392
- self.loss_func = nn.MSELoss()
393
- else:
394
- raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
395
-
396
- def forward(self, pred, target):
397
- loss = self.loss_func(pred, target)
398
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model.py DELETED
@@ -1,936 +0,0 @@
1
- """ CLAP Model
2
-
3
- Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- Adapted to the Audio Task.
5
- """
6
-
7
- from collections import OrderedDict
8
- from dataclasses import dataclass
9
- from email.mime import audio
10
- from typing import Tuple, Union, Callable, Optional
11
-
12
- import numpy as np
13
- import torch
14
- import torch.nn.functional as F
15
- from torch import nn
16
-
17
- from .timm_model import TimmModel
18
- import logging
19
- from .utils import freeze_batch_norm_2d
20
-
21
- from .pann_model import create_pann_model
22
- from .htsat import create_htsat_model
23
- from transformers import BertModel, RobertaModel, BartModel
24
- from transformers.tokenization_utils_base import BatchEncoding
25
-
26
-
27
- class MLPLayers(nn.Module):
28
- def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
29
- super(MLPLayers, self).__init__()
30
- self.nonlin = nonlin
31
- self.dropout = dropout
32
-
33
- sequence = []
34
- for u0, u1 in zip(units[:-1], units[1:]):
35
- sequence.append(nn.Linear(u0, u1))
36
- sequence.append(self.nonlin)
37
- sequence.append(nn.Dropout(self.dropout))
38
- sequence = sequence[:-2]
39
-
40
- self.sequential = nn.Sequential(*sequence)
41
-
42
- def forward(self, X):
43
- X = self.sequential(X)
44
- return X
45
-
46
-
47
- class Bottleneck(nn.Module):
48
- expansion = 4
49
-
50
- def __init__(self, inplanes, planes, stride=1):
51
- super().__init__()
52
-
53
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
54
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
55
- self.bn1 = nn.BatchNorm2d(planes)
56
-
57
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
58
- self.bn2 = nn.BatchNorm2d(planes)
59
-
60
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
61
-
62
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
63
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
64
-
65
- self.relu = nn.ReLU(inplace=True)
66
- self.downsample = None
67
- self.stride = stride
68
-
69
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
70
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
71
- self.downsample = nn.Sequential(
72
- OrderedDict(
73
- [
74
- ("-1", nn.AvgPool2d(stride)),
75
- (
76
- "0",
77
- nn.Conv2d(
78
- inplanes,
79
- planes * self.expansion,
80
- 1,
81
- stride=1,
82
- bias=False,
83
- ),
84
- ),
85
- ("1", nn.BatchNorm2d(planes * self.expansion)),
86
- ]
87
- )
88
- )
89
-
90
- def forward(self, x: torch.Tensor):
91
- identity = x
92
-
93
- out = self.relu(self.bn1(self.conv1(x)))
94
- out = self.relu(self.bn2(self.conv2(out)))
95
- out = self.avgpool(out)
96
- out = self.bn3(self.conv3(out))
97
-
98
- if self.downsample is not None:
99
- identity = self.downsample(x)
100
-
101
- out += identity
102
- out = self.relu(out)
103
- return out
104
-
105
-
106
- class AttentionPool2d(nn.Module):
107
- def __init__(
108
- self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
109
- ):
110
- super().__init__()
111
- self.positional_embedding = nn.Parameter(
112
- torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
113
- )
114
- self.k_proj = nn.Linear(embed_dim, embed_dim)
115
- self.q_proj = nn.Linear(embed_dim, embed_dim)
116
- self.v_proj = nn.Linear(embed_dim, embed_dim)
117
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
118
- self.num_heads = num_heads
119
-
120
- def forward(self, x):
121
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
122
- 2, 0, 1
123
- ) # NCHW -> (HW)NC
124
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
125
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
126
- x, _ = F.multi_head_attention_forward(
127
- query=x,
128
- key=x,
129
- value=x,
130
- embed_dim_to_check=x.shape[-1],
131
- num_heads=self.num_heads,
132
- q_proj_weight=self.q_proj.weight,
133
- k_proj_weight=self.k_proj.weight,
134
- v_proj_weight=self.v_proj.weight,
135
- in_proj_weight=None,
136
- in_proj_bias=torch.cat(
137
- [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
138
- ),
139
- bias_k=None,
140
- bias_v=None,
141
- add_zero_attn=False,
142
- dropout_p=0,
143
- out_proj_weight=self.c_proj.weight,
144
- out_proj_bias=self.c_proj.bias,
145
- use_separate_proj_weight=True,
146
- training=self.training,
147
- need_weights=False,
148
- )
149
-
150
- return x[0]
151
-
152
-
153
- class ModifiedResNet(nn.Module):
154
- """
155
- A ResNet class that is similar to torchvision's but contains the following changes:
156
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
157
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
158
- - The final pooling layer is a QKV attention instead of an average pool
159
- """
160
-
161
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
162
- super().__init__()
163
- self.output_dim = output_dim
164
- self.image_size = image_size
165
-
166
- # the 3-layer stem
167
- self.conv1 = nn.Conv2d(
168
- 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
169
- )
170
- self.bn1 = nn.BatchNorm2d(width // 2)
171
- self.conv2 = nn.Conv2d(
172
- width // 2, width // 2, kernel_size=3, padding=1, bias=False
173
- )
174
- self.bn2 = nn.BatchNorm2d(width // 2)
175
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
176
- self.bn3 = nn.BatchNorm2d(width)
177
- self.avgpool = nn.AvgPool2d(2)
178
- self.relu = nn.ReLU(inplace=True)
179
-
180
- # residual layers
181
- self._inplanes = width # this is a *mutable* variable used during construction
182
- self.layer1 = self._make_layer(width, layers[0])
183
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
184
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
185
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
186
-
187
- embed_dim = width * 32 # the ResNet feature dimension
188
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
189
-
190
- self.init_parameters()
191
-
192
- def _make_layer(self, planes, blocks, stride=1):
193
- layers = [Bottleneck(self._inplanes, planes, stride)]
194
-
195
- self._inplanes = planes * Bottleneck.expansion
196
- for _ in range(1, blocks):
197
- layers.append(Bottleneck(self._inplanes, planes))
198
-
199
- return nn.Sequential(*layers)
200
-
201
- def init_parameters(self):
202
- if self.attnpool is not None:
203
- std = self.attnpool.c_proj.in_features**-0.5
204
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
205
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
206
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
207
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
208
-
209
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
210
- for name, param in resnet_block.named_parameters():
211
- if name.endswith("bn3.weight"):
212
- nn.init.zeros_(param)
213
-
214
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
215
- assert (
216
- unlocked_groups == 0
217
- ), "partial locking not currently supported for this model"
218
- for param in self.parameters():
219
- param.requires_grad = False
220
- if freeze_bn_stats:
221
- freeze_batch_norm_2d(self)
222
-
223
- def stem(self, x):
224
- for conv, bn in [
225
- (self.conv1, self.bn1),
226
- (self.conv2, self.bn2),
227
- (self.conv3, self.bn3),
228
- ]:
229
- x = self.relu(bn(conv(x)))
230
- x = self.avgpool(x)
231
- return x
232
-
233
- def forward(self, x):
234
- x = self.stem(x)
235
- x = self.layer1(x)
236
- x = self.layer2(x)
237
- x = self.layer3(x)
238
- x = self.layer4(x)
239
- x = self.attnpool(x)
240
-
241
- return x
242
-
243
-
244
- class LayerNorm(nn.LayerNorm):
245
- """Subclass torch's LayerNorm to handle fp16."""
246
-
247
- def forward(self, x: torch.Tensor):
248
- orig_type = x.dtype
249
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
250
- return x.to(orig_type)
251
-
252
-
253
- class QuickGELU(nn.Module):
254
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
255
- def forward(self, x: torch.Tensor):
256
- return x * torch.sigmoid(1.702 * x)
257
-
258
-
259
- class ResidualAttentionBlock(nn.Module):
260
- def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
261
- super().__init__()
262
-
263
- self.attn = nn.MultiheadAttention(d_model, n_head)
264
- self.ln_1 = LayerNorm(d_model)
265
- self.mlp = nn.Sequential(
266
- OrderedDict(
267
- [
268
- ("c_fc", nn.Linear(d_model, d_model * 4)),
269
- ("gelu", act_layer()),
270
- ("c_proj", nn.Linear(d_model * 4, d_model)),
271
- ]
272
- )
273
- )
274
- self.ln_2 = LayerNorm(d_model)
275
-
276
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
278
-
279
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
280
- x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
281
- x = x + self.mlp(self.ln_2(x))
282
- return x
283
-
284
-
285
- class Transformer(nn.Module):
286
- def __init__(
287
- self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
288
- ):
289
- super().__init__()
290
- self.width = width
291
- self.layers = layers
292
- self.resblocks = nn.ModuleList(
293
- [
294
- ResidualAttentionBlock(width, heads, act_layer=act_layer)
295
- for _ in range(layers)
296
- ]
297
- )
298
-
299
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
300
- for r in self.resblocks:
301
- x = r(x, attn_mask=attn_mask)
302
- return x
303
-
304
-
305
- class VisualTransformer(nn.Module):
306
- def __init__(
307
- self,
308
- image_size: int,
309
- patch_size: int,
310
- width: int,
311
- layers: int,
312
- heads: int,
313
- output_dim: int,
314
- act_layer: Callable = nn.GELU,
315
- ):
316
- super().__init__()
317
- self.image_size = image_size
318
- self.output_dim = output_dim
319
- self.conv1 = nn.Conv2d(
320
- in_channels=3,
321
- out_channels=width,
322
- kernel_size=patch_size,
323
- stride=patch_size,
324
- bias=False,
325
- )
326
-
327
- scale = width**-0.5
328
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
329
- self.positional_embedding = nn.Parameter(
330
- scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
331
- )
332
- self.ln_pre = LayerNorm(width)
333
-
334
- self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
335
-
336
- self.ln_post = LayerNorm(width)
337
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
338
-
339
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
340
- assert (
341
- unlocked_groups == 0
342
- ), "partial locking not currently supported for this model"
343
- for param in self.parameters():
344
- param.requires_grad = False
345
-
346
- def forward(self, x: torch.Tensor):
347
- x = self.conv1(x) # shape = [*, width, grid, grid]
348
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
349
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
350
- x = torch.cat(
351
- [
352
- self.class_embedding.to(x.dtype)
353
- + torch.zeros(
354
- x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
355
- ),
356
- x,
357
- ],
358
- dim=1,
359
- ) # shape = [*, grid ** 2 + 1, width]
360
- x = x + self.positional_embedding.to(x.dtype)
361
- x = self.ln_pre(x)
362
-
363
- x = x.permute(1, 0, 2) # NLD -> LND
364
- x = self.text_branch(x)
365
- x = x.permute(1, 0, 2) # LND -> NLD
366
-
367
- x = self.ln_post(x[:, 0, :])
368
-
369
- if self.proj is not None:
370
- x = x @ self.proj
371
-
372
- return x
373
-
374
-
375
- @dataclass
376
- class CLAPVisionCfg:
377
- layers: Union[Tuple[int, int, int, int], int] = 12
378
- width: int = 768
379
- patch_size: int = 16
380
- image_size: Union[Tuple[int, int], int] = 224
381
- timm_model_name: str = (
382
- None # a valid model name overrides layers, width, patch_size
383
- )
384
- timm_model_pretrained: bool = (
385
- False # use (imagenet) pretrained weights for named model
386
- )
387
- timm_pool: str = (
388
- "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
389
- )
390
- timm_proj: str = (
391
- "linear" # linear projection for timm model output ('linear', 'mlp', '')
392
- )
393
-
394
-
395
- # Audio Config Class
396
- @dataclass
397
- class CLAPAudioCfp:
398
- model_type: str = "PANN"
399
- model_name: str = "Cnn14"
400
- sample_rate: int = 48000
401
- # Param
402
- audio_length: int = 1024
403
- window_size: int = 1024
404
- hop_size: int = 1024
405
- fmin: int = 50
406
- fmax: int = 14000
407
- class_num: int = 527
408
- mel_bins: int = 64
409
- clip_samples: int = 480000
410
-
411
-
412
- @dataclass
413
- class CLAPTextCfg:
414
- context_length: int
415
- vocab_size: int
416
- width: int
417
- heads: int
418
- layers: int
419
- model_type: str
420
-
421
-
422
- class CLAP(nn.Module):
423
- def __init__(
424
- self,
425
- embed_dim: int,
426
- audio_cfg: CLAPAudioCfp,
427
- text_cfg: CLAPTextCfg,
428
- quick_gelu: bool = False,
429
- enable_fusion: bool = False,
430
- fusion_type: str = "None",
431
- joint_embed_shape: int = 512,
432
- mlp_act: str = "relu",
433
- ):
434
- super().__init__()
435
- if isinstance(audio_cfg, dict):
436
- audio_cfg = CLAPAudioCfp(**audio_cfg)
437
- if isinstance(text_cfg, dict):
438
- text_cfg = CLAPTextCfg(**text_cfg)
439
-
440
- self.audio_cfg = audio_cfg
441
- self.text_cfg = text_cfg
442
- self.enable_fusion = enable_fusion
443
- self.fusion_type = fusion_type
444
- self.joint_embed_shape = joint_embed_shape
445
- self.mlp_act = mlp_act
446
-
447
- self.context_length = text_cfg.context_length
448
-
449
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
450
- # memory efficient in recent PyTorch releases (>= 1.10).
451
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
452
- act_layer = QuickGELU if quick_gelu else nn.GELU
453
-
454
- if mlp_act == "relu":
455
- mlp_act_layer = nn.ReLU()
456
- elif mlp_act == "gelu":
457
- mlp_act_layer = nn.GELU()
458
- else:
459
- raise NotImplementedError
460
-
461
- # audio branch
462
- # audio branch parameters
463
- if audio_cfg.model_type == "PANN":
464
- self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
465
- elif audio_cfg.model_type == "HTSAT":
466
- self.audio_branch = create_htsat_model(
467
- audio_cfg, enable_fusion, fusion_type
468
- )
469
- else:
470
- logging.error(f"Model config for {audio_cfg.model_type} not found")
471
- raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
472
-
473
- # text branch
474
- # text branch parameters
475
- if text_cfg.model_type == "transformer":
476
- self.text_branch = Transformer(
477
- width=text_cfg.width,
478
- layers=text_cfg.layers,
479
- heads=text_cfg.heads,
480
- act_layer=act_layer,
481
- )
482
- self.vocab_size = text_cfg.vocab_size
483
- self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
484
- self.positional_embedding = nn.Parameter(
485
- torch.empty(self.context_length, text_cfg.width)
486
- )
487
- self.ln_final = LayerNorm(text_cfg.width)
488
- self.text_transform = MLPLayers(
489
- units=[
490
- self.joint_embed_shape,
491
- self.joint_embed_shape,
492
- self.joint_embed_shape,
493
- ],
494
- dropout=0.1,
495
- )
496
- self.text_projection = nn.Sequential(
497
- nn.Linear(text_cfg.width, self.joint_embed_shape),
498
- mlp_act_layer,
499
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
500
- )
501
- elif text_cfg.model_type == "bert":
502
- self.text_branch = BertModel.from_pretrained("bert-base-uncased")
503
- self.text_transform = MLPLayers(
504
- units=[
505
- self.joint_embed_shape,
506
- self.joint_embed_shape,
507
- self.joint_embed_shape,
508
- ],
509
- dropout=0.1,
510
- )
511
- self.text_projection = nn.Sequential(
512
- nn.Linear(768, self.joint_embed_shape),
513
- mlp_act_layer,
514
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
515
- )
516
- elif text_cfg.model_type == "roberta":
517
- self.text_branch = RobertaModel.from_pretrained("roberta-base")
518
- self.text_transform = MLPLayers(
519
- units=[
520
- self.joint_embed_shape,
521
- self.joint_embed_shape,
522
- self.joint_embed_shape,
523
- ],
524
- dropout=0.1,
525
- )
526
- self.text_projection = nn.Sequential(
527
- nn.Linear(768, self.joint_embed_shape),
528
- mlp_act_layer,
529
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
530
- )
531
- elif text_cfg.model_type == "bart":
532
- self.text_branch = BartModel.from_pretrained("facebook/bart-base")
533
- self.text_transform = MLPLayers(
534
- units=[
535
- self.joint_embed_shape,
536
- self.joint_embed_shape,
537
- self.joint_embed_shape,
538
- ],
539
- dropout=0.1,
540
- )
541
- self.text_projection = nn.Sequential(
542
- nn.Linear(768, self.joint_embed_shape),
543
- mlp_act_layer,
544
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
545
- )
546
- else:
547
- logging.error(f"Model config for {text_cfg.model_type} not found")
548
- raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
549
- self.text_branch_type = text_cfg.model_type
550
- # text branch parameters
551
-
552
- # audio branch parameters
553
- self.audio_transform = MLPLayers(
554
- units=[
555
- self.joint_embed_shape,
556
- self.joint_embed_shape,
557
- self.joint_embed_shape,
558
- ],
559
- dropout=0.1,
560
- )
561
-
562
- # below here is text branch parameters
563
-
564
- # ============================================================================================================
565
- self.audio_projection = nn.Sequential(
566
- nn.Linear(embed_dim, self.joint_embed_shape),
567
- mlp_act_layer,
568
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
569
- )
570
-
571
- self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
- self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
573
- self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
574
-
575
- self.init_text_branch_parameters()
576
-
577
- def init_text_branch_parameters(self):
578
- if self.text_branch_type == "transformer":
579
- nn.init.normal_(self.token_embedding.weight, std=0.02)
580
- nn.init.normal_(self.positional_embedding, std=0.01)
581
- proj_std = (self.text_branch.width**-0.5) * (
582
- (2 * self.text_branch.layers) ** -0.5
583
- )
584
- attn_std = self.text_branch.width**-0.5
585
- fc_std = (2 * self.text_branch.width) ** -0.5
586
- for block in self.text_branch.resblocks:
587
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
588
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
589
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
590
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
591
- if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
592
- width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
593
- elif self.text_branch_type == "bart":
594
- width = self.text_branch.shared.weight.shape[-1]
595
- else:
596
- width = self.text_branch.width
597
- nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
598
- nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
599
-
600
- # deprecated
601
- # if hasattr(self.visual, 'init_parameters'):
602
- # self.visual.init_parameters()
603
-
604
- # if self.text_projection is not None:
605
- # nn.init.normal_(self.text_projection, std=width**-0.5)
606
-
607
- def build_attention_mask(self):
608
- # lazily create causal attention mask, with full attention between the vision tokens
609
- # pytorch uses additive attention mask; fill with -inf
610
- mask = torch.empty(self.context_length, self.context_length)
611
- mask.fill_(float("-inf"))
612
- mask.triu_(1) # zero out the lower diagonal
613
- return mask
614
-
615
- def encode_audio(self, audio, device):
616
- return self.audio_branch(
617
- audio, mixup_lambda=None, device=device
618
- ) # mix lambda needs to add
619
-
620
- # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
621
- # tmp = {}
622
- # for k in x[0].keys():
623
- # tmp[k] = []
624
- # for i in range(len(x)):
625
- # tmp[k].append(x[i][k][:77])
626
- # for k in x[0].keys():
627
- # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
628
- # return tmp
629
-
630
- def encode_text(self, text, device):
631
- if self.text_branch_type == "transformer":
632
- text = text.to(device=device, non_blocking=True)
633
- x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
634
-
635
- x = x + self.positional_embedding
636
- x = x.permute(1, 0, 2) # NLD -> LND
637
- x = self.text_branch(x, attn_mask=self.attn_mask)
638
- x = x.permute(1, 0, 2) # LND -> NLD
639
- x = self.ln_final(x)
640
-
641
- # x.shape = [batch_size, n_ctx, transformer.width]
642
- # take features from the eot embedding (eot_token is the highest number in each sequence)
643
- x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
644
- elif self.text_branch_type == "bert":
645
- # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
646
- # text = BatchEncoding(text)
647
- x = self.text_branch(
648
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
649
- attention_mask=text["attention_mask"].to(
650
- device=device, non_blocking=True
651
- ),
652
- token_type_ids=text["token_type_ids"].to(
653
- device=device, non_blocking=True
654
- ),
655
- )["pooler_output"]
656
- x = self.text_projection(x)
657
- elif self.text_branch_type == "roberta":
658
- x = self.text_branch(
659
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
660
- attention_mask=text["attention_mask"].to(
661
- device=device, non_blocking=True
662
- ),
663
- )["pooler_output"]
664
- x = self.text_projection(x)
665
- elif self.text_branch_type == "bart":
666
- x = torch.mean(
667
- self.text_branch(
668
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
669
- attention_mask=text["attention_mask"].to(
670
- device=device, non_blocking=True
671
- ),
672
- )["encoder_last_hidden_state"],
673
- axis=1,
674
- )
675
- x = self.text_projection(x)
676
- else:
677
- logging.error(f"Model type {self.text_branch_type} not found")
678
- raise RuntimeError(f"Model type {self.text_branch_type} not found.")
679
- return x
680
-
681
- def forward(self, audio, text, device=None):
682
- """Forward audio and text into the CLAP
683
-
684
- Parameters
685
- ----------
686
- audio: torch.Tensor (batch_size, audio_length)
687
- the time-domain audio input / the batch of mel_spec and longer list.
688
- text: torch.Tensor () // need to add
689
- the text token input
690
- """
691
- if device is None:
692
- if audio is not None:
693
- device = audio.device
694
- elif text is not None:
695
- device = text.device
696
- if audio is None and text is None:
697
- # a hack to get the logit scale
698
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
699
- elif audio is None:
700
- return self.encode_text(text, device=device)
701
- elif text is None:
702
- return self.audio_projection(
703
- self.encode_audio(audio, device=device)["embedding"]
704
- )
705
- audio_features = self.audio_projection(
706
- self.encode_audio(audio, device=device)["embedding"]
707
- )
708
- audio_features = F.normalize(audio_features, dim=-1)
709
-
710
- text_features = self.encode_text(text, device=device)
711
- # print("text_features", text_features)
712
- # print("text_features.shape", text_features.shape)
713
- # print("text_features.type", type(text_features))
714
- text_features = F.normalize(text_features, dim=-1)
715
-
716
- audio_features_mlp = self.audio_transform(audio_features)
717
- text_features_mlp = self.text_transform(text_features)
718
- # Four outputs: audio features (basic & MLP), text features (basic & MLP)
719
- return (
720
- audio_features,
721
- text_features,
722
- audio_features_mlp,
723
- text_features_mlp,
724
- self.logit_scale_a.exp(),
725
- self.logit_scale_t.exp(),
726
- )
727
-
728
- def get_logit_scale(self):
729
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
730
-
731
- def get_text_embedding(self, data):
732
- """Get the text embedding from the model
733
-
734
- Parameters
735
- ----------
736
- data: torch.Tensor
737
- a tensor of text embedding
738
-
739
- Returns
740
- ----------
741
- text_embed: torch.Tensor
742
- a tensor of text_embeds (N, D)
743
-
744
- """
745
- device = next(self.parameters()).device
746
- for k in data:
747
- data[k] = data[k].to(device)
748
- if(len(data[k].size()) < 2):
749
- data[k] = data[k].unsqueeze(0)
750
- text_embeds = self.encode_text(data, device=device)
751
- text_embeds = F.normalize(text_embeds, dim=-1)
752
-
753
- return text_embeds
754
-
755
- def get_audio_embedding(self, data):
756
- """Get the audio embedding from the model
757
-
758
- Parameters
759
- ----------
760
- data: a list of dict
761
- the audio input dict list from 'get_audio_feature' method
762
-
763
- Returns
764
- ----------
765
- audio_embed: torch.Tensor
766
- a tensor of audio_embeds (N, D)
767
-
768
- """
769
- device = next(self.parameters()).device
770
- input_dict = {}
771
- keys = data[0].keys()
772
- for k in keys:
773
- input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
774
- device
775
- )
776
-
777
- audio_embeds = self.audio_projection(
778
- self.encode_audio(input_dict, device=device)["embedding"]
779
- )
780
- audio_embeds = F.normalize(audio_embeds, dim=-1)
781
-
782
- return audio_embeds
783
-
784
- def audio_infer(self, audio, hopsize=None, device=None):
785
- """Forward one audio and produce the audio embedding
786
-
787
- Parameters
788
- ----------
789
- audio: (audio_length)
790
- the time-domain audio input, notice that it must be only one input
791
- hopsize: int
792
- the overlap hopsize as the sliding window
793
-
794
- Returns
795
- ----------
796
- output_dict: {
797
- key: [n, (embedding_shape)] if "HTS-AT"
798
- or
799
- key: [(embedding_shape)] if "PANN"
800
- }
801
- the list of key values of the audio branch
802
-
803
- """
804
-
805
- assert not self.training, "the inference mode must be run at eval stage"
806
- output_dict = {}
807
- # PANN
808
- if self.audio_cfg.model_type == "PANN":
809
- audio_input = audio.unsqueeze(dim=0)
810
- output_dict[key] = self.encode_audio(audio_input, device=device)[
811
- key
812
- ].squeeze(dim=0)
813
- elif self.audio_cfg.model_type == "HTSAT":
814
- # repeat
815
- audio_len = len(audio)
816
- k = self.audio_cfg.clip_samples // audio_len
817
- if k > 1:
818
- audio = audio.repeat(k)
819
- audio_len = len(audio)
820
-
821
- if hopsize is None:
822
- hopsize = min(hopsize, audio_len)
823
-
824
- if audio_len > self.audio_cfg.clip_samples:
825
- audio_input = [
826
- audio[pos : pos + self.audio_cfg.clip_samples].clone()
827
- for pos in range(
828
- 0, audio_len - self.audio_cfg.clip_samples, hopsize
829
- )
830
- ]
831
- audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
832
- audio_input = torch.stack(audio_input)
833
- output_dict[key] = self.encode_audio(audio_input, device=device)[key]
834
- else:
835
- audio_input = audio.unsqueeze(dim=0)
836
- output_dict[key] = self.encode_audio(audio_input, device=device)[
837
- key
838
- ].squeeze(dim=0)
839
-
840
- return output_dict
841
-
842
-
843
- def convert_weights_to_fp16(model: nn.Module):
844
- """Convert applicable model parameters to fp16"""
845
-
846
- def _convert_weights_to_fp16(l):
847
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
848
- l.weight.data = l.weight.data.half()
849
- if l.bias is not None:
850
- l.bias.data = l.bias.data.half()
851
-
852
- if isinstance(l, nn.MultiheadAttention):
853
- for attr in [
854
- *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
855
- "in_proj_bias",
856
- "bias_k",
857
- "bias_v",
858
- ]:
859
- tensor = getattr(l, attr)
860
- if tensor is not None:
861
- tensor.data = tensor.data.half()
862
-
863
- for name in ["text_projection", "proj"]:
864
- if hasattr(l, name):
865
- attr = getattr(l, name)
866
- if attr is not None:
867
- attr.data = attr.data.half()
868
-
869
- model.apply(_convert_weights_to_fp16)
870
-
871
-
872
- # Ignore the state dict of the vision part
873
- def build_model_from_openai_state_dict(
874
- state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
875
- ):
876
-
877
- embed_dim = model_cfg["embed_dim"]
878
- audio_cfg = model_cfg["audio_cfg"]
879
- text_cfg = model_cfg["text_cfg"]
880
- context_length = state_dict["positional_embedding"].shape[0]
881
- vocab_size = state_dict["token_embedding.weight"].shape[0]
882
- transformer_width = state_dict["ln_final.weight"].shape[0]
883
- transformer_heads = transformer_width // 64
884
- transformer_layers = len(
885
- set(
886
- k.split(".")[2]
887
- for k in state_dict
888
- if k.startswith(f"transformer.resblocks")
889
- )
890
- )
891
-
892
- audio_cfg = CLAPAudioCfp(**audio_cfg)
893
- text_cfg = CLAPTextCfg(**text_cfg)
894
-
895
- model = CLAP(
896
- embed_dim,
897
- audio_cfg=audio_cfg,
898
- text_cfg=text_cfg,
899
- quick_gelu=True, # OpenAI models were trained with QuickGELU
900
- enable_fusion=enable_fusion,
901
- fusion_type=fusion_type,
902
- )
903
- state_dict["logit_scale_a"] = state_dict["logit_scale"]
904
- state_dict["logit_scale_t"] = state_dict["logit_scale"]
905
- pop_keys = list(state_dict.keys())[::]
906
- # pop the visual branch saved weights
907
- for key in pop_keys:
908
- if key.startswith("visual."):
909
- state_dict.pop(key, None)
910
-
911
- for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
912
- state_dict.pop(key, None)
913
-
914
- # not use fp16
915
- # convert_weights_to_fp16(model)
916
- model.load_state_dict(state_dict, strict=False)
917
- return model.eval()
918
-
919
-
920
- def trace_model(model, batch_size=256, device=torch.device("cpu")):
921
- model.eval()
922
- audio_length = model.audio_cfg.audio_length
923
- example_audio = torch.ones((batch_size, audio_length), device=device)
924
- example_text = torch.zeros(
925
- (batch_size, model.context_length), dtype=torch.int, device=device
926
- )
927
- model = torch.jit.trace_module(
928
- model,
929
- inputs=dict(
930
- forward=(example_audio, example_text),
931
- encode_text=(example_text,),
932
- encode_image=(example_audio,),
933
- ),
934
- )
935
- model.audio_cfg.audio_length = audio_length # Question: what does this do?
936
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-base.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "base"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-large.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "large"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-tiny.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-10.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn10"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 18000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 960000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 360,
10
- "fmin": 50,
11
- "fmax": 8000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 4
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-6.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn6"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN101-quickgelu.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 23,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN101.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 23,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN50-quickgelu.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 6,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN50.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 6,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN50x16.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 384,
5
- "layers": [
6
- 6,
7
- 8,
8
- 18,
9
- 8
10
- ],
11
- "width": 96,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 768,
18
- "heads": 12,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN50x4.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 640,
3
- "vision_cfg": {
4
- "image_size": 288,
5
- "layers": [
6
- 4,
7
- 6,
8
- 10,
9
- 6
10
- ],
11
- "width": 80,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 640,
18
- "heads": 10,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/ViT-B-16.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 16
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 512,
13
- "heads": 8,
14
- "layers": 12
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json DELETED
@@ -1,17 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": 12,
7
- "width": 768,
8
- "patch_size": 32
9
- },
10
- "text_cfg": {
11
- "context_length": 77,
12
- "vocab_size": 49408,
13
- "width": 512,
14
- "heads": 8,
15
- "layers": 12
16
- }
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/ViT-B-32.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 32
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 512,
13
- "heads": 8,
14
- "layers": 12
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/ViT-L-14.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 24,
6
- "width": 1024,
7
- "patch_size": 14
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 768,
13
- "heads": 12,
14
- "layers": 12
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/openai.py DELETED
@@ -1,156 +0,0 @@
1
- """ OpenAI pretrained model functions
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
-
6
- import os
7
- import warnings
8
- from typing import Union, List
9
-
10
- import torch
11
-
12
- from .model import build_model_from_openai_state_dict
13
- from .pretrained import (
14
- get_pretrained_url,
15
- list_pretrained_tag_models,
16
- download_pretrained,
17
- )
18
-
19
- __all__ = ["list_openai_models", "load_openai_model"]
20
-
21
-
22
- def list_openai_models() -> List[str]:
23
- """Returns the names of available CLIP models"""
24
- return list_pretrained_tag_models("openai")
25
-
26
-
27
- def load_openai_model(
28
- name: str,
29
- model_cfg,
30
- device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
31
- jit=True,
32
- cache_dir=os.path.expanduser("~/.cache/clip"),
33
- enable_fusion: bool = False,
34
- fusion_type: str = "None",
35
- ):
36
- """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
37
-
38
- Parameters
39
- ----------
40
- name : str
41
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
42
- device : Union[str, torch.device]
43
- The device to put the loaded model
44
- jit : bool
45
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
46
-
47
- Returns
48
- -------
49
- model : torch.nn.Module
50
- The CLAP model
51
- preprocess : Callable[[PIL.Image], torch.Tensor]
52
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
53
- """
54
- if get_pretrained_url(name, "openai"):
55
- model_path = download_pretrained(
56
- get_pretrained_url(name, "openai"), root=cache_dir
57
- )
58
- elif os.path.isfile(name):
59
- model_path = name
60
- else:
61
- raise RuntimeError(
62
- f"Model {name} not found; available models = {list_openai_models()}"
63
- )
64
-
65
- try:
66
- # loading JIT archive
67
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
68
- state_dict = None
69
- except RuntimeError:
70
- # loading saved state dict
71
- if jit:
72
- warnings.warn(
73
- f"File {model_path} is not a JIT archive. Loading as a state dict instead"
74
- )
75
- jit = False
76
- state_dict = torch.load(model_path, map_location="cpu")
77
-
78
- if not jit:
79
- try:
80
- model = build_model_from_openai_state_dict(
81
- state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
82
- ).to(device)
83
- except KeyError:
84
- sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
85
- model = build_model_from_openai_state_dict(
86
- sd, model_cfg, enable_fusion, fusion_type
87
- ).to(device)
88
-
89
- if str(device) == "cpu":
90
- model.float()
91
- return model
92
-
93
- # patch the device names
94
- device_holder = torch.jit.trace(
95
- lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
96
- )
97
- device_node = [
98
- n
99
- for n in device_holder.graph.findAllNodes("prim::Constant")
100
- if "Device" in repr(n)
101
- ][-1]
102
-
103
- def patch_device(module):
104
- try:
105
- graphs = [module.graph] if hasattr(module, "graph") else []
106
- except RuntimeError:
107
- graphs = []
108
-
109
- if hasattr(module, "forward1"):
110
- graphs.append(module.forward1.graph)
111
-
112
- for graph in graphs:
113
- for node in graph.findAllNodes("prim::Constant"):
114
- if "value" in node.attributeNames() and str(node["value"]).startswith(
115
- "cuda"
116
- ):
117
- node.copyAttributes(device_node)
118
-
119
- model.apply(patch_device)
120
- patch_device(model.encode_audio)
121
- patch_device(model.encode_text)
122
-
123
- # patch dtype to float32 on CPU
124
- if str(device) == "cpu":
125
- float_holder = torch.jit.trace(
126
- lambda: torch.ones([]).float(), example_inputs=[]
127
- )
128
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
129
- float_node = float_input.node()
130
-
131
- def patch_float(module):
132
- try:
133
- graphs = [module.graph] if hasattr(module, "graph") else []
134
- except RuntimeError:
135
- graphs = []
136
-
137
- if hasattr(module, "forward1"):
138
- graphs.append(module.forward1.graph)
139
-
140
- for graph in graphs:
141
- for node in graph.findAllNodes("aten::to"):
142
- inputs = list(node.inputs())
143
- for i in [
144
- 1,
145
- 2,
146
- ]: # dtype can be the second or third argument to aten::to()
147
- if inputs[i].node()["value"] == 5:
148
- inputs[i].node().copyAttributes(float_node)
149
-
150
- model.apply(patch_float)
151
- patch_float(model.encode_audio)
152
- patch_float(model.encode_text)
153
- model.float()
154
-
155
- model.audio_branch.audio_length = model.audio_cfg.audio_length
156
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/pann_model.py DELETED
@@ -1,703 +0,0 @@
1
- # PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
2
- # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
3
- # Some layers are re-designed for CLAP
4
- import os
5
-
6
- os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
12
- from torchlibrosa.augmentation import SpecAugmentation
13
-
14
- from .utils import do_mixup, interpolate, pad_framewise_output
15
- from .feature_fusion import iAFF, AFF, DAF
16
-
17
-
18
- def init_layer(layer):
19
- """Initialize a Linear or Convolutional layer."""
20
- nn.init.xavier_uniform_(layer.weight)
21
-
22
- if hasattr(layer, "bias"):
23
- if layer.bias is not None:
24
- layer.bias.data.fill_(0.0)
25
-
26
- def init_bn(bn):
27
- """Initialize a Batchnorm layer."""
28
- bn.bias.data.fill_(0.0)
29
- bn.weight.data.fill_(1.0)
30
-
31
-
32
- class ConvBlock(nn.Module):
33
- def __init__(self, in_channels, out_channels):
34
-
35
- super(ConvBlock, self).__init__()
36
-
37
- self.conv1 = nn.Conv2d(
38
- in_channels=in_channels,
39
- out_channels=out_channels,
40
- kernel_size=(3, 3),
41
- stride=(1, 1),
42
- padding=(1, 1),
43
- bias=False,
44
- )
45
-
46
- self.conv2 = nn.Conv2d(
47
- in_channels=out_channels,
48
- out_channels=out_channels,
49
- kernel_size=(3, 3),
50
- stride=(1, 1),
51
- padding=(1, 1),
52
- bias=False,
53
- )
54
-
55
- self.bn1 = nn.BatchNorm2d(out_channels)
56
- self.bn2 = nn.BatchNorm2d(out_channels)
57
-
58
- self.init_weight()
59
-
60
- def init_weight(self):
61
- init_layer(self.conv1)
62
- init_layer(self.conv2)
63
- init_bn(self.bn1)
64
- init_bn(self.bn2)
65
-
66
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
67
-
68
- x = input
69
- x = F.relu_(self.bn1(self.conv1(x)))
70
- x = F.relu_(self.bn2(self.conv2(x)))
71
- if pool_type == "max":
72
- x = F.max_pool2d(x, kernel_size=pool_size)
73
- elif pool_type == "avg":
74
- x = F.avg_pool2d(x, kernel_size=pool_size)
75
- elif pool_type == "avg+max":
76
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
77
- x2 = F.max_pool2d(x, kernel_size=pool_size)
78
- x = x1 + x2
79
- else:
80
- raise Exception("Incorrect argument!")
81
-
82
- return x
83
-
84
-
85
- class ConvBlock5x5(nn.Module):
86
- def __init__(self, in_channels, out_channels):
87
-
88
- super(ConvBlock5x5, self).__init__()
89
-
90
- self.conv1 = nn.Conv2d(
91
- in_channels=in_channels,
92
- out_channels=out_channels,
93
- kernel_size=(5, 5),
94
- stride=(1, 1),
95
- padding=(2, 2),
96
- bias=False,
97
- )
98
-
99
- self.bn1 = nn.BatchNorm2d(out_channels)
100
-
101
- self.init_weight()
102
-
103
- def init_weight(self):
104
- init_layer(self.conv1)
105
- init_bn(self.bn1)
106
-
107
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
108
-
109
- x = input
110
- x = F.relu_(self.bn1(self.conv1(x)))
111
- if pool_type == "max":
112
- x = F.max_pool2d(x, kernel_size=pool_size)
113
- elif pool_type == "avg":
114
- x = F.avg_pool2d(x, kernel_size=pool_size)
115
- elif pool_type == "avg+max":
116
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
117
- x2 = F.max_pool2d(x, kernel_size=pool_size)
118
- x = x1 + x2
119
- else:
120
- raise Exception("Incorrect argument!")
121
-
122
- return x
123
-
124
-
125
- class AttBlock(nn.Module):
126
- def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
127
- super(AttBlock, self).__init__()
128
-
129
- self.activation = activation
130
- self.temperature = temperature
131
- self.att = nn.Conv1d(
132
- in_channels=n_in,
133
- out_channels=n_out,
134
- kernel_size=1,
135
- stride=1,
136
- padding=0,
137
- bias=True,
138
- )
139
- self.cla = nn.Conv1d(
140
- in_channels=n_in,
141
- out_channels=n_out,
142
- kernel_size=1,
143
- stride=1,
144
- padding=0,
145
- bias=True,
146
- )
147
-
148
- self.bn_att = nn.BatchNorm1d(n_out)
149
- self.init_weights()
150
-
151
- def init_weights(self):
152
- init_layer(self.att)
153
- init_layer(self.cla)
154
- init_bn(self.bn_att)
155
-
156
- def forward(self, x):
157
- # x: (n_samples, n_in, n_time)
158
- norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
159
- cla = self.nonlinear_transform(self.cla(x))
160
- x = torch.sum(norm_att * cla, dim=2)
161
- return x, norm_att, cla
162
-
163
- def nonlinear_transform(self, x):
164
- if self.activation == "linear":
165
- return x
166
- elif self.activation == "sigmoid":
167
- return torch.sigmoid(x)
168
-
169
-
170
- class Cnn14(nn.Module):
171
- def __init__(
172
- self,
173
- sample_rate,
174
- window_size,
175
- hop_size,
176
- mel_bins,
177
- fmin,
178
- fmax,
179
- classes_num,
180
- enable_fusion=False,
181
- fusion_type="None",
182
- ):
183
-
184
- super(Cnn14, self).__init__()
185
-
186
- window = "hann"
187
- center = True
188
- pad_mode = "reflect"
189
- ref = 1.0
190
- amin = 1e-10
191
- top_db = None
192
-
193
- self.enable_fusion = enable_fusion
194
- self.fusion_type = fusion_type
195
-
196
- # Spectrogram extractor
197
- self.spectrogram_extractor = Spectrogram(
198
- n_fft=window_size,
199
- hop_length=hop_size,
200
- win_length=window_size,
201
- window=window,
202
- center=center,
203
- pad_mode=pad_mode,
204
- freeze_parameters=True,
205
- )
206
-
207
- # Logmel feature extractor
208
- self.logmel_extractor = LogmelFilterBank(
209
- sr=sample_rate,
210
- n_fft=window_size,
211
- n_mels=mel_bins,
212
- fmin=fmin,
213
- fmax=fmax,
214
- ref=ref,
215
- amin=amin,
216
- top_db=top_db,
217
- freeze_parameters=True,
218
- )
219
-
220
- # Spec augmenter
221
- self.spec_augmenter = SpecAugmentation(
222
- time_drop_width=64,
223
- time_stripes_num=2,
224
- freq_drop_width=8,
225
- freq_stripes_num=2,
226
- )
227
-
228
- self.bn0 = nn.BatchNorm2d(64)
229
-
230
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
231
- self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
232
- else:
233
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
234
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
235
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
236
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
237
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
238
- self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
239
-
240
- self.fc1 = nn.Linear(2048, 2048, bias=True)
241
- self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
242
-
243
- if (self.enable_fusion) and (
244
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
245
- ):
246
- self.mel_conv1d = nn.Sequential(
247
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
248
- nn.BatchNorm1d(64), # No Relu
249
- )
250
- if self.fusion_type == "daf_1d":
251
- self.fusion_model = DAF()
252
- elif self.fusion_type == "aff_1d":
253
- self.fusion_model = AFF(channels=64, type="1D")
254
- elif self.fusion_type == "iaff_1d":
255
- self.fusion_model = iAFF(channels=64, type="1D")
256
-
257
- if (self.enable_fusion) and (
258
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
259
- ):
260
- self.mel_conv2d = nn.Sequential(
261
- nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
262
- nn.BatchNorm2d(64),
263
- nn.ReLU(inplace=True),
264
- )
265
-
266
- if self.fusion_type == "daf_2d":
267
- self.fusion_model = DAF()
268
- elif self.fusion_type == "aff_2d":
269
- self.fusion_model = AFF(channels=64, type="2D")
270
- elif self.fusion_type == "iaff_2d":
271
- self.fusion_model = iAFF(channels=64, type="2D")
272
- self.init_weight()
273
-
274
- def init_weight(self):
275
- init_bn(self.bn0)
276
- init_layer(self.fc1)
277
- init_layer(self.fc_audioset)
278
-
279
- def forward(self, input, mixup_lambda=None, device=None):
280
- """
281
- Input: (batch_size, data_length)"""
282
-
283
- if self.enable_fusion and input["longer"].sum() == 0:
284
- # if no audio is longer than 10s, then randomly select one audio to be longer
285
- input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
286
-
287
- if not self.enable_fusion:
288
- x = self.spectrogram_extractor(
289
- input["waveform"].to(device=device, non_blocking=True)
290
- ) # (batch_size, 1, time_steps, freq_bins)
291
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
292
-
293
- x = x.transpose(1, 3)
294
- x = self.bn0(x)
295
- x = x.transpose(1, 3)
296
- else:
297
- longer_list = input["longer"].to(device=device, non_blocking=True)
298
- x = input["mel_fusion"].to(device=device, non_blocking=True)
299
- longer_list_idx = torch.where(longer_list)[0]
300
- x = x.transpose(1, 3)
301
- x = self.bn0(x)
302
- x = x.transpose(1, 3)
303
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
304
- new_x = x[:, 0:1, :, :].clone().contiguous()
305
- # local processing
306
- if len(longer_list_idx) > 0:
307
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
308
- FB, FC, FT, FF = fusion_x_local.size()
309
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
310
- fusion_x_local = torch.permute(
311
- fusion_x_local, (0, 2, 1)
312
- ).contiguous()
313
- fusion_x_local = self.mel_conv1d(fusion_x_local)
314
- fusion_x_local = fusion_x_local.view(
315
- FB, FC, FF, fusion_x_local.size(-1)
316
- )
317
- fusion_x_local = (
318
- torch.permute(fusion_x_local, (0, 2, 1, 3))
319
- .contiguous()
320
- .flatten(2)
321
- )
322
- if fusion_x_local.size(-1) < FT:
323
- fusion_x_local = torch.cat(
324
- [
325
- fusion_x_local,
326
- torch.zeros(
327
- (FB, FF, FT - fusion_x_local.size(-1)),
328
- device=device,
329
- ),
330
- ],
331
- dim=-1,
332
- )
333
- else:
334
- fusion_x_local = fusion_x_local[:, :, :FT]
335
- # 1D fusion
336
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
337
- new_x[longer_list_idx] = self.fusion_model(
338
- new_x[longer_list_idx], fusion_x_local
339
- )
340
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
341
- else:
342
- x = new_x
343
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
344
- x = x # no change
345
-
346
- if self.training:
347
- x = self.spec_augmenter(x)
348
- # Mixup on spectrogram
349
- if self.training and mixup_lambda is not None:
350
- x = do_mixup(x, mixup_lambda)
351
- if (self.enable_fusion) and (
352
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
353
- ):
354
- global_x = x[:, 0:1, :, :]
355
-
356
- # global processing
357
- B, C, H, W = global_x.shape
358
- global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
359
- if len(longer_list_idx) > 0:
360
- local_x = x[longer_list_idx, 1:, :, :].contiguous()
361
- TH = global_x.size(-2)
362
- # local processing
363
- B, C, H, W = local_x.shape
364
- local_x = local_x.view(B * C, 1, H, W)
365
- local_x = self.mel_conv2d(local_x)
366
- local_x = local_x.view(
367
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
368
- )
369
- local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
370
- TB, TC, _, TW = local_x.size()
371
- if local_x.size(-2) < TH:
372
- local_x = torch.cat(
373
- [
374
- local_x,
375
- torch.zeros(
376
- (TB, TC, TH - local_x.size(-2), TW),
377
- device=global_x.device,
378
- ),
379
- ],
380
- dim=-2,
381
- )
382
- else:
383
- local_x = local_x[:, :, :TH, :]
384
-
385
- global_x[longer_list_idx] = self.fusion_model(
386
- global_x[longer_list_idx], local_x
387
- )
388
- x = global_x
389
- else:
390
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
391
-
392
- x = F.dropout(x, p=0.2, training=self.training)
393
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
394
- x = F.dropout(x, p=0.2, training=self.training)
395
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
396
- x = F.dropout(x, p=0.2, training=self.training)
397
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
398
- x = F.dropout(x, p=0.2, training=self.training)
399
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
400
- x = F.dropout(x, p=0.2, training=self.training)
401
- x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
402
- x = F.dropout(x, p=0.2, training=self.training)
403
- x = torch.mean(x, dim=3)
404
-
405
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
406
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
407
- latent_x = latent_x1 + latent_x2
408
- latent_x = latent_x.transpose(1, 2)
409
- latent_x = F.relu_(self.fc1(latent_x))
410
- latent_output = interpolate(latent_x, 32)
411
-
412
- (x1, _) = torch.max(x, dim=2)
413
- x2 = torch.mean(x, dim=2)
414
- x = x1 + x2
415
- x = F.dropout(x, p=0.5, training=self.training)
416
- x = F.relu_(self.fc1(x))
417
- embedding = F.dropout(x, p=0.5, training=self.training)
418
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
419
-
420
- output_dict = {
421
- "clipwise_output": clipwise_output,
422
- "embedding": embedding,
423
- "fine_grained_embedding": latent_output,
424
- }
425
- return output_dict
426
-
427
-
428
- class Cnn6(nn.Module):
429
- def __init__(
430
- self,
431
- sample_rate,
432
- window_size,
433
- hop_size,
434
- mel_bins,
435
- fmin,
436
- fmax,
437
- classes_num,
438
- enable_fusion=False,
439
- fusion_type="None",
440
- ):
441
-
442
- super(Cnn6, self).__init__()
443
-
444
- window = "hann"
445
- center = True
446
- pad_mode = "reflect"
447
- ref = 1.0
448
- amin = 1e-10
449
- top_db = None
450
-
451
- self.enable_fusion = enable_fusion
452
- self.fusion_type = fusion_type
453
-
454
- # Spectrogram extractor
455
- self.spectrogram_extractor = Spectrogram(
456
- n_fft=window_size,
457
- hop_length=hop_size,
458
- win_length=window_size,
459
- window=window,
460
- center=center,
461
- pad_mode=pad_mode,
462
- freeze_parameters=True,
463
- )
464
-
465
- # Logmel feature extractor
466
- self.logmel_extractor = LogmelFilterBank(
467
- sr=sample_rate,
468
- n_fft=window_size,
469
- n_mels=mel_bins,
470
- fmin=fmin,
471
- fmax=fmax,
472
- ref=ref,
473
- amin=amin,
474
- top_db=top_db,
475
- freeze_parameters=True,
476
- )
477
-
478
- # Spec augmenter
479
- self.spec_augmenter = SpecAugmentation(
480
- time_drop_width=64,
481
- time_stripes_num=2,
482
- freq_drop_width=8,
483
- freq_stripes_num=2,
484
- )
485
-
486
- self.bn0 = nn.BatchNorm2d(64)
487
-
488
- self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
489
- self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
490
- self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
491
- self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
492
-
493
- self.fc1 = nn.Linear(512, 512, bias=True)
494
- self.fc_audioset = nn.Linear(512, classes_num, bias=True)
495
-
496
- self.init_weight()
497
-
498
- def init_weight(self):
499
- init_bn(self.bn0)
500
- init_layer(self.fc1)
501
- init_layer(self.fc_audioset)
502
-
503
- def forward(self, input, mixup_lambda=None, device=None):
504
- """
505
- Input: (batch_size, data_length)"""
506
-
507
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
508
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
509
-
510
- x = x.transpose(1, 3)
511
- x = self.bn0(x)
512
- x = x.transpose(1, 3)
513
-
514
- if self.training:
515
- x = self.spec_augmenter(x)
516
-
517
- # Mixup on spectrogram
518
- if self.training and mixup_lambda is not None:
519
- x = do_mixup(x, mixup_lambda)
520
-
521
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
522
- x = F.dropout(x, p=0.2, training=self.training)
523
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
524
- x = F.dropout(x, p=0.2, training=self.training)
525
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
526
- x = F.dropout(x, p=0.2, training=self.training)
527
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
528
- x = F.dropout(x, p=0.2, training=self.training)
529
- x = torch.mean(x, dim=3)
530
-
531
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
532
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
533
- latent_x = latent_x1 + latent_x2
534
- latent_x = latent_x.transpose(1, 2)
535
- latent_x = F.relu_(self.fc1(latent_x))
536
- latent_output = interpolate(latent_x, 16)
537
-
538
- (x1, _) = torch.max(x, dim=2)
539
- x2 = torch.mean(x, dim=2)
540
- x = x1 + x2
541
- x = F.dropout(x, p=0.5, training=self.training)
542
- x = F.relu_(self.fc1(x))
543
- embedding = F.dropout(x, p=0.5, training=self.training)
544
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
545
-
546
- output_dict = {
547
- "clipwise_output": clipwise_output,
548
- "embedding": embedding,
549
- "fine_grained_embedding": latent_output,
550
- }
551
-
552
- return output_dict
553
-
554
-
555
- class Cnn10(nn.Module):
556
- def __init__(
557
- self,
558
- sample_rate,
559
- window_size,
560
- hop_size,
561
- mel_bins,
562
- fmin,
563
- fmax,
564
- classes_num,
565
- enable_fusion=False,
566
- fusion_type="None",
567
- ):
568
-
569
- super(Cnn10, self).__init__()
570
-
571
- window = "hann"
572
- center = True
573
- pad_mode = "reflect"
574
- ref = 1.0
575
- amin = 1e-10
576
- top_db = None
577
-
578
- self.enable_fusion = enable_fusion
579
- self.fusion_type = fusion_type
580
-
581
- # Spectrogram extractor
582
- self.spectrogram_extractor = Spectrogram(
583
- n_fft=window_size,
584
- hop_length=hop_size,
585
- win_length=window_size,
586
- window=window,
587
- center=center,
588
- pad_mode=pad_mode,
589
- freeze_parameters=True,
590
- )
591
-
592
- # Logmel feature extractor
593
- self.logmel_extractor = LogmelFilterBank(
594
- sr=sample_rate,
595
- n_fft=window_size,
596
- n_mels=mel_bins,
597
- fmin=fmin,
598
- fmax=fmax,
599
- ref=ref,
600
- amin=amin,
601
- top_db=top_db,
602
- freeze_parameters=True,
603
- )
604
-
605
- # Spec augmenter
606
- self.spec_augmenter = SpecAugmentation(
607
- time_drop_width=64,
608
- time_stripes_num=2,
609
- freq_drop_width=8,
610
- freq_stripes_num=2,
611
- )
612
-
613
- self.bn0 = nn.BatchNorm2d(64)
614
-
615
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
616
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
617
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
618
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
619
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
620
-
621
- self.fc1 = nn.Linear(1024, 1024, bias=True)
622
- self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
623
-
624
- self.init_weight()
625
-
626
- def init_weight(self):
627
- init_bn(self.bn0)
628
- init_layer(self.fc1)
629
- init_layer(self.fc_audioset)
630
-
631
- def forward(self, input, mixup_lambda=None, device=None):
632
- """
633
- Input: (batch_size, data_length)"""
634
-
635
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
636
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
637
-
638
- x = x.transpose(1, 3)
639
- x = self.bn0(x)
640
- x = x.transpose(1, 3)
641
-
642
- if self.training:
643
- x = self.spec_augmenter(x)
644
-
645
- # Mixup on spectrogram
646
- if self.training and mixup_lambda is not None:
647
- x = do_mixup(x, mixup_lambda)
648
-
649
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
650
- x = F.dropout(x, p=0.2, training=self.training)
651
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
652
- x = F.dropout(x, p=0.2, training=self.training)
653
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
654
- x = F.dropout(x, p=0.2, training=self.training)
655
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
656
- x = F.dropout(x, p=0.2, training=self.training)
657
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
658
- x = F.dropout(x, p=0.2, training=self.training)
659
- x = torch.mean(x, dim=3)
660
-
661
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
662
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
663
- latent_x = latent_x1 + latent_x2
664
- latent_x = latent_x.transpose(1, 2)
665
- latent_x = F.relu_(self.fc1(latent_x))
666
- latent_output = interpolate(latent_x, 32)
667
-
668
- (x1, _) = torch.max(x, dim=2)
669
- x2 = torch.mean(x, dim=2)
670
- x = x1 + x2
671
- x = F.dropout(x, p=0.5, training=self.training)
672
- x = F.relu_(self.fc1(x))
673
- embedding = F.dropout(x, p=0.5, training=self.training)
674
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
675
-
676
- output_dict = {
677
- "clipwise_output": clipwise_output,
678
- "embedding": embedding,
679
- "fine_grained_embedding": latent_output,
680
- }
681
-
682
- return output_dict
683
-
684
-
685
- def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
686
- try:
687
- ModelProto = eval(audio_cfg.model_name)
688
- model = ModelProto(
689
- sample_rate=audio_cfg.sample_rate,
690
- window_size=audio_cfg.window_size,
691
- hop_size=audio_cfg.hop_size,
692
- mel_bins=audio_cfg.mel_bins,
693
- fmin=audio_cfg.fmin,
694
- fmax=audio_cfg.fmax,
695
- classes_num=audio_cfg.class_num,
696
- enable_fusion=enable_fusion,
697
- fusion_type=fusion_type,
698
- )
699
- return model
700
- except:
701
- raise RuntimeError(
702
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
703
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/pretrained.py DELETED
@@ -1,167 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
-
6
- from tqdm import tqdm
7
-
8
- _RN50 = dict(
9
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
12
- )
13
-
14
- _RN50_quickgelu = dict(
15
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
18
- )
19
-
20
- _RN101 = dict(
21
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
23
- )
24
-
25
- _RN101_quickgelu = dict(
26
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
28
- )
29
-
30
- _RN50x4 = dict(
31
- openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32
- )
33
-
34
- _RN50x16 = dict(
35
- openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36
- )
37
-
38
- _RN50x64 = dict(
39
- openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
- )
41
-
42
- _VITB32 = dict(
43
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47
- )
48
-
49
- _VITB32_quickgelu = dict(
50
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54
- )
55
-
56
- _VITB16 = dict(
57
- openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58
- )
59
-
60
- _VITL14 = dict(
61
- openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62
- )
63
-
64
- _PRETRAINED = {
65
- "RN50": _RN50,
66
- "RN50-quickgelu": _RN50_quickgelu,
67
- "RN101": _RN101,
68
- "RN101-quickgelu": _RN101_quickgelu,
69
- "RN50x4": _RN50x4,
70
- "RN50x16": _RN50x16,
71
- "ViT-B-32": _VITB32,
72
- "ViT-B-32-quickgelu": _VITB32_quickgelu,
73
- "ViT-B-16": _VITB16,
74
- "ViT-L-14": _VITL14,
75
- }
76
-
77
-
78
- def list_pretrained(as_str: bool = False):
79
- """returns list of pretrained models
80
- Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81
- """
82
- return [
83
- ":".join([k, t]) if as_str else (k, t)
84
- for k in _PRETRAINED.keys()
85
- for t in _PRETRAINED[k].keys()
86
- ]
87
-
88
-
89
- def list_pretrained_tag_models(tag: str):
90
- """return all models having the specified pretrain tag"""
91
- models = []
92
- for k in _PRETRAINED.keys():
93
- if tag in _PRETRAINED[k]:
94
- models.append(k)
95
- return models
96
-
97
-
98
- def list_pretrained_model_tags(model: str):
99
- """return all pretrain tags for the specified model architecture"""
100
- tags = []
101
- if model in _PRETRAINED:
102
- tags.extend(_PRETRAINED[model].keys())
103
- return tags
104
-
105
-
106
- def get_pretrained_url(model: str, tag: str):
107
- if model not in _PRETRAINED:
108
- return ""
109
- model_pretrained = _PRETRAINED[model]
110
- if tag not in model_pretrained:
111
- return ""
112
- return model_pretrained[tag]
113
-
114
-
115
- def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
116
- os.makedirs(root, exist_ok=True)
117
- filename = os.path.basename(url)
118
-
119
- if "openaipublic" in url:
120
- expected_sha256 = url.split("/")[-2]
121
- else:
122
- expected_sha256 = ""
123
-
124
- download_target = os.path.join(root, filename)
125
-
126
- if os.path.exists(download_target) and not os.path.isfile(download_target):
127
- raise RuntimeError(f"{download_target} exists and is not a regular file")
128
-
129
- if os.path.isfile(download_target):
130
- if expected_sha256:
131
- if (
132
- hashlib.sha256(open(download_target, "rb").read()).hexdigest()
133
- == expected_sha256
134
- ):
135
- return download_target
136
- else:
137
- warnings.warn(
138
- f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
139
- )
140
- else:
141
- return download_target
142
-
143
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
144
- with tqdm(
145
- total=int(source.info().get("Content-Length")),
146
- ncols=80,
147
- unit="iB",
148
- unit_scale=True,
149
- ) as loop:
150
- while True:
151
- buffer = source.read(8192)
152
- if not buffer:
153
- break
154
-
155
- output.write(buffer)
156
- loop.update(len(buffer))
157
-
158
- if (
159
- expected_sha256
160
- and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
161
- != expected_sha256
162
- ):
163
- raise RuntimeError(
164
- f"Model has been downloaded but the SHA256 checksum does not not match"
165
- )
166
-
167
- return download_target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/timm_model.py DELETED
@@ -1,112 +0,0 @@
1
- """ timm model adapter
2
-
3
- Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
- """
5
- from collections import OrderedDict
6
-
7
- import torch.nn as nn
8
-
9
- try:
10
- import timm
11
- from timm.models.layers import Mlp, to_2tuple
12
- from timm.models.layers.attention_pool2d import RotAttentionPool2d
13
- from timm.models.layers.attention_pool2d import (
14
- AttentionPool2d as AbsAttentionPool2d,
15
- )
16
- except ImportError as e:
17
- timm = None
18
-
19
- from .utils import freeze_batch_norm_2d
20
-
21
-
22
- class TimmModel(nn.Module):
23
- """timm model adapter
24
- # FIXME this adapter is a work in progress, may change in ways that break weight compat
25
- """
26
-
27
- def __init__(
28
- self,
29
- model_name,
30
- embed_dim,
31
- image_size=224,
32
- pool="avg",
33
- proj="linear",
34
- drop=0.0,
35
- pretrained=False,
36
- ):
37
- super().__init__()
38
- if timm is None:
39
- raise RuntimeError("Please `pip install timm` to use timm models.")
40
-
41
- self.image_size = to_2tuple(image_size)
42
- self.trunk = timm.create_model(model_name, pretrained=pretrained)
43
- feat_size = self.trunk.default_cfg.get("pool_size", None)
44
- feature_ndim = 1 if not feat_size else 2
45
- if pool in ("abs_attn", "rot_attn"):
46
- assert feature_ndim == 2
47
- # if attn pooling used, remove both classifier and default pool
48
- self.trunk.reset_classifier(0, global_pool="")
49
- else:
50
- # reset global pool if pool config set, otherwise leave as network default
51
- reset_kwargs = dict(global_pool=pool) if pool else {}
52
- self.trunk.reset_classifier(0, **reset_kwargs)
53
- prev_chs = self.trunk.num_features
54
-
55
- head_layers = OrderedDict()
56
- if pool == "abs_attn":
57
- head_layers["pool"] = AbsAttentionPool2d(
58
- prev_chs, feat_size=feat_size, out_features=embed_dim
59
- )
60
- prev_chs = embed_dim
61
- elif pool == "rot_attn":
62
- head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63
- prev_chs = embed_dim
64
- else:
65
- assert proj, "projection layer needed if non-attention pooling is used."
66
-
67
- # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68
- if proj == "linear":
69
- head_layers["drop"] = nn.Dropout(drop)
70
- head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71
- elif proj == "mlp":
72
- head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73
-
74
- self.head = nn.Sequential(head_layers)
75
-
76
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77
- """lock modules
78
- Args:
79
- unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80
- """
81
- if not unlocked_groups:
82
- # lock full model
83
- for param in self.trunk.parameters():
84
- param.requires_grad = False
85
- if freeze_bn_stats:
86
- freeze_batch_norm_2d(self.trunk)
87
- else:
88
- # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89
- try:
90
- # FIXME import here until API stable and in an official release
91
- from timm.models.helpers import group_parameters, group_modules
92
- except ImportError:
93
- raise RuntimeError(
94
- "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95
- )
96
- matcher = self.trunk.group_matcher()
97
- gparams = group_parameters(self.trunk, matcher)
98
- max_layer_id = max(gparams.keys())
99
- max_layer_id = max_layer_id - unlocked_groups
100
- for group_idx in range(max_layer_id + 1):
101
- group = gparams[group_idx]
102
- for param in group:
103
- self.trunk.get_parameter(param).requires_grad = False
104
- if freeze_bn_stats:
105
- gmodules = group_modules(self.trunk, matcher, reverse=True)
106
- gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107
- freeze_batch_norm_2d(self.trunk, gmodules)
108
-
109
- def forward(self, x):
110
- x = self.trunk(x)
111
- x = self.head(x)
112
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/tokenizer.py DELETED
@@ -1,197 +0,0 @@
1
- """ CLIP tokenizer
2
-
3
- Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- import gzip
6
- import html
7
- import os
8
- from functools import lru_cache
9
- from typing import Union, List
10
-
11
- import ftfy
12
- import regex as re
13
- import torch
14
-
15
-
16
- @lru_cache()
17
- def default_bpe():
18
- return os.path.join(
19
- os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20
- )
21
-
22
-
23
- @lru_cache()
24
- def bytes_to_unicode():
25
- """
26
- Returns list of utf-8 byte and a corresponding list of unicode strings.
27
- The reversible bpe codes work on unicode strings.
28
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30
- This is a signficant percentage of your normal, say, 32K bpe vocab.
31
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32
- And avoids mapping to whitespace/control characters the bpe code barfs on.
33
- """
34
- bs = (
35
- list(range(ord("!"), ord("~") + 1))
36
- + list(range(ord("¡"), ord("¬") + 1))
37
- + list(range(ord("®"), ord("ÿ") + 1))
38
- )
39
- cs = bs[:]
40
- n = 0
41
- for b in range(2**8):
42
- if b not in bs:
43
- bs.append(b)
44
- cs.append(2**8 + n)
45
- n += 1
46
- cs = [chr(n) for n in cs]
47
- return dict(zip(bs, cs))
48
-
49
-
50
- def get_pairs(word):
51
- """Return set of symbol pairs in a word.
52
- Word is represented as tuple of symbols (symbols being variable-length strings).
53
- """
54
- pairs = set()
55
- prev_char = word[0]
56
- for char in word[1:]:
57
- pairs.add((prev_char, char))
58
- prev_char = char
59
- return pairs
60
-
61
-
62
- def basic_clean(text):
63
- text = ftfy.fix_text(text)
64
- text = html.unescape(html.unescape(text))
65
- return text.strip()
66
-
67
-
68
- def whitespace_clean(text):
69
- text = re.sub(r"\s+", " ", text)
70
- text = text.strip()
71
- return text
72
-
73
-
74
- class SimpleTokenizer(object):
75
- def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76
- self.byte_encoder = bytes_to_unicode()
77
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
- merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79
- merges = merges[1 : 49152 - 256 - 2 + 1]
80
- merges = [tuple(merge.split()) for merge in merges]
81
- vocab = list(bytes_to_unicode().values())
82
- vocab = vocab + [v + "</w>" for v in vocab]
83
- for merge in merges:
84
- vocab.append("".join(merge))
85
- if not special_tokens:
86
- special_tokens = ["<start_of_text>", "<end_of_text>"]
87
- else:
88
- special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
89
- vocab.extend(special_tokens)
90
- self.encoder = dict(zip(vocab, range(len(vocab))))
91
- self.decoder = {v: k for k, v in self.encoder.items()}
92
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
93
- self.cache = {t: t for t in special_tokens}
94
- special = "|".join(special_tokens)
95
- self.pat = re.compile(
96
- special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97
- re.IGNORECASE,
98
- )
99
-
100
- self.vocab_size = len(self.encoder)
101
- self.all_special_ids = [self.encoder[t] for t in special_tokens]
102
-
103
- def bpe(self, token):
104
- if token in self.cache:
105
- return self.cache[token]
106
- word = tuple(token[:-1]) + (token[-1] + "</w>",)
107
- pairs = get_pairs(word)
108
-
109
- if not pairs:
110
- return token + "</w>"
111
-
112
- while True:
113
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
- if bigram not in self.bpe_ranks:
115
- break
116
- first, second = bigram
117
- new_word = []
118
- i = 0
119
- while i < len(word):
120
- try:
121
- j = word.index(first, i)
122
- new_word.extend(word[i:j])
123
- i = j
124
- except:
125
- new_word.extend(word[i:])
126
- break
127
-
128
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
- new_word.append(first + second)
130
- i += 2
131
- else:
132
- new_word.append(word[i])
133
- i += 1
134
- new_word = tuple(new_word)
135
- word = new_word
136
- if len(word) == 1:
137
- break
138
- else:
139
- pairs = get_pairs(word)
140
- word = " ".join(word)
141
- self.cache[token] = word
142
- return word
143
-
144
- def encode(self, text):
145
- bpe_tokens = []
146
- text = whitespace_clean(basic_clean(text)).lower()
147
- for token in re.findall(self.pat, text):
148
- token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149
- bpe_tokens.extend(
150
- self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151
- )
152
- return bpe_tokens
153
-
154
- def decode(self, tokens):
155
- text = "".join([self.decoder[token] for token in tokens])
156
- text = (
157
- bytearray([self.byte_decoder[c] for c in text])
158
- .decode("utf-8", errors="replace")
159
- .replace("</w>", " ")
160
- )
161
- return text
162
-
163
-
164
- _tokenizer = SimpleTokenizer()
165
-
166
-
167
- def tokenize(
168
- texts: Union[str, List[str]], context_length: int = 77
169
- ) -> torch.LongTensor:
170
- """
171
- Returns the tokenized representation of given input string(s)
172
-
173
- Parameters
174
- ----------
175
- texts : Union[str, List[str]]
176
- An input string or a list of input strings to tokenize
177
- context_length : int
178
- The context length to use; all CLIP models use 77 as the context length
179
-
180
- Returns
181
- -------
182
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183
- """
184
- if isinstance(texts, str):
185
- texts = [texts]
186
-
187
- sot_token = _tokenizer.encoder["<start_of_text>"]
188
- eot_token = _tokenizer.encoder["<end_of_text>"]
189
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191
-
192
- for i, tokens in enumerate(all_tokens):
193
- if len(tokens) > context_length:
194
- tokens = tokens[:context_length] # Truncate
195
- result[i, : len(tokens)] = torch.tensor(tokens)
196
-
197
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/transform.py DELETED
@@ -1,45 +0,0 @@
1
- from torchvision.transforms import (
2
- Normalize,
3
- Compose,
4
- RandomResizedCrop,
5
- InterpolationMode,
6
- ToTensor,
7
- Resize,
8
- CenterCrop,
9
- )
10
-
11
-
12
- def _convert_to_rgb(image):
13
- return image.convert("RGB")
14
-
15
-
16
- def image_transform(
17
- image_size: int,
18
- is_train: bool,
19
- mean=(0.48145466, 0.4578275, 0.40821073),
20
- std=(0.26862954, 0.26130258, 0.27577711),
21
- ):
22
- normalize = Normalize(mean=mean, std=std)
23
- if is_train:
24
- return Compose(
25
- [
26
- RandomResizedCrop(
27
- image_size,
28
- scale=(0.9, 1.0),
29
- interpolation=InterpolationMode.BICUBIC,
30
- ),
31
- _convert_to_rgb,
32
- ToTensor(),
33
- normalize,
34
- ]
35
- )
36
- else:
37
- return Compose(
38
- [
39
- Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40
- CenterCrop(image_size),
41
- _convert_to_rgb,
42
- ToTensor(),
43
- normalize,
44
- ]
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/utils.py DELETED
@@ -1,361 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from torch import nn as nn
4
- from torchvision.ops.misc import FrozenBatchNorm2d
5
- import logging
6
- # import h5py
7
- from tqdm import tqdm
8
- import random
9
- import json
10
- import os
11
- import pathlib
12
-
13
- # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
14
- dataset_split = {
15
- "audiocaps": ["train", "valid", "test"],
16
- "audioset": ["balanced_train", "unbalanced_train", "eval"],
17
- "BBCSoundEffects": ["train", "test"],
18
- "Clotho": ["train", "test", "valid"],
19
- "free_to_use_sounds": ["train", "test"],
20
- "paramount_motion": ["train", "test"],
21
- "sonniss_game_effects": ["train", "test"],
22
- "wesoundeffects": ["train", "test"],
23
- "MACS": ["train", "test"],
24
- "freesound": ["train", "test"],
25
- "FSD50K": ["train", "test", "valid"],
26
- "fsd50k_class_label": ["train", "test", "valid"],
27
- "esc50": ["train", "test"],
28
- "audiostock": ["train", "test"],
29
- "freesound_no_overlap_noesc50": ["train", "test"],
30
- "epidemic_sound_effects": ["train", "test"],
31
- "VGGSound": ["train", "test"],
32
- "urbansound8k_class_label": ["train", "test"],
33
- "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
34
- "epidemic_sound_effects_t5": ["train", "test"],
35
- "WavText5K": ["train", "test"],
36
- "esc50_no_overlap": ["train", "test"],
37
- "usd8k_no_overlap": ["train", "test"],
38
- "fsd50k_200_class_label": ["train", "test", "valid"],
39
- }
40
-
41
-
42
- def freeze_batch_norm_2d(module, module_match={}, name=""):
43
- """
44
- Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
45
- itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
46
- returned. Otherwise, the module is walked recursively and submodules are converted in place.
47
-
48
- Args:
49
- module (torch.nn.Module): Any PyTorch module.
50
- module_match (dict): Dictionary of full module names to freeze (all if empty)
51
- name (str): Full module name (prefix)
52
-
53
- Returns:
54
- torch.nn.Module: Resulting module
55
-
56
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
57
- """
58
- res = module
59
- is_match = True
60
- if module_match:
61
- is_match = name in module_match
62
- if is_match and isinstance(
63
- module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
64
- ):
65
- res = FrozenBatchNorm2d(module.num_features)
66
- res.num_features = module.num_features
67
- res.affine = module.affine
68
- if module.affine:
69
- res.weight.data = module.weight.data.clone().detach()
70
- res.bias.data = module.bias.data.clone().detach()
71
- res.running_mean.data = module.running_mean.data
72
- res.running_var.data = module.running_var.data
73
- res.eps = module.eps
74
- else:
75
- for child_name, child in module.named_children():
76
- full_child_name = ".".join([name, child_name]) if name else child_name
77
- new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
78
- if new_child is not child:
79
- res.add_module(child_name, new_child)
80
- return res
81
-
82
-
83
- def exist(dataset_name, dataset_type):
84
- """
85
- Check if dataset exists
86
- """
87
- if dataset_type in dataset_split[dataset_name]:
88
- return True
89
- else:
90
- return False
91
-
92
-
93
- def get_tar_path_from_dataset_name(
94
- dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
95
- ):
96
- """
97
- Get tar path from dataset name and type
98
- """
99
- output = []
100
- for n in dataset_names:
101
- if full_dataset is not None and n in full_dataset:
102
- current_dataset_types = dataset_split[n]
103
- else:
104
- current_dataset_types = dataset_types
105
- for s in current_dataset_types:
106
- tmp = []
107
- if islocal:
108
- sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
109
- if not os.path.exists(sizefilepath_):
110
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
111
- else:
112
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
113
- if not os.path.exists(sizefilepath_):
114
- continue
115
- sizes = json.load(open(sizefilepath_, "r"))
116
- for k in sizes.keys():
117
- if islocal:
118
- tmp.append(f"{dataset_path}/{n}/{s}/{k}")
119
- else:
120
- tmp.append(
121
- f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
122
- )
123
- if proportion != 1:
124
- tmp = random.sample(tmp, int(proportion * len(tmp)))
125
- output.append(tmp)
126
- return sum(output, [])
127
-
128
-
129
- def get_tar_path_from_txts(txt_path, islocal, proportion=1):
130
- """
131
- Get tar path from txt path
132
- """
133
- if isinstance(txt_path, (list, tuple)):
134
- return sum(
135
- [
136
- get_tar_path_from_txts(
137
- txt_path[i], islocal=islocal, proportion=proportion
138
- )
139
- for i in range(len(txt_path))
140
- ],
141
- [],
142
- )
143
- if isinstance(txt_path, str):
144
- with open(txt_path) as f:
145
- lines = f.readlines()
146
- if islocal:
147
- lines = [
148
- lines[i]
149
- .split("\n")[0]
150
- .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
151
- for i in range(len(lines))
152
- ]
153
- else:
154
- lines = [
155
- lines[i].split("\n")[0].replace(".tar", ".tar -")
156
- for i in range(len(lines))
157
- ]
158
- if proportion != 1:
159
- print("Sampling tars with proportion of {}".format(proportion))
160
- lines = random.sample(lines, int(proportion * len(lines)))
161
- return lines
162
-
163
-
164
- def get_mix_lambda(mixup_alpha, batch_size):
165
- mixup_lambdas = [
166
- np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
167
- ]
168
- return np.array(mixup_lambdas).astype(np.float32)
169
-
170
-
171
- def do_mixup(x, mixup_lambda):
172
- """
173
- Args:
174
- x: (batch_size , ...)
175
- mixup_lambda: (batch_size,)
176
- Returns:
177
- out: (batch_size, ...)
178
- """
179
- out = (
180
- x.transpose(0, -1) * mixup_lambda
181
- + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
182
- ).transpose(0, -1)
183
- return out
184
-
185
-
186
- def interpolate(x, ratio):
187
- """Interpolate data in time domain. This is used to compensate the
188
- resolution reduction in downsampling of a CNN.
189
-
190
- Args:
191
- x: (batch_size, time_steps, classes_num)
192
- ratio: int, ratio to interpolate
193
- Returns:
194
- upsampled: (batch_size, time_steps * ratio, classes_num)
195
- """
196
- (batch_size, time_steps, classes_num) = x.shape
197
- upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
198
- upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
199
- return upsampled
200
-
201
-
202
- def pad_framewise_output(framewise_output, frames_num):
203
- """Pad framewise_output to the same length as input frames. The pad value
204
- is the same as the value of the last frame.
205
- Args:
206
- framewise_output: (batch_size, frames_num, classes_num)
207
- frames_num: int, number of frames to pad
208
- Outputs:
209
- output: (batch_size, frames_num, classes_num)
210
- """
211
- pad = framewise_output[:, -1:, :].repeat(
212
- 1, frames_num - framewise_output.shape[1], 1
213
- )
214
- """tensor for padding"""
215
-
216
- output = torch.cat((framewise_output, pad), dim=1)
217
- """(batch_size, frames_num, classes_num)"""
218
-
219
-
220
- # def process_ipc(index_path, classes_num, filename):
221
- # # load data
222
- # logging.info("Load Data...............")
223
- # ipc = [[] for _ in range(classes_num)]
224
- # with h5py.File(index_path, "r") as f:
225
- # for i in tqdm(range(len(f["target"]))):
226
- # t_class = np.where(f["target"][i])[0]
227
- # for t in t_class:
228
- # ipc[t].append(i)
229
- # print(ipc)
230
- # np.save(filename, ipc)
231
- # logging.info("Load Data Succeed...............")
232
-
233
-
234
- def save_to_dict(s, o_={}):
235
- sp = s.split(": ")
236
- o_.update({sp[0]: float(sp[1])})
237
- return o_
238
-
239
-
240
- def get_data_from_log(txt_path):
241
- """
242
- Output dictionary from out.txt log file
243
- """
244
- with open(txt_path) as f:
245
- lines = f.readlines()
246
- val_data = {}
247
- train_data = {}
248
- train_losses = []
249
- train_losses_epoch = []
250
- for i in range(len(lines)):
251
- if "| INFO |" in lines[i]:
252
- if "Eval Epoch" in lines[i]:
253
- if "val_loss" in lines[i]:
254
- # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
255
- line = lines[i].split("Eval Epoch: ")[-1]
256
- num_epoch = int(line.split(" ")[0].split(" ")[0])
257
- d = {
258
- line.split(" ")[0]
259
- .split(" ")[1]
260
- .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
261
- }
262
- for i in range(1, len(line.split(" "))):
263
- d = save_to_dict(line.split(" ")[i], d)
264
- val_data[num_epoch] = d
265
- elif "Train Epoch" in lines[i]:
266
- num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
267
- loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
268
- train_losses.append(loss)
269
- train_losses_epoch.append(num_epoch)
270
- for i in range(len(train_losses)):
271
- train_data[i] = {
272
- "num_epoch": train_losses_epoch[i],
273
- "train_loss": train_losses[i],
274
- }
275
- return train_data, val_data
276
-
277
-
278
- def save_p(obj, filename):
279
- import pickle
280
-
281
- try:
282
- from deepdiff import DeepDiff
283
- except:
284
- os.system("pip install deepdiff")
285
- from deepdiff import DeepDiff
286
- with open(filename, "wb") as file:
287
- pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
288
- with open(filename, "rb") as file:
289
- z = pickle.load(file)
290
- assert (
291
- DeepDiff(obj, z, ignore_string_case=True) == {}
292
- ), "there is something wrong with the saving process"
293
- return
294
-
295
-
296
- def load_p(filename):
297
- import pickle
298
-
299
- with open(filename, "rb") as file:
300
- z = pickle.load(file)
301
- return z
302
-
303
-
304
- def save_json(data, name="data.json"):
305
- import json
306
-
307
- with open(name, "w") as fp:
308
- json.dump(data, fp)
309
- return
310
-
311
-
312
- def load_json(name):
313
- import json
314
-
315
- with open(name, "r") as fp:
316
- data = json.load(fp)
317
- return data
318
-
319
-
320
- from multiprocessing import Process, Manager
321
- from multiprocessing import Process, Value, Array
322
- from ctypes import c_wchar
323
-
324
-
325
- def load_class_label(path):
326
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
327
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
328
- out = None
329
- if path is not None:
330
- if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
331
- out = load_p(path)
332
- elif pathlib.Path(path).suffix in [".json", ".txt"]:
333
- out = load_json(path)
334
- elif pathlib.Path(path).suffix in [".npy", ".npz"]:
335
- out = np.load(path)
336
- elif pathlib.Path(path).suffix in [".csv"]:
337
- import pandas as pd
338
-
339
- out = pd.read_csv(path)
340
- return out
341
- # if out is None:
342
- # return None
343
- # else:
344
- # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
345
- # val = Array('i', out.values(), lock=False)
346
- # return (key, val)
347
-
348
-
349
- from torch import optim
350
-
351
-
352
- def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
353
- if optimizer_name.lower() == "adamw":
354
- optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
355
- elif optimizer_name.lower() == "sgd":
356
- optimizer = optim.SGD(params, lr=lr, momentum=momentum)
357
- elif optimizer_name.lower() == "adam":
358
- optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
359
- else:
360
- raise ValueError("optimizer name is not correct")
361
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/version.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "0.2.1"
 
 
audioldm/clap/training/__init__.py DELETED
File without changes
audioldm/clap/training/audioset_textmap.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
3
- size 84448
 
 
 
 
audioldm/clap/training/data.py DELETED
@@ -1,977 +0,0 @@
1
- import ast
2
- import json
3
- import logging
4
- import math
5
- import os
6
- import random
7
- # import h5py
8
- from dataclasses import dataclass
9
- from audioldm.clap.training.params import parse_args
10
- # import braceexpand
11
- import numpy as np
12
- import pandas as pd
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- import torchvision.datasets as datasets
17
- import torchvision.transforms
18
- # import webdataset as wds
19
- from PIL import Image
20
- from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
21
- from torch.utils.data.distributed import DistributedSampler
22
- from functools import partial
23
- import soundfile as sf
24
- import io
25
- from pathlib import Path
26
- # import wget
27
-
28
- from audioldm.clap.open_clip.utils import (
29
- get_tar_path_from_dataset_name,
30
- dataset_split,
31
- )
32
- from audioldm.clap.open_clip.utils import load_p, load_class_label
33
- import copy
34
-
35
- try:
36
- import horovod.torch as hvd
37
- except ImportError:
38
- hvd = None
39
-
40
- try:
41
- import torchaudio
42
- except ImportError:
43
- torchaudio = None
44
-
45
- from audioldm.clap.open_clip import tokenize
46
-
47
-
48
- def tokenizer(text):
49
- return tokenize(text).squeeze(0)
50
-
51
-
52
- from transformers import RobertaTokenizer
53
-
54
- tokenize = RobertaTokenizer.from_pretrained("roberta-base")
55
-
56
-
57
- def tokenizer(text):
58
- result = tokenize(
59
- text,
60
- padding="max_length",
61
- truncation=True,
62
- max_length=77,
63
- return_tensors="pt",
64
- )
65
- return {k: v.squeeze(0) for k, v in result.items()}
66
-
67
-
68
- # initizlied the audioset map
69
- _AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
70
- _AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
71
-
72
-
73
- def int16_to_float32(x):
74
- return (x / 32767.0).astype(np.float32)
75
-
76
-
77
- def float32_to_int16(x):
78
- x = np.clip(x, a_min=-1.0, a_max=1.0)
79
- return (x * 32767.0).astype(np.int16)
80
-
81
-
82
- # For Toy Dataset
83
- # class ToyDataset(Dataset):
84
- # def __init__(self, index_path, ipc, config, eval_mode=False):
85
- # """Toy Dataset for testing the audioset input with text labels
86
- # Parameters
87
- # ----------
88
- # index_path: str
89
- # the link to the h5 file of each audio
90
- # idc: str
91
- # the link to the npy file, the number of samples in each class
92
- # config: dict
93
- # the audio cfg file
94
- # eval_model (bool): to indicate if the dataset is a testing dataset
95
- # """
96
- # self.audio_cfg = config["audio_cfg"]
97
- # self.text_cfg = config["text_cfg"]
98
- # self.fp = h5py.File(index_path, "r")
99
- # self.ipc = np.load(ipc, allow_pickle=True)
100
- # self.total_size = len(self.fp["audio_name"])
101
- # self.classes_num = self.audio_cfg["class_num"]
102
- # self.eval_mode = eval_mode
103
-
104
- # if not eval_mode:
105
- # self.generate_queue()
106
- # else:
107
- # self.queue = []
108
- # for i in range(self.total_size):
109
- # target = self.fp["target"][i]
110
- # if np.sum(target) > 0:
111
- # self.queue.append(i)
112
- # self.total_size = len(self.queue)
113
- # logging.info("total dataset size: %d" % (self.total_size))
114
- # logging.info("class num: %d" % (self.classes_num))
115
-
116
- # def time_shifting(self, x):
117
- # frame_num = len(x)
118
- # shift_len = random.randint(0, frame_num - 1)
119
- # new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
120
- # return new_sample
121
-
122
- # def generate_queue(self):
123
- # self.queue = []
124
- # while len(self.queue) < self.total_size:
125
- # class_set = [*range(self.classes_num)]
126
- # random.shuffle(class_set)
127
- # self.queue += [
128
- # self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
129
- # ]
130
- # self.queue = self.queue[: self.total_size]
131
-
132
- # logging.info("queue regenerated:%s" % (self.queue[-5:]))
133
-
134
- # def crop_wav(self, x):
135
- # crop_size = self.audio_cfg["crop_size"]
136
- # crop_pos = random.randint(0, len(x) - crop_size - 1)
137
- # return x[crop_pos : crop_pos + crop_size]
138
-
139
- # def prompt_text(self, target):
140
- # events = _AUDIOSET_MAP[np.where(target > 0)]
141
- # event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
142
- # text = tokenize(event_text)[0]
143
- # return text
144
-
145
- # def __getitem__(self, index):
146
- # """Load waveform, text, and target of an audio clip
147
-
148
- # Parameters
149
- # ----------
150
- # index: int
151
- # the index number
152
- # Return
153
- # ------
154
- # output: dict {
155
- # "hdf5_path": str,
156
- # "index_in_hdf5": int,
157
- # "audio_name": str,
158
- # "waveform": list (audio_length,),
159
- # "target": list (class_num, ),
160
- # "text": torch.tensor (context_length,)
161
- # }
162
- # the output dictionary
163
- # """
164
- # s_index = self.queue[index]
165
-
166
- # audio_name = self.fp["audio_name"][s_index].decode()
167
- # # Hardcode here CHANGE
168
- # hdf5_path = (
169
- # self.fp["hdf5_path"][s_index]
170
- # .decode()
171
- # .replace(
172
- # "../workspace",
173
- # "/home/la/kechen/Research/ke_zsasp/workspace",
174
- # )
175
- # )
176
- # r_idx = self.fp["index_in_hdf5"][s_index]
177
- # target = self.fp["target"][s_index].astype(np.float32)
178
- # text = self.prompt_text(target)
179
- # with h5py.File(hdf5_path, "r") as f:
180
- # waveform = int16_to_float32(f["waveform"][r_idx])[
181
- # : self.audio_cfg["clip_samples"]
182
- # ]
183
- # assert (
184
- # len(waveform) == self.audio_cfg["clip_samples"]
185
- # ), "The sample length is not match"
186
- # # Time shift
187
- # # if (self.config.enable_time_shift) and (not self.eval_mode):
188
- # # waveform = self.time_shifting(waveform)
189
- # # # Label Enhance
190
- # # if (self.config.crop_size is not None) and (not self.eval_mode):
191
- # # waveform = self.crop_wav(waveform)
192
- # # # the label enhance rate is fixed 0.5
193
- # # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
194
- # # kidx = np.where(target)[0]
195
- # # for k in kidx:
196
- # # for add_key in self.class_map[k][1]:
197
- # # target[add_key] = 1.0
198
- # # if len(self.class_map[k][2]) > 0:
199
- # # add_key = random.choice(self.class_map[k][2])
200
- # # target[add_key] = 1.0
201
-
202
- # # missing the text input
203
- # mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
204
- # mel_spec = (
205
- # torch.cat(
206
- # [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
207
- # )
208
- # .cpu()
209
- # .numpy()
210
- # )
211
- # longer = random.choice([True, False])
212
- # if longer == False:
213
- # mel_spec[1:, :, :] = 0.0
214
- # data_dict = {
215
- # "hdf5_path": hdf5_path,
216
- # "index_in_hdf5": r_idx,
217
- # "audio_name": audio_name,
218
- # "waveform": waveform,
219
- # "class_label": target,
220
- # "text": text,
221
- # "longer": longer,
222
- # "mel_fusion": mel_spec,
223
- # }
224
- # return data_dict
225
-
226
- # def __len__(self):
227
- # return self.total_size
228
-
229
-
230
- class CsvDataset(Dataset):
231
- def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
232
- logging.debug(f"Loading csv data from {input_filename}.")
233
- df = pd.read_csv(input_filename, sep=sep)
234
-
235
- self.images = df[img_key].tolist()
236
- self.captions = df[caption_key].tolist()
237
- self.transforms = transforms
238
- logging.debug("Done loading data.")
239
-
240
- def __len__(self):
241
- return len(self.captions)
242
-
243
- def __getitem__(self, idx):
244
- images = self.transforms(Image.open(str(self.images[idx])))
245
- texts = tokenize([str(self.captions[idx])])[0]
246
- return images, texts
247
-
248
-
249
- @dataclass
250
- class DataInfo:
251
- dataloader: DataLoader
252
- sampler: DistributedSampler
253
-
254
-
255
- def preprocess_txt(text):
256
- return tokenize([str(text)])[0]
257
-
258
-
259
- def get_dataset_size(shards, sizefilepath_=None, is_local=True):
260
- if isinstance(shards, list):
261
- size_list = []
262
- for s in shards:
263
- size_list.append(
264
- get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
265
- )
266
- else:
267
- if not is_local:
268
- for n in dataset_split.keys():
269
- if n in shards.split("/"):
270
- break
271
- for s in dataset_split[n]:
272
- if s in shards.split("/"):
273
- break
274
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
275
- shards_list = list(braceexpand.braceexpand(shards))
276
- dir_path = os.path.dirname(shards)
277
- if sizefilepath_ is not None:
278
- sizes = json.load(open(sizefilepath_, "r"))
279
- total_size = sum(
280
- [
281
- int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
282
- for shard in shards_list
283
- ]
284
- )
285
- else:
286
- sizes_filename = os.path.join(dir_path, "sizes.json")
287
- len_filename = os.path.join(dir_path, "__len__")
288
- if os.path.exists(sizes_filename):
289
- sizes = json.load(open(sizes_filename, "r"))
290
- total_size = sum(
291
- [int(sizes[os.path.basename(shard)]) for shard in shards_list]
292
- )
293
- elif os.path.exists(len_filename):
294
- # FIXME this used to be eval(open(...)) but that seemed rather unsafe
295
- total_size = ast.literal_eval(open(len_filename, "r").read())
296
- else:
297
- raise Exception(
298
- "Cannot find sizes file for dataset. Please specify the path to the file."
299
- )
300
- # total_size = None # num samples undefined
301
- # some common dataset sizes (at time of authors last download)
302
- # cc3m-train: 2905954
303
- # cc12m: 10968539
304
- # LAION-400m: 407332084
305
- num_shards = len(shards_list)
306
- if isinstance(shards, list):
307
- return sum(size_list), len(shards)
308
- else:
309
- return total_size, num_shards
310
-
311
-
312
- def get_imagenet(args, preprocess_fns, split):
313
- assert split in ["train", "val", "v2"]
314
- is_train = split == "train"
315
- preprocess_train, preprocess_val = preprocess_fns
316
-
317
- if split == "v2":
318
- from imagenetv2_pytorch import ImageNetV2Dataset
319
-
320
- dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
321
- else:
322
- if is_train:
323
- data_path = args.imagenet_train
324
- preprocess_fn = preprocess_train
325
- else:
326
- data_path = args.imagenet_val
327
- preprocess_fn = preprocess_val
328
- assert data_path
329
-
330
- dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
331
-
332
- if is_train:
333
- idxs = np.zeros(len(dataset.targets))
334
- target_array = np.array(dataset.targets)
335
- k = 50
336
- for c in range(1000):
337
- m = target_array == c
338
- n = len(idxs[m])
339
- arr = np.zeros(n)
340
- arr[:k] = 1
341
- np.random.shuffle(arr)
342
- idxs[m] = arr
343
-
344
- idxs = idxs.astype("int")
345
- sampler = SubsetRandomSampler(np.where(idxs)[0])
346
- else:
347
- sampler = None
348
-
349
- dataloader = torch.utils.data.DataLoader(
350
- dataset,
351
- batch_size=args.batch_size,
352
- num_workers=args.workers,
353
- sampler=sampler,
354
- )
355
-
356
- return DataInfo(dataloader, sampler)
357
-
358
-
359
- def count_samples(dataloader):
360
- os.environ["WDS_EPOCH"] = "0"
361
- n_elements, n_batches = 0, 0
362
- for images, texts in dataloader:
363
- n_batches += 1
364
- n_elements += len(images)
365
- assert len(images) == len(texts)
366
- return n_elements, n_batches
367
-
368
-
369
- def filter_no_caption(sample):
370
- return "txt" in sample
371
-
372
-
373
- def log_and_continue(exn):
374
- """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
375
- logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
376
- return True
377
-
378
-
379
- _SHARD_SHUFFLE_SIZE = 2000
380
- _SHARD_SHUFFLE_INITIAL = 500
381
- _SAMPLE_SHUFFLE_SIZE = 5000
382
- _SAMPLE_SHUFFLE_INITIAL = 1000
383
-
384
-
385
- def sample_prop(sizefile, inputs, proportion, is_local=True):
386
- """
387
- Sample a proportion of the data.
388
- """
389
- file_path_dict = {
390
- os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
391
- for i in range(len(inputs))
392
- }
393
- sampled_filepath_dict = {}
394
- sampled_size_dict = {}
395
- if not is_local:
396
- if os.path.exists("sizes.json"):
397
- os.remove("sizes.json")
398
- wget.download(sizefile, "sizes.json")
399
- sizefile = "sizes.json"
400
- with open(sizefile, "r", encoding="UTF-8") as f:
401
- load_dict = json.load(f)
402
- L = int(len(file_path_dict) * proportion)
403
- subkeys = random.sample(file_path_dict.keys(), L)
404
- for k in subkeys:
405
- sampled_size_dict[k] = load_dict[k]
406
- sampled_filepath_dict[k] = file_path_dict[k]
407
- return (
408
- sum(sampled_size_dict.values()),
409
- L,
410
- [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
411
- sampled_size_dict,
412
- )
413
-
414
-
415
- def get_mel(audio_data, audio_cfg):
416
- # mel shape: (n_mels, T)
417
- mel = torchaudio.transforms.MelSpectrogram(
418
- sample_rate=audio_cfg["sample_rate"],
419
- n_fft=audio_cfg["window_size"],
420
- win_length=audio_cfg["window_size"],
421
- hop_length=audio_cfg["hop_size"],
422
- center=True,
423
- pad_mode="reflect",
424
- power=2.0,
425
- norm=None,
426
- onesided=True,
427
- n_mels=64,
428
- f_min=audio_cfg["fmin"],
429
- f_max=audio_cfg["fmax"],
430
- ).to(audio_data.device)
431
- mel = mel(audio_data)
432
- # Align to librosa:
433
- # librosa_melspec = librosa.feature.melspectrogram(
434
- # waveform,
435
- # sr=audio_cfg['sample_rate'],
436
- # n_fft=audio_cfg['window_size'],
437
- # hop_length=audio_cfg['hop_size'],
438
- # win_length=audio_cfg['window_size'],
439
- # center=True,
440
- # pad_mode="reflect",
441
- # power=2.0,
442
- # n_mels=64,
443
- # norm=None,
444
- # htk=True,
445
- # f_min=audio_cfg['fmin'],
446
- # f_max=audio_cfg['fmax']
447
- # )
448
- # we use log mel spectrogram as input
449
- mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
450
- return mel.T # (T, n_mels)
451
-
452
-
453
- def get_audio_features(
454
- sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
455
- ):
456
- """
457
- Calculate and add audio features to sample.
458
- Sample: a dict containing all the data of current sample.
459
- audio_data: a tensor of shape (T) containing audio data.
460
- max_len: the maximum length of audio data.
461
- data_truncating: the method of truncating data.
462
- data_filling: the method of filling data.
463
- audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
464
- """
465
- with torch.no_grad():
466
- if len(audio_data) > max_len:
467
- if data_truncating == "rand_trunc":
468
- longer = torch.tensor([True])
469
- elif data_truncating == "fusion":
470
- # fusion
471
- mel = get_mel(audio_data, audio_cfg)
472
- # split to three parts
473
- chunk_frames = (
474
- max_len // audio_cfg["hop_size"] + 1
475
- ) # the +1 related to how the spectrogram is computed
476
- total_frames = mel.shape[0]
477
- if chunk_frames == total_frames:
478
- # there is a corner case where the audio length is
479
- # larger than max_len but smaller than max_len+hop_size.
480
- # In this case, we just use the whole audio.
481
- mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
482
- sample["mel_fusion"] = mel_fusion
483
- longer = torch.tensor([False])
484
- else:
485
- ranges = np.array_split(
486
- list(range(0, total_frames - chunk_frames + 1)), 3
487
- )
488
- # print('total_frames-chunk_frames:', total_frames-chunk_frames,
489
- # 'len(audio_data):', len(audio_data),
490
- # 'chunk_frames:', chunk_frames,
491
- # 'total_frames:', total_frames)
492
- if len(ranges[1]) == 0:
493
- # if the audio is too short, we just use the first chunk
494
- ranges[1] = [0]
495
- if len(ranges[2]) == 0:
496
- # if the audio is too short, we just use the first chunk
497
- ranges[2] = [0]
498
- # randomly choose index for each part
499
- idx_front = np.random.choice(ranges[0])
500
- idx_middle = np.random.choice(ranges[1])
501
- idx_back = np.random.choice(ranges[2])
502
- # select mel
503
- mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
504
- mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
505
- mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
506
-
507
- # shrink the mel
508
- mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(
509
- mel[None]
510
- )[0]
511
- # logging.info(f"mel_shrink.shape: {mel_shrink.shape}")
512
-
513
- # stack
514
- mel_fusion = torch.stack(
515
- [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink],
516
- dim=0,
517
- )
518
- sample["mel_fusion"] = mel_fusion
519
- longer = torch.tensor([True])
520
- else:
521
- raise NotImplementedError(
522
- f"data_truncating {data_truncating} not implemented"
523
- )
524
- # random crop to max_len (for compatibility)
525
- overflow = len(audio_data) - max_len
526
- idx = np.random.randint(0, overflow + 1)
527
- audio_data = audio_data[idx : idx + max_len]
528
-
529
- else: # padding if too short
530
- if len(audio_data) < max_len: # do nothing if equal
531
- if data_filling == "repeatpad":
532
- n_repeat = int(max_len / len(audio_data))
533
- audio_data = audio_data.repeat(n_repeat)
534
- # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0)
535
- # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0]
536
- audio_data = F.pad(
537
- audio_data,
538
- (0, max_len - len(audio_data)),
539
- mode="constant",
540
- value=0,
541
- )
542
- elif data_filling == "pad":
543
- audio_data = F.pad(
544
- audio_data,
545
- (0, max_len - len(audio_data)),
546
- mode="constant",
547
- value=0,
548
- )
549
- elif data_filling == "repeat":
550
- n_repeat = int(max_len / len(audio_data))
551
- audio_data = audio_data.repeat(n_repeat + 1)[:max_len]
552
- else:
553
- raise NotImplementedError(
554
- f"data_filling {data_filling} not implemented"
555
- )
556
- if data_truncating == "fusion":
557
- mel = get_mel(audio_data, audio_cfg)
558
- mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
559
- sample["mel_fusion"] = mel_fusion
560
- longer = torch.tensor([False])
561
-
562
- sample["longer"] = longer
563
- sample["waveform"] = audio_data
564
-
565
- return sample
566
-
567
-
568
- def preprocess(
569
- sample,
570
- audio_ext,
571
- text_ext,
572
- max_len,
573
- audio_cfg,
574
- class_index_dict=None,
575
- data_filling="pad",
576
- data_truncating="rand_trunc",
577
- text_augment_selection=None,
578
- ):
579
- """
580
- Preprocess a single sample for wdsdataloader.
581
- """
582
- audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
583
- audio_data = int16_to_float32(float32_to_int16(audio_data))
584
- audio_data = torch.tensor(audio_data).float()
585
-
586
- # TODO: (yusong) to be include in the future
587
- # # if torchaudio not installed, use soundfile to load audio
588
- # if torchaudio is None:
589
- # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
590
- # audio_data = torch.tensor(audio_data).float()
591
- # else:
592
- # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
593
- # with tempfile.TemporaryDirectory() as dirname:
594
- # os.makedirs(dirname, exist_ok=True)
595
- # fname = os.path.join(dirname, f"file.flac")
596
- # with open(fname, "wb") as stream:
597
- # stream.write(sample[audio_ext])
598
- # audio_data, orig_sr = torchaudio.load(fname)
599
- # audio_data = audio_data[0, :].float()
600
-
601
- sample = get_audio_features(
602
- sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
603
- )
604
- del sample[audio_ext]
605
-
606
- try:
607
- json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
608
- except:
609
- print("sample[__url__]:", sample["__url__"])
610
-
611
- # For selecting augmented text from dataset
612
- if text_augment_selection is None or text_augment_selection == "none":
613
- texts = json_dict_raw["text"]
614
- elif text_augment_selection == "all":
615
- if "text_augment_all" in json_dict_raw.keys():
616
- texts = json_dict_raw["text_augment_all"]
617
- else:
618
- texts = json_dict_raw["text"]
619
- elif text_augment_selection == "augment_only":
620
- if "text_augment_all" in json_dict_raw.keys():
621
- if json_dict_raw["text_augment_t5"] is None:
622
- texts = json_dict_raw["text"]
623
- else:
624
- texts = json_dict_raw["text_augment_t5"]
625
- else:
626
- texts = json_dict_raw["text"]
627
- else:
628
- raise NotImplementedError(
629
- f"text_augment_selection {text_augment_selection} not implemented"
630
- )
631
- sample["full_text"] = texts
632
-
633
- if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
634
- texts = random.choice(texts)
635
- sample["raw_text"] = texts
636
- sample["text"] = tokenizer(texts) # text shape: [num_token]
637
- if class_index_dict is not None:
638
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
639
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
640
- # key, val = class_index_dict
641
- # key = key[:].split('\n')
642
- # _dict = {k: v for k, v in zip(key, val)}
643
- sample["class_label"] = np.zeros(len(class_index_dict.keys()))
644
- for x in json_dict_raw["tag"]:
645
- sample["class_label"][class_index_dict[x]] = 1
646
- sample["class_label"] = torch.tensor(sample["class_label"]).float()
647
- del sample[text_ext]
648
- sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
649
- sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
650
- sample["audio_orig_sr"] = orig_sr
651
- return sample
652
-
653
-
654
- def collate_fn(batch):
655
- """
656
- Collate function for wdsdataloader.
657
- batch: a list of dict, each dict is a sample
658
- """
659
- # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
660
- batch_dict = {}
661
- for k in batch[0].keys():
662
- if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
663
- batch_dict[k] = {}
664
- for kk in batch[0][k].keys():
665
- tmp = []
666
- for i in range(len(batch)):
667
- tmp.append(batch[i][k][kk])
668
- batch_dict[k][kk] = torch.vstack(tmp)
669
- elif isinstance(batch[0][k], torch.Tensor):
670
- batch_dict[k] = torch.stack([sample[k] for sample in batch])
671
- elif isinstance(batch[0][k], np.ndarray):
672
- batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
673
- else:
674
- batch_dict[k] = [sample[k] for sample in batch]
675
- return batch_dict
676
-
677
-
678
- def get_wds_dataset(
679
- args,
680
- model_cfg,
681
- is_train,
682
- audio_ext="flac",
683
- text_ext="json",
684
- max_len=480000,
685
- proportion=1.0,
686
- sizefilepath_=None,
687
- is_local=None,
688
- ):
689
- """
690
- Get a dataset for wdsdataloader.
691
- """
692
- if is_local is None and (not args.remotedata is None):
693
- is_local = not args.remotedata
694
-
695
- input_shards = args.train_data if is_train else args.val_data
696
- assert input_shards is not None
697
-
698
- if not sizefilepath_ is None:
699
- sizefilepath = sizefilepath_
700
- else:
701
- sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
702
-
703
- if proportion != 1.0:
704
- num_samples, num_shards, input_shards, _ = sample_prop(
705
- sizefilepath, input_shards, proportion, is_local=is_local
706
- )
707
- else:
708
- num_samples, num_shards = get_dataset_size(
709
- input_shards, sizefilepath_=sizefilepath_, is_local=is_local
710
- )
711
-
712
- if not num_samples:
713
- if is_train:
714
- num_samples = args.train_num_samples
715
- if not num_samples:
716
- raise RuntimeError(
717
- "Currently, number of dataset samples must be specified for training dataset. "
718
- "Please specify via `--train-num-samples` if no dataset length info present."
719
- )
720
- else:
721
- num_samples = (
722
- args.val_num_samples or 0
723
- ) # eval will just exhaust the iterator if not specified
724
-
725
- pipeline = [wds.SimpleShardList(input_shards)]
726
- # at this point we have an iterator over all the shards
727
- # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
728
- if is_train or args.parallel_eval:
729
- pipeline.extend(
730
- [
731
- wds.detshuffle(
732
- bufsize=_SHARD_SHUFFLE_SIZE,
733
- initial=_SHARD_SHUFFLE_INITIAL,
734
- seed=args.seed,
735
- ),
736
- wds.split_by_node,
737
- wds.split_by_worker,
738
- # at this point, we have an iterator over the shards assigned to each worker at each node
739
- wds.tarfile_to_samples(handler=log_and_continue),
740
- wds.shuffle(
741
- bufsize=_SAMPLE_SHUFFLE_SIZE,
742
- initial=_SAMPLE_SHUFFLE_INITIAL,
743
- rng=random.Random(args.seed),
744
- ),
745
- # wds.repeatedly, # FIXME determine if this is beneficial
746
- ]
747
- )
748
- else:
749
- pipeline.extend(
750
- [
751
- wds.split_by_worker,
752
- # at this point, we have an iterator over the shards assigned to each worker
753
- wds.tarfile_to_samples(handler=log_and_continue),
754
- ]
755
- )
756
- pipeline.append(
757
- wds.map(
758
- partial(
759
- preprocess,
760
- audio_ext=audio_ext,
761
- text_ext=text_ext,
762
- max_len=max_len,
763
- audio_cfg=model_cfg["audio_cfg"],
764
- class_index_dict=copy.deepcopy(args.class_index_dict),
765
- data_filling=args.data_filling,
766
- data_truncating=args.data_truncating,
767
- text_augment_selection=args.text_augment_selection,
768
- )
769
- ),
770
- )
771
-
772
- pipeline.append(
773
- wds.batched(
774
- args.batch_size,
775
- partial=not (is_train or args.parallel_eval),
776
- collation_fn=collate_fn,
777
- )
778
- )
779
-
780
- dataset = wds.DataPipeline(*pipeline)
781
- if is_train or args.parallel_eval:
782
- # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
783
- # (yusong): See comments below.
784
- # roll over and repeat a few samples to get same number of full batches on each node
785
- global_batch_size = args.batch_size * args.world_size
786
- num_batches = math.ceil(num_samples / global_batch_size)
787
- num_workers = max(1, args.workers)
788
- num_worker_batches = math.ceil(
789
- num_batches / num_workers
790
- ) # per dataloader worker
791
- num_batches = num_worker_batches * num_workers
792
- num_samples = num_batches * global_batch_size
793
- dataset = dataset.with_epoch(
794
- num_worker_batches
795
- ) # each worker is iterating over this
796
- else:
797
- # last batches are partial, eval is done on single (master) node
798
- num_batches = math.ceil(num_samples / args.batch_size)
799
-
800
- kwargs = {}
801
- if args.horovod: # multi-node training on summit
802
- kwargs["multiprocessing_context"] = "forkserver"
803
-
804
- dataloader = wds.WebLoader(
805
- dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
806
- )
807
-
808
- # FIXME not clear which approach is better, with_epoch before vs after dataloader?
809
- # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
810
- # if is_train:
811
- # # roll over and repeat a few samples to get same number of full batches on each node
812
- # global_batch_size = args.batch_size * args.world_size
813
- # num_batches = math.ceil(num_samples / global_batch_size)
814
- # num_workers = max(1, args.workers)
815
- # num_batches = math.ceil(num_batches / num_workers) * num_workers
816
- # num_samples = num_batches * global_batch_size
817
- # dataloader = dataloader.with_epoch(num_batches)
818
- # else:
819
- # # last batches are partial, eval is done on single (master) node
820
- # num_batches = math.ceil(num_samples / args.batch_size)
821
-
822
- # add meta-data to dataloader instance for convenience
823
- dataloader.num_batches = num_batches
824
- dataloader.num_samples = num_samples
825
-
826
- return DataInfo(dataloader, None)
827
-
828
-
829
- def wds_batch_list2dict(
830
- batch,
831
- keys=[
832
- "__url__",
833
- "__key__",
834
- "waveform",
835
- "text",
836
- "raw_text",
837
- "audio_name",
838
- "text_name",
839
- "audio_orig_sr",
840
- ],
841
- ):
842
- """
843
- Return a dictionary of the batch, with keys as the names of the fields.
844
- """
845
- assert len(keys) == len(
846
- batch
847
- ), "batch must have same number of keys as keys argument"
848
- return {keys[i]: batch[i] for i in range(len(batch))}
849
-
850
-
851
- def get_csv_dataset(args, preprocess_fn, is_train):
852
- input_filename = args.train_data if is_train else args.val_data
853
- assert input_filename
854
- dataset = CsvDataset(
855
- input_filename,
856
- preprocess_fn,
857
- img_key=args.csv_img_key,
858
- caption_key=args.csv_caption_key,
859
- sep=args.csv_separator,
860
- )
861
- num_samples = len(dataset)
862
- sampler = DistributedSampler(dataset) if args.distributed and is_train else None
863
- shuffle = is_train and sampler is None
864
-
865
- dataloader = DataLoader(
866
- dataset,
867
- batch_size=args.batch_size,
868
- shuffle=shuffle,
869
- num_workers=args.workers,
870
- pin_memory=True,
871
- sampler=sampler,
872
- drop_last=is_train,
873
- )
874
- dataloader.num_samples = num_samples
875
- dataloader.num_batches = len(dataloader)
876
-
877
- return DataInfo(dataloader, sampler)
878
-
879
-
880
- def get_toy_dataset(args, model_cfg, is_train):
881
- index_path = args.train_data if is_train else args.val_data
882
- ipc_path = args.train_ipc if is_train else args.val_ipc
883
- assert index_path and ipc_path
884
- eval_mode = not is_train
885
- dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
886
-
887
- num_samples = len(dataset)
888
- sampler = (
889
- DistributedSampler(dataset, shuffle=False)
890
- if args.distributed and is_train
891
- else None
892
- )
893
-
894
- dataloader = DataLoader(
895
- dataset,
896
- batch_size=args.batch_size,
897
- shuffle=False,
898
- num_workers=args.workers,
899
- sampler=sampler,
900
- drop_last=is_train,
901
- )
902
- dataloader.num_samples = num_samples
903
- dataloader.num_batches = len(dataloader)
904
-
905
- return DataInfo(dataloader, sampler)
906
-
907
-
908
- def get_dataset_fn(data_path, dataset_type):
909
- if dataset_type == "webdataset":
910
- return get_wds_dataset
911
- elif dataset_type == "csv":
912
- return get_csv_dataset
913
- elif dataset_type == "auto":
914
- ext = data_path.split(".")[-1]
915
- if ext in ["csv", "tsv"]:
916
- return get_csv_dataset
917
- elif ext in ["tar"]:
918
- return get_wds_dataset
919
- else:
920
- raise ValueError(
921
- f"Tried to figure out dataset type, but failed for extention {ext}."
922
- )
923
- elif dataset_type == "toy":
924
- return get_toy_dataset
925
- else:
926
- raise ValueError(f"Unsupported dataset type: {dataset_type}")
927
-
928
-
929
- def get_data(args, model_cfg):
930
- data = {}
931
-
932
- args.class_index_dict = load_class_label(args.class_label_path)
933
-
934
- if args.datasetinfos is None:
935
- args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
936
- if args.dataset_type == "webdataset":
937
- args.train_data = get_tar_path_from_dataset_name(
938
- args.datasetnames,
939
- args.datasetinfos,
940
- islocal=not args.remotedata,
941
- proportion=args.dataset_proportion,
942
- dataset_path=args.datasetpath,
943
- full_dataset=args.full_train_dataset,
944
- )
945
-
946
- if args.full_train_dataset is None:
947
- args.full_train_dataset = []
948
- if args.exclude_eval_dataset is None:
949
- args.exclude_eval_dataset = []
950
- excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
951
-
952
- val_dataset_names = (
953
- [n for n in args.datasetnames if n not in excluded_eval_datasets]
954
- if excluded_eval_datasets
955
- else args.datasetnames
956
- )
957
- args.val_dataset_names = val_dataset_names
958
- args.val_data = get_tar_path_from_dataset_name(
959
- val_dataset_names,
960
- ["valid", "test", "eval"],
961
- islocal=not args.remotedata,
962
- proportion=1,
963
- dataset_path=args.datasetpath,
964
- full_dataset=None,
965
- )
966
-
967
- if args.train_data:
968
- data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
969
- args, model_cfg, is_train=True
970
- )
971
-
972
- if args.val_data:
973
- data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
974
- args, model_cfg, is_train=False
975
- )
976
-
977
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/training/distributed.py DELETED
@@ -1,150 +0,0 @@
1
- import os
2
-
3
- import torch
4
- import socket
5
-
6
- try:
7
- import horovod.torch as hvd
8
- except ImportError:
9
- hvd = None
10
-
11
-
12
- def is_global_master(args):
13
- return args.rank == 0
14
-
15
-
16
- def is_local_master(args):
17
- return args.local_rank == 0
18
-
19
-
20
- def is_master(args, local=False):
21
- return is_local_master(args) if local else is_global_master(args)
22
-
23
-
24
- def is_using_horovod():
25
- # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
26
- # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
27
- ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
28
- pmi_vars = ["PMI_RANK", "PMI_SIZE"]
29
- if all([var in os.environ for var in ompi_vars]) or all(
30
- [var in os.environ for var in pmi_vars]
31
- ):
32
- return True
33
- else:
34
- return False
35
-
36
-
37
- def is_using_distributed():
38
- if "WORLD_SIZE" in os.environ:
39
- return int(os.environ["WORLD_SIZE"]) > 1
40
- if "SLURM_NTASKS" in os.environ:
41
- return int(os.environ["SLURM_NTASKS"]) > 1
42
- return False
43
-
44
-
45
- def world_info_from_env():
46
- local_rank = 0
47
- for v in (
48
- "SLURM_LOCALID",
49
- "MPI_LOCALRANKID",
50
- "OMPI_COMM_WORLD_LOCAL_RANK",
51
- "LOCAL_RANK",
52
- ):
53
- if v in os.environ:
54
- local_rank = int(os.environ[v])
55
- break
56
- global_rank = 0
57
- for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"):
58
- if v in os.environ:
59
- global_rank = int(os.environ[v])
60
- break
61
- world_size = 1
62
- for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"):
63
- if v in os.environ:
64
- world_size = int(os.environ[v])
65
- break
66
-
67
- return local_rank, global_rank, world_size
68
-
69
-
70
- def init_distributed_device(args):
71
- # Distributed training = training on more than one GPU.
72
- # Works in both single and multi-node scenarios.
73
- args.distributed = False
74
- args.world_size = 1
75
- args.rank = 0 # global rank
76
- args.local_rank = 0
77
- if args.horovod:
78
- assert hvd is not None, "Horovod is not installed"
79
- hvd.init()
80
- world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
81
- world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
82
- local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
83
- args.local_rank = local_rank
84
- args.rank = world_rank
85
- args.world_size = world_size
86
- # args.local_rank = int(hvd.local_rank())
87
- # args.rank = hvd.rank()
88
- # args.world_size = hvd.size()
89
- args.distributed = True
90
- os.environ["LOCAL_RANK"] = str(args.local_rank)
91
- os.environ["RANK"] = str(args.rank)
92
- os.environ["WORLD_SIZE"] = str(args.world_size)
93
- print(
94
- f"Distributed training: local_rank={args.local_rank}, "
95
- f"rank={args.rank}, world_size={args.world_size}, "
96
- f"hostname={socket.gethostname()}, pid={os.getpid()}"
97
- )
98
- elif is_using_distributed():
99
- if "SLURM_PROCID" in os.environ:
100
- # DDP via SLURM
101
- args.local_rank, args.rank, args.world_size = world_info_from_env()
102
- # SLURM var -> torch.distributed vars in case needed
103
- os.environ["LOCAL_RANK"] = str(args.local_rank)
104
- os.environ["RANK"] = str(args.rank)
105
- os.environ["WORLD_SIZE"] = str(args.world_size)
106
- torch.distributed.init_process_group(
107
- backend=args.dist_backend,
108
- init_method=args.dist_url,
109
- world_size=args.world_size,
110
- rank=args.rank,
111
- )
112
- elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster
113
- world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
114
- world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
115
- local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
116
- args.local_rank = local_rank
117
- args.rank = world_rank
118
- args.world_size = world_size
119
- torch.distributed.init_process_group(
120
- backend=args.dist_backend,
121
- init_method=args.dist_url,
122
- world_size=args.world_size,
123
- rank=args.rank,
124
- )
125
- else:
126
- # DDP via torchrun, torch.distributed.launch
127
- args.local_rank, _, _ = world_info_from_env()
128
- torch.distributed.init_process_group(
129
- backend=args.dist_backend, init_method=args.dist_url
130
- )
131
- args.world_size = torch.distributed.get_world_size()
132
- args.rank = torch.distributed.get_rank()
133
- args.distributed = True
134
- print(
135
- f"Distributed training: local_rank={args.local_rank}, "
136
- f"rank={args.rank}, world_size={args.world_size}, "
137
- f"hostname={socket.gethostname()}, pid={os.getpid()}"
138
- )
139
-
140
- if torch.cuda.is_available():
141
- if args.distributed and not args.no_set_device_rank:
142
- device = "cuda:%d" % args.local_rank
143
- else:
144
- device = "cuda:0"
145
- torch.cuda.set_device(device)
146
- else:
147
- device = "cpu"
148
- args.device = device
149
- device = torch.device(device)
150
- return device