sino commited on
Commit
cee9fbc
·
1 Parent(s): 183e91f

Upload 21 files

Browse files
src/LMdecoder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from doctest import ELLIPSIS_MARKER
3
+ from functools import partial
4
+ import json
5
+ from turtle import forward, shape
6
+ import einops
7
+ import torch
8
+ from torch import nn
9
+
10
+ from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
11
+ from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \
12
+ BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
13
+ from transformers import BitsAndBytesConfig
14
+
15
+ from peft import prepare_model_for_kbit_training
16
+ from peft import LoraConfig
17
+ from peft import get_peft_model
18
+
19
+
20
+ from mmcv.cnn import build_norm_layer
21
+ from mmcv.runner import BaseModule
22
+ import math
23
+ from ipdb import set_trace
24
+
25
+ class mixEmbed(nn.Module):
26
+ def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None:
27
+ super().__init__(*args, **kwargs)
28
+ self.lm_embed = lm_embed
29
+ self.audio_embeddings = audio_embeddings # ugly but works without modifying raw model codes
30
+
31
+ def forward(self, input_ids):
32
+ text_ids = torch.clamp(input_ids.clone(), 0).long()
33
+
34
+ au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long()
35
+ text_embeds = self.lm_embed(text_ids)
36
+ au_embeds = self.audio_embeddings[au_ids]
37
+ with torch.no_grad():
38
+ embed_mask = (input_ids > 0)
39
+ mix_embeds = au_embeds.clone()
40
+ mix_embeds[embed_mask] = text_embeds[embed_mask]
41
+ return mix_embeds
42
+
43
+
44
+ class LMDecoder(nn.Module):
45
+ def __init__(self,
46
+ # num_patches=196,
47
+ img_size=(80,512),
48
+ patch_size:int=16,
49
+ in_chans:int=3,
50
+ embed_dim=1024, # encoder embed dim
51
+ decoder_embed_dim=512,
52
+ norm_cfg=dict(type='LN', eps=1e-6),
53
+ # patch_resolution=14,
54
+ decoder_type='gpt2',
55
+ freeze_decoder=True,
56
+ additional_layer:int=0,
57
+ ):
58
+ super().__init__()
59
+ self.decoder_type = decoder_type
60
+ self.load_lm()
61
+
62
+ self.lm_embed = self.lm.get_input_embeddings()
63
+ try:
64
+ self.lm_pos_embed = self.lm.get_position_embeddings()
65
+ except NotImplementedError:
66
+ self.lm_pos_embed = None # rotrary embeds
67
+
68
+
69
+ if hasattr(self.lm,'embed_dim'):
70
+ self.embed_dim = self.lm.embed_dim
71
+ else:
72
+ self.embed_dim = decoder_embed_dim
73
+
74
+ # self.asLM = asLM # if generates tokens rather than hidden states
75
+ # if self.asLM: # TODO: 当年写这个是为啥?
76
+ # self.lm.set_output_embeddings(nn.Linear(self.embed_dim, self.self.LMconfig.vocab_size, bias=False))
77
+ self.freeze_decoder = False
78
+ if True:
79
+ for para in self.lm.parameters():
80
+ para.requires_grad = False
81
+
82
+ def load_lm(self):
83
+ ## ---------------------LM setting----------------------
84
+ self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
85
+ if self.tokenizer.pad_token is None:
86
+ self.tokenizer.pad_token = self.tokenizer.eos_token
87
+ self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
88
+ self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, trust_remote_code=True)
89
+
90
+
91
+ def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs):
92
+ mix_embed = mixEmbed(self.lm_embed, flatten_embs)
93
+ self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
94
+ output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs)
95
+ self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
96
+ return output
97
+
98
+ def generate(self, input_ids, flatten_embs):
99
+ mix_embed = mixEmbed(self.lm_embed, flatten_embs)
100
+ self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
101
+ outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False)
102
+ # outputs = self.lm.generate(input_ids=input_ids,
103
+ # max_new_tokens=1024,
104
+ # do_sample=True,
105
+ # temperature=1.5,
106
+ # num_beams=1,
107
+ # top_p=0.9,
108
+ # top_k=3,
109
+ # use_cache=False)
110
+ self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
111
+ return outputs
112
+ '''
113
+ ## infer params
114
+ max_input_tokens: 40
115
+ batch_size_test: 16
116
+ max_new_tokens: 64
117
+ min_length: 2
118
+ num_beams: 5
119
+ length_penalty: -2.0
120
+ top_p: 0.9
121
+ top_k: 3
122
+ no_repeat_ngram_size: 2
123
+ apply_lemmatizer: False
124
+ use_nucleus_sampling: True
125
+ '''
126
+
127
+ class LMDecoder_qlora(LMDecoder):
128
+ def __init__(self,
129
+ # num_patches=196,
130
+ img_size=(80,512),
131
+ patch_size:int=16,
132
+ in_chans:int=3,
133
+ embed_dim=1024, # encoder embed dim
134
+ decoder_embed_dim=512,
135
+ norm_cfg=dict(type='LN', eps=1e-6),
136
+ # patch_resolution=14,
137
+ decoder_type='gpt2',
138
+ freeze_decoder=True,
139
+ additional_layer:int=0,
140
+ ):
141
+ super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer)
142
+
143
+ def load_lm(self):
144
+ self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
145
+ self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
146
+ double_quant_config = BitsAndBytesConfig(
147
+ load_in_4bit=True,
148
+ bnb_4bit_use_double_quant=True,
149
+ )
150
+ model = AutoModelForCausalLM.from_pretrained(self.decoder_type,
151
+ # device_map='auto', # if remove, can not add lora
152
+ # load_in_4bit=True,# if remove, can not add lora
153
+ # # torch_dtype=torch.bfloat16,
154
+ # quantization_config=double_quant_config, # if remove, can not add lora
155
+ trust_remote_code=True )
156
+
157
+ model.gradient_checkpointing_enable()
158
+ model = prepare_model_for_kbit_training(model)
159
+ lora_config = LoraConfig(
160
+ r=8,
161
+ lora_alpha=32,
162
+ target_modules=["query_key_value"],
163
+ lora_dropout=0.05,
164
+ bias="none",
165
+ task_type="CAUSAL_LM"
166
+ )
167
+
168
+ self.lm = get_peft_model(model, lora_config)
169
+ self.lm.print_trainable_parameters()
src/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .spectprompt import SpectPrompt
2
+ from .LMdecoder import LMDecoder
3
+ from .mae_vit import MAEViT
4
+ from .vision_transformer import VisionTransformer
5
+ from .htsat import HTSAT_Swin_Transformer, create_htsat_model
src/__pycache__/LMdecoder.cpython-310.pyc ADDED
Binary file (7.71 kB). View file
 
src/__pycache__/LMdecoder.cpython-39.pyc ADDED
Binary file (4.82 kB). View file
 
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (344 Bytes). View file
 
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (432 Bytes). View file
 
src/__pycache__/comm_utils.cpython-39.pyc ADDED
Binary file (7.22 kB). View file
 
src/__pycache__/htsat.cpython-39.pyc ADDED
Binary file (34.7 kB). View file
 
src/__pycache__/mae_vit.cpython-310.pyc ADDED
Binary file (7.32 kB). View file
 
src/__pycache__/mae_vit.cpython-39.pyc ADDED
Binary file (7.88 kB). View file
 
src/__pycache__/spectprompt.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
src/__pycache__/spectprompt.cpython-39.pyc ADDED
Binary file (14.6 kB). View file
 
src/__pycache__/vision_transformer.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
src/__pycache__/vision_transformer.cpython-39.pyc ADDED
Binary file (4.97 kB). View file
 
src/comm_utils.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains primitives for multi-gpu communication.
3
+ This is useful when doing distributed training.
4
+ """
5
+
6
+ import functools
7
+ import logging
8
+ import numpy as np
9
+ import pickle
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+ _LOCAL_PROCESS_GROUP = None
14
+ """
15
+ A torch process group which only includes processes that on the same machine as the current process.
16
+ This variable is set when processes are spawned by `launch()` in "engine/launch.py".
17
+ """
18
+
19
+
20
+ def get_world_size() -> int:
21
+ if not dist.is_available():
22
+ return 1
23
+ if not dist.is_initialized():
24
+ return 1
25
+ return dist.get_world_size()
26
+
27
+
28
+ def get_rank() -> int:
29
+ if not dist.is_available():
30
+ return 0
31
+ if not dist.is_initialized():
32
+ return 0
33
+ return dist.get_rank()
34
+
35
+
36
+ def get_local_rank() -> int:
37
+ """
38
+ Returns:
39
+ The rank of the current process within the local (per-machine) process group.
40
+ """
41
+ if not dist.is_available():
42
+ return 0
43
+ if not dist.is_initialized():
44
+ return 0
45
+ assert _LOCAL_PROCESS_GROUP is not None
46
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
47
+
48
+
49
+ def get_local_size() -> int:
50
+ """
51
+ Returns:
52
+ The size of the per-machine process group,
53
+ i.e. the number of processes per machine.
54
+ """
55
+ if not dist.is_available():
56
+ return 1
57
+ if not dist.is_initialized():
58
+ return 1
59
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
60
+
61
+
62
+ def is_main_process() -> bool:
63
+ return get_rank() == 0
64
+
65
+
66
+ def synchronize():
67
+ """
68
+ Helper function to synchronize (barrier) among all processes when
69
+ using distributed training
70
+ """
71
+ if not dist.is_available():
72
+ return
73
+ if not dist.is_initialized():
74
+ return
75
+ world_size = dist.get_world_size()
76
+ if world_size == 1:
77
+ return
78
+ dist.barrier()
79
+
80
+
81
+ @functools.lru_cache()
82
+ def _get_global_gloo_group():
83
+ """
84
+ Return a process group based on gloo backend, containing all the ranks
85
+ The result is cached.
86
+ """
87
+ if dist.get_backend() == "nccl":
88
+ return dist.new_group(backend="gloo")
89
+ else:
90
+ return dist.group.WORLD
91
+
92
+
93
+ def _serialize_to_tensor(data, group):
94
+ backend = dist.get_backend(group)
95
+ assert backend in ["gloo", "nccl"]
96
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
97
+
98
+ buffer = pickle.dumps(data)
99
+ if len(buffer) > 1024 ** 3:
100
+ logger = logging.getLogger(__name__)
101
+ logger.warning(
102
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
103
+ get_rank(), len(buffer) / (1024 ** 3), device
104
+ )
105
+ )
106
+ storage = torch.ByteStorage.from_buffer(buffer)
107
+ tensor = torch.ByteTensor(storage).to(device=device)
108
+ return tensor
109
+
110
+
111
+ def _pad_to_largest_tensor(tensor, group):
112
+ """
113
+ Returns:
114
+ list[int]: size of the tensor, on each rank
115
+ Tensor: padded tensor that has the max size
116
+ """
117
+ world_size = dist.get_world_size(group=group)
118
+ assert (
119
+ world_size >= 1
120
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
121
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
122
+ size_list = [
123
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
124
+ ]
125
+ dist.all_gather(size_list, local_size, group=group)
126
+ size_list = [int(size.item()) for size in size_list]
127
+
128
+ max_size = max(size_list)
129
+
130
+ # we pad the tensor because torch all_gather does not support
131
+ # gathering tensors of different shapes
132
+ if local_size != max_size:
133
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
134
+ tensor = torch.cat((tensor, padding), dim=0)
135
+ return size_list, tensor
136
+
137
+
138
+ def all_gather(data, group=None):
139
+ """
140
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
141
+ Args:
142
+ data: any picklable object
143
+ group: a torch process group. By default, will use a group which
144
+ contains all ranks on gloo backend.
145
+ Returns:
146
+ list[data]: list of data gathered from each rank
147
+ """
148
+ if get_world_size() == 1:
149
+ return [data]
150
+ if group is None:
151
+ group = _get_global_gloo_group()
152
+ if dist.get_world_size(group) == 1:
153
+ return [data]
154
+
155
+ tensor = _serialize_to_tensor(data, group)
156
+
157
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
158
+ max_size = max(size_list)
159
+
160
+ # receiving Tensor from all ranks
161
+ tensor_list = [
162
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
163
+ ]
164
+ dist.all_gather(tensor_list, tensor, group=group)
165
+
166
+ data_list = []
167
+ for size, tensor in zip(size_list, tensor_list):
168
+ buffer = tensor.cpu().numpy().tobytes()[:size]
169
+ data_list.append(pickle.loads(buffer))
170
+
171
+ return data_list
172
+
173
+
174
+ def gather(data, dst=0, group=None):
175
+ """
176
+ Run gather on arbitrary picklable data (not necessarily tensors).
177
+ Args:
178
+ data: any picklable object
179
+ dst (int): destination rank
180
+ group: a torch process group. By default, will use a group which
181
+ contains all ranks on gloo backend.
182
+ Returns:
183
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
184
+ an empty list.
185
+ """
186
+ if get_world_size() == 1:
187
+ return [data]
188
+ if group is None:
189
+ group = _get_global_gloo_group()
190
+ if dist.get_world_size(group=group) == 1:
191
+ return [data]
192
+ rank = dist.get_rank(group=group)
193
+
194
+ tensor = _serialize_to_tensor(data, group)
195
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
196
+
197
+ # receiving Tensor from all ranks
198
+ if rank == dst:
199
+ max_size = max(size_list)
200
+ tensor_list = [
201
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
202
+ ]
203
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
204
+
205
+ data_list = []
206
+ for size, tensor in zip(size_list, tensor_list):
207
+ buffer = tensor.cpu().numpy().tobytes()[:size]
208
+ data_list.append(pickle.loads(buffer))
209
+ return data_list
210
+ else:
211
+ dist.gather(tensor, [], dst=dst, group=group)
212
+ return []
213
+
214
+
215
+ def shared_random_seed():
216
+ """
217
+ Returns:
218
+ int: a random number that is the same across all workers.
219
+ If workers need a shared RNG, they can use this shared seed to
220
+ create one.
221
+ All workers must call this function, otherwise it will deadlock.
222
+ """
223
+ ints = np.random.randint(2 ** 31)
224
+ all_ints = all_gather(ints)
225
+ return all_ints[0]
226
+
227
+
228
+ def reduce_dict(input_dict, average=True):
229
+ """
230
+ Reduce the values in the dictionary from all processes so that process with rank
231
+ 0 has the reduced results.
232
+ Args:
233
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
234
+ average (bool): whether to do average or sum
235
+ Returns:
236
+ a dict with the same keys as input_dict, after reduction.
237
+ """
238
+ world_size = get_world_size()
239
+ if world_size < 2:
240
+ return input_dict
241
+ with torch.no_grad():
242
+ names = []
243
+ values = []
244
+ # sort the keys so that they are consistent across processes
245
+ for k in sorted(input_dict.keys()):
246
+ names.append(k)
247
+ values.append(input_dict[k])
248
+ values = torch.stack(values, dim=0)
249
+ dist.reduce(values, dst=0)
250
+ if dist.get_rank() == 0 and average:
251
+ # only main process gets accumulated, so only divide by
252
+ # world_size in this case
253
+ values /= world_size
254
+ reduced_dict = {k: v for k, v in zip(names, values)}
255
+ return reduced_dict
src/htsat.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+ import math
14
+ import warnings
15
+
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+ import torch.utils.checkpoint as checkpoint
18
+
19
+ import random
20
+
21
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
+ from torchlibrosa.augmentation import SpecAugmentation
23
+ from einops import rearrange
24
+ from itertools import repeat
25
+ # from .utils import interpolate
26
+
27
+ # from .feature_fusion import iAFF, AFF, DAF
28
+
29
+
30
+ '''
31
+ Feature Fusion for Varible-Length Data Processing
32
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
33
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
34
+ '''
35
+
36
+ class DAF(nn.Module):
37
+ '''
38
+ 直接相加 DirectAddFuse
39
+ '''
40
+
41
+ def __init__(self):
42
+ super(DAF, self).__init__()
43
+
44
+ def forward(self, x, residual):
45
+ return x + residual
46
+
47
+
48
+ class iAFF(nn.Module):
49
+ '''
50
+ 多特征融合 iAFF
51
+ '''
52
+
53
+ def __init__(self, channels=64, r=4, type='2D'):
54
+ super(iAFF, self).__init__()
55
+ inter_channels = int(channels // r)
56
+
57
+ if type == '1D':
58
+ # 本地注意力
59
+ self.local_att = nn.Sequential(
60
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
61
+ nn.BatchNorm1d(inter_channels),
62
+ nn.ReLU(inplace=True),
63
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(channels),
65
+ )
66
+
67
+ # 全局注意力
68
+ self.global_att = nn.Sequential(
69
+ nn.AdaptiveAvgPool1d(1),
70
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
71
+ nn.BatchNorm1d(inter_channels),
72
+ nn.ReLU(inplace=True),
73
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
74
+ nn.BatchNorm1d(channels),
75
+ )
76
+
77
+ # 第二次本地注意力
78
+ self.local_att2 = nn.Sequential(
79
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
80
+ nn.BatchNorm1d(inter_channels),
81
+ nn.ReLU(inplace=True),
82
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm1d(channels),
84
+ )
85
+ # 第二次全局注意力
86
+ self.global_att2 = nn.Sequential(
87
+ nn.AdaptiveAvgPool1d(1),
88
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
89
+ nn.BatchNorm1d(inter_channels),
90
+ nn.ReLU(inplace=True),
91
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm1d(channels),
93
+ )
94
+ elif type == '2D':
95
+ # 本地注意力
96
+ self.local_att = nn.Sequential(
97
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
98
+ nn.BatchNorm2d(inter_channels),
99
+ nn.ReLU(inplace=True),
100
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(channels),
102
+ )
103
+
104
+ # 全局注意力
105
+ self.global_att = nn.Sequential(
106
+ nn.AdaptiveAvgPool2d(1),
107
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
108
+ nn.BatchNorm2d(inter_channels),
109
+ nn.ReLU(inplace=True),
110
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
111
+ nn.BatchNorm2d(channels),
112
+ )
113
+
114
+ # 第二次本地注意力
115
+ self.local_att2 = nn.Sequential(
116
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
117
+ nn.BatchNorm2d(inter_channels),
118
+ nn.ReLU(inplace=True),
119
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
120
+ nn.BatchNorm2d(channels),
121
+ )
122
+ # 第二次全局注意力
123
+ self.global_att2 = nn.Sequential(
124
+ nn.AdaptiveAvgPool2d(1),
125
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
126
+ nn.BatchNorm2d(inter_channels),
127
+ nn.ReLU(inplace=True),
128
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
129
+ nn.BatchNorm2d(channels),
130
+ )
131
+ else:
132
+ raise f'the type is not supported'
133
+
134
+ self.sigmoid = nn.Sigmoid()
135
+
136
+ def forward(self, x, residual):
137
+ flag = False
138
+ xa = x + residual
139
+ if xa.size(0) == 1:
140
+ xa = torch.cat([xa,xa],dim=0)
141
+ flag = True
142
+ xl = self.local_att(xa)
143
+ xg = self.global_att(xa)
144
+ xlg = xl + xg
145
+ wei = self.sigmoid(xlg)
146
+ xi = x * wei + residual * (1 - wei)
147
+
148
+ xl2 = self.local_att2(xi)
149
+ xg2 = self.global_att(xi)
150
+ xlg2 = xl2 + xg2
151
+ wei2 = self.sigmoid(xlg2)
152
+ xo = x * wei2 + residual * (1 - wei2)
153
+ if flag:
154
+ xo = xo[0].unsqueeze(0)
155
+ return xo
156
+
157
+
158
+ class AFF(nn.Module):
159
+ '''
160
+ 多特征融合 AFF
161
+ '''
162
+
163
+ def __init__(self, channels=64, r=4, type='2D'):
164
+ super(AFF, self).__init__()
165
+ inter_channels = int(channels // r)
166
+
167
+ if type == '1D':
168
+ self.local_att = nn.Sequential(
169
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
170
+ nn.BatchNorm1d(inter_channels),
171
+ nn.ReLU(inplace=True),
172
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
173
+ nn.BatchNorm1d(channels),
174
+ )
175
+ self.global_att = nn.Sequential(
176
+ nn.AdaptiveAvgPool1d(1),
177
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
178
+ nn.BatchNorm1d(inter_channels),
179
+ nn.ReLU(inplace=True),
180
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
181
+ nn.BatchNorm1d(channels),
182
+ )
183
+ elif type == '2D':
184
+ self.local_att = nn.Sequential(
185
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
186
+ nn.BatchNorm2d(inter_channels),
187
+ nn.ReLU(inplace=True),
188
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
189
+ nn.BatchNorm2d(channels),
190
+ )
191
+ self.global_att = nn.Sequential(
192
+ nn.AdaptiveAvgPool2d(1),
193
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
194
+ nn.BatchNorm2d(inter_channels),
195
+ nn.ReLU(inplace=True),
196
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
197
+ nn.BatchNorm2d(channels),
198
+ )
199
+ else:
200
+ raise f'the type is not supported.'
201
+
202
+ self.sigmoid = nn.Sigmoid()
203
+
204
+ def forward(self, x, residual):
205
+ flag = False
206
+ xa = x + residual
207
+ if xa.size(0) == 1:
208
+ xa = torch.cat([xa,xa],dim=0)
209
+ flag = True
210
+ xl = self.local_att(xa)
211
+ xg = self.global_att(xa)
212
+ xlg = xl + xg
213
+ wei = self.sigmoid(xlg)
214
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
215
+ if flag:
216
+ xo = xo[0].unsqueeze(0)
217
+ return xo
218
+
219
+
220
+ # .utils
221
+
222
+ def interpolate(x, ratio):
223
+ """Interpolate data in time domain. This is used to compensate the
224
+ resolution reduction in downsampling of a CNN.
225
+
226
+ Args:
227
+ x: (batch_size, time_steps, classes_num)
228
+ ratio: int, ratio to interpolate
229
+ Returns:
230
+ upsampled: (batch_size, time_steps * ratio, classes_num)
231
+ """
232
+ (batch_size, time_steps, classes_num) = x.shape
233
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
234
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
235
+ return upsampled
236
+
237
+ def do_mixup(x, mixup_lambda):
238
+ """
239
+ Args:
240
+ x: (batch_size , ...)
241
+ mixup_lambda: (batch_size,)
242
+ Returns:
243
+ out: (batch_size, ...)
244
+ """
245
+ out = (
246
+ x.transpose(0, -1) * mixup_lambda
247
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
248
+ ).transpose(0, -1)
249
+ return out
250
+
251
+ # from PyTorch internals
252
+ def _ntuple(n):
253
+ def parse(x):
254
+ if isinstance(x, collections.abc.Iterable):
255
+ return x
256
+ return tuple(repeat(x, n))
257
+ return parse
258
+
259
+ to_1tuple = _ntuple(1)
260
+ to_2tuple = _ntuple(2)
261
+ to_3tuple = _ntuple(3)
262
+ to_4tuple = _ntuple(4)
263
+ to_ntuple = _ntuple
264
+
265
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
266
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
267
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
268
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
269
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
270
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
271
+ 'survival rate' as the argument.
272
+ """
273
+ if drop_prob == 0. or not training:
274
+ return x
275
+ keep_prob = 1 - drop_prob
276
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
277
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
278
+ random_tensor.floor_() # binarize
279
+ output = x.div(keep_prob) * random_tensor
280
+ return output
281
+
282
+
283
+ class DropPath(nn.Module):
284
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
285
+ """
286
+ def __init__(self, drop_prob=None):
287
+ super(DropPath, self).__init__()
288
+ self.drop_prob = drop_prob
289
+
290
+ def forward(self, x):
291
+ return drop_path(x, self.drop_prob, self.training)
292
+
293
+ class PatchEmbed(nn.Module):
294
+ """ 2D Image to Patch Embedding
295
+ """
296
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16,
297
+ enable_fusion=False, fusion_type='None'):
298
+ super().__init__()
299
+ img_size = to_2tuple(img_size)
300
+ patch_size = to_2tuple(patch_size)
301
+ patch_stride = to_2tuple(patch_stride)
302
+ self.img_size = img_size
303
+ self.patch_size = patch_size
304
+ self.patch_stride = patch_stride
305
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
306
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
307
+ self.flatten = flatten
308
+ self.in_chans = in_chans
309
+ self.embed_dim = embed_dim
310
+
311
+ self.enable_fusion = enable_fusion
312
+ self.fusion_type = fusion_type
313
+
314
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
315
+
316
+ if (self.enable_fusion) and (self.fusion_type == 'channel_map'):
317
+ self.proj = nn.Conv2d(in_chans*4, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
318
+ else:
319
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
320
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
321
+
322
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
323
+ self.mel_conv2d = nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_size[0], patch_size[1]*3), stride=(patch_stride[0], patch_stride[1] * 3), padding=padding)
324
+ if self.fusion_type == 'daf_2d':
325
+ self.fusion_model = DAF()
326
+ elif self.fusion_type == 'aff_2d':
327
+ self.fusion_model = AFF(channels=embed_dim, type='2D')
328
+ elif self.fusion_type == 'iaff_2d':
329
+ self.fusion_model = iAFF(channels=embed_dim, type='2D')
330
+ def forward(self, x, longer_idx = None):
331
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
332
+ global_x = x[:,0:1,:,:]
333
+
334
+
335
+ # global processing
336
+ B, C, H, W = global_x.shape
337
+ assert H == self.img_size[0] and W == self.img_size[1], \
338
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
339
+ global_x = self.proj(global_x)
340
+ TW = global_x.size(-1)
341
+ if len(longer_idx) > 0:
342
+ # local processing
343
+ local_x = x[longer_idx,1:,:,:].contiguous()
344
+ B, C, H, W = local_x.shape
345
+ local_x = local_x.view(B*C,1,H,W)
346
+ local_x = self.mel_conv2d(local_x)
347
+ local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3))
348
+ local_x = local_x.permute((0,2,3,1,4)).contiguous().flatten(3)
349
+ TB,TC,TH,_ = local_x.size()
350
+ if local_x.size(-1) < TW:
351
+ local_x = torch.cat([local_x, torch.zeros((TB,TC,TH,TW-local_x.size(-1)), device=global_x.device)], dim=-1)
352
+ else:
353
+ local_x = local_x[:,:,:,:TW]
354
+
355
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx],local_x)
356
+ x = global_x
357
+ else:
358
+ B, C, H, W = x.shape
359
+ assert H == self.img_size[0] and W == self.img_size[1], \
360
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
361
+ x = self.proj(x)
362
+
363
+ if self.flatten:
364
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
365
+ x = self.norm(x)
366
+ return x
367
+
368
+ class Mlp(nn.Module):
369
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
370
+ """
371
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
372
+ super().__init__()
373
+ out_features = out_features or in_features
374
+ hidden_features = hidden_features or in_features
375
+ self.fc1 = nn.Linear(in_features, hidden_features)
376
+ self.act = act_layer()
377
+ self.fc2 = nn.Linear(hidden_features, out_features)
378
+ self.drop = nn.Dropout(drop)
379
+
380
+ def forward(self, x):
381
+ x = self.fc1(x)
382
+ x = self.act(x)
383
+ x = self.drop(x)
384
+ x = self.fc2(x)
385
+ x = self.drop(x)
386
+ return x
387
+
388
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
389
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
390
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
391
+ def norm_cdf(x):
392
+ # Computes standard normal cumulative distribution function
393
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
394
+
395
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
396
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
397
+ "The distribution of values may be incorrect.",
398
+ stacklevel=2)
399
+
400
+ with torch.no_grad():
401
+ # Values are generated by using a truncated uniform distribution and
402
+ # then using the inverse CDF for the normal distribution.
403
+ # Get upper and lower cdf values
404
+ l = norm_cdf((a - mean) / std)
405
+ u = norm_cdf((b - mean) / std)
406
+
407
+ # Uniformly fill tensor with values from [l, u], then translate to
408
+ # [2l-1, 2u-1].
409
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
410
+
411
+ # Use inverse cdf transform for normal distribution to get truncated
412
+ # standard normal
413
+ tensor.erfinv_()
414
+
415
+ # Transform to proper mean, std
416
+ tensor.mul_(std * math.sqrt(2.))
417
+ tensor.add_(mean)
418
+
419
+ # Clamp to ensure it's in the proper range
420
+ tensor.clamp_(min=a, max=b)
421
+ return tensor
422
+
423
+
424
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
425
+ # type: (Tensor, float, float, float, float) -> Tensor
426
+ r"""Fills the input Tensor with values drawn from a truncated
427
+ normal distribution. The values are effectively drawn from the
428
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
429
+ with values outside :math:`[a, b]` redrawn until they are within
430
+ the bounds. The method used for generating the random values works
431
+ best when :math:`a \leq \text{mean} \leq b`.
432
+ Args:
433
+ tensor: an n-dimensional `torch.Tensor`
434
+ mean: the mean of the normal distribution
435
+ std: the standard deviation of the normal distribution
436
+ a: the minimum cutoff value
437
+ b: the maximum cutoff value
438
+ Examples:
439
+ >>> w = torch.empty(3, 5)
440
+ >>> nn.init.trunc_normal_(w)
441
+ """
442
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
443
+
444
+
445
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
446
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
447
+ if mode == 'fan_in':
448
+ denom = fan_in
449
+ elif mode == 'fan_out':
450
+ denom = fan_out
451
+ elif mode == 'fan_avg':
452
+ denom = (fan_in + fan_out) / 2
453
+
454
+ variance = scale / denom
455
+
456
+ if distribution == "truncated_normal":
457
+ # constant is stddev of standard normal truncated to (-2, 2)
458
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
459
+ elif distribution == "normal":
460
+ tensor.normal_(std=math.sqrt(variance))
461
+ elif distribution == "uniform":
462
+ bound = math.sqrt(3 * variance)
463
+ tensor.uniform_(-bound, bound)
464
+ else:
465
+ raise ValueError(f"invalid distribution {distribution}")
466
+
467
+
468
+ def lecun_normal_(tensor):
469
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
470
+
471
+ def window_partition(x, window_size):
472
+ """
473
+ Args:
474
+ x: (B, H, W, C)
475
+ window_size (int): window size
476
+ Returns:
477
+ windows: (num_windows*B, window_size, window_size, C)
478
+ """
479
+ B, H, W, C = x.shape
480
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
481
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
482
+ return windows
483
+
484
+
485
+ def window_reverse(windows, window_size, H, W):
486
+ """
487
+ Args:
488
+ windows: (num_windows*B, window_size, window_size, C)
489
+ window_size (int): Window size
490
+ H (int): Height of image
491
+ W (int): Width of image
492
+ Returns:
493
+ x: (B, H, W, C)
494
+ """
495
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
496
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
497
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
498
+ return x
499
+
500
+
501
+ class WindowAttention(nn.Module):
502
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
503
+ It supports both of shifted and non-shifted window.
504
+ Args:
505
+ dim (int): Number of input channels.
506
+ window_size (tuple[int]): The height and width of the window.
507
+ num_heads (int): Number of attention heads.
508
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
509
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
510
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
511
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
512
+ """
513
+
514
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
515
+
516
+ super().__init__()
517
+ self.dim = dim
518
+ self.window_size = window_size # Wh, Ww
519
+ self.num_heads = num_heads
520
+ head_dim = dim // num_heads
521
+ self.scale = qk_scale or head_dim ** -0.5
522
+
523
+ # define a parameter table of relative position bias
524
+ self.relative_position_bias_table = nn.Parameter(
525
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
526
+
527
+ # get pair-wise relative position index for each token inside the window
528
+ coords_h = torch.arange(self.window_size[0])
529
+ coords_w = torch.arange(self.window_size[1])
530
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
531
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
532
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
533
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
534
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
535
+ relative_coords[:, :, 1] += self.window_size[1] - 1
536
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
537
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
538
+ self.register_buffer("relative_position_index", relative_position_index)
539
+
540
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
541
+ self.attn_drop = nn.Dropout(attn_drop)
542
+ self.proj = nn.Linear(dim, dim)
543
+ self.proj_drop = nn.Dropout(proj_drop)
544
+
545
+ trunc_normal_(self.relative_position_bias_table, std=.02)
546
+ self.softmax = nn.Softmax(dim=-1)
547
+
548
+ def forward(self, x, mask=None):
549
+ """
550
+ Args:
551
+ x: input features with shape of (num_windows*B, N, C)
552
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
553
+ """
554
+ B_, N, C = x.shape
555
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
556
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
557
+
558
+ q = q * self.scale
559
+ attn = (q @ k.transpose(-2, -1))
560
+
561
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
562
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
563
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
564
+ attn = attn + relative_position_bias.unsqueeze(0)
565
+
566
+ if mask is not None:
567
+ nW = mask.shape[0]
568
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
569
+ attn = attn.view(-1, self.num_heads, N, N)
570
+ attn = self.softmax(attn)
571
+ else:
572
+ attn = self.softmax(attn)
573
+
574
+ attn = self.attn_drop(attn)
575
+
576
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
577
+ x = self.proj(x)
578
+ x = self.proj_drop(x)
579
+ return x, attn
580
+
581
+ def extra_repr(self):
582
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
583
+
584
+
585
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
586
+ class SwinTransformerBlock(nn.Module):
587
+ r""" Swin Transformer Block.
588
+ Args:
589
+ dim (int): Number of input channels.
590
+ input_resolution (tuple[int]): Input resulotion.
591
+ num_heads (int): Number of attention heads.
592
+ window_size (int): Window size.
593
+ shift_size (int): Shift size for SW-MSA.
594
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
595
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
596
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
597
+ drop (float, optional): Dropout rate. Default: 0.0
598
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
599
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
600
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
601
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
602
+ """
603
+
604
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
605
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
606
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
607
+ super().__init__()
608
+ self.dim = dim
609
+ self.input_resolution = input_resolution
610
+ self.num_heads = num_heads
611
+ self.window_size = window_size
612
+ self.shift_size = shift_size
613
+ self.mlp_ratio = mlp_ratio
614
+ self.norm_before_mlp = norm_before_mlp
615
+ if min(self.input_resolution) <= self.window_size:
616
+ # if window size is larger than input resolution, we don't partition windows
617
+ self.shift_size = 0
618
+ self.window_size = min(self.input_resolution)
619
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
620
+
621
+ self.norm1 = norm_layer(dim)
622
+ self.attn = WindowAttention(
623
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
624
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
625
+
626
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
627
+ if self.norm_before_mlp == 'ln':
628
+ self.norm2 = nn.LayerNorm(dim)
629
+ elif self.norm_before_mlp == 'bn':
630
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
631
+ else:
632
+ raise NotImplementedError
633
+ mlp_hidden_dim = int(dim * mlp_ratio)
634
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
635
+
636
+ if self.shift_size > 0:
637
+ # calculate attention mask for SW-MSA
638
+ H, W = self.input_resolution
639
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
640
+ h_slices = (slice(0, -self.window_size),
641
+ slice(-self.window_size, -self.shift_size),
642
+ slice(-self.shift_size, None))
643
+ w_slices = (slice(0, -self.window_size),
644
+ slice(-self.window_size, -self.shift_size),
645
+ slice(-self.shift_size, None))
646
+ cnt = 0
647
+ for h in h_slices:
648
+ for w in w_slices:
649
+ img_mask[:, h, w, :] = cnt
650
+ cnt += 1
651
+
652
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
653
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
654
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
655
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
656
+ else:
657
+ attn_mask = None
658
+
659
+ self.register_buffer("attn_mask", attn_mask)
660
+
661
+ def forward(self, x):
662
+ # pdb.set_trace()
663
+ H, W = self.input_resolution
664
+ # print("H: ", H)
665
+ # print("W: ", W)
666
+ # pdb.set_trace()
667
+ B, L, C = x.shape
668
+ # assert L == H * W, "input feature has wrong size"
669
+
670
+ shortcut = x
671
+ x = self.norm1(x)
672
+ x = x.view(B, H, W, C)
673
+
674
+ # cyclic shift
675
+ if self.shift_size > 0:
676
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
677
+ else:
678
+ shifted_x = x
679
+
680
+ # partition windows
681
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
682
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
683
+
684
+ # W-MSA/SW-MSA
685
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
686
+
687
+ # merge windows
688
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
689
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
690
+
691
+ # reverse cyclic shift
692
+ if self.shift_size > 0:
693
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
694
+ else:
695
+ x = shifted_x
696
+ x = x.view(B, H * W, C)
697
+
698
+ # FFN
699
+ x = shortcut + self.drop_path(x)
700
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
701
+
702
+ return x, attn
703
+
704
+ def extra_repr(self):
705
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
706
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
707
+
708
+
709
+
710
+ class PatchMerging(nn.Module):
711
+ r""" Patch Merging Layer.
712
+ Args:
713
+ input_resolution (tuple[int]): Resolution of input feature.
714
+ dim (int): Number of input channels.
715
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
716
+ """
717
+
718
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
719
+ super().__init__()
720
+ self.input_resolution = input_resolution
721
+ self.dim = dim
722
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
723
+ self.norm = norm_layer(4 * dim)
724
+
725
+ def forward(self, x):
726
+ """
727
+ x: B, H*W, C
728
+ """
729
+ H, W = self.input_resolution
730
+ B, L, C = x.shape
731
+ assert L == H * W, "input feature has wrong size"
732
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
733
+
734
+ x = x.view(B, H, W, C)
735
+
736
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
737
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
738
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
739
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
740
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
741
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
742
+
743
+ x = self.norm(x)
744
+ x = self.reduction(x)
745
+
746
+ return x
747
+
748
+ def extra_repr(self):
749
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
750
+
751
+
752
+ class BasicLayer(nn.Module):
753
+ """ A basic Swin Transformer layer for one stage.
754
+ Args:
755
+ dim (int): Number of input channels.
756
+ input_resolution (tuple[int]): Input resolution.
757
+ depth (int): Number of blocks.
758
+ num_heads (int): Number of attention heads.
759
+ window_size (int): Local window size.
760
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
761
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
762
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
763
+ drop (float, optional): Dropout rate. Default: 0.0
764
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
765
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
766
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
767
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
768
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
769
+ """
770
+
771
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
772
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
773
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
774
+ norm_before_mlp='ln'):
775
+
776
+ super().__init__()
777
+ self.dim = dim
778
+ self.input_resolution = input_resolution
779
+ self.depth = depth
780
+ self.use_checkpoint = use_checkpoint
781
+
782
+ # build blocks
783
+ self.blocks = nn.ModuleList([
784
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
785
+ num_heads=num_heads, window_size=window_size,
786
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
787
+ mlp_ratio=mlp_ratio,
788
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
789
+ drop=drop, attn_drop=attn_drop,
790
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
791
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
792
+ for i in range(depth)])
793
+
794
+ # patch merging layer
795
+ if downsample is not None:
796
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
797
+ else:
798
+ self.downsample = None
799
+
800
+ def forward(self, x):
801
+ attns = []
802
+ for blk in self.blocks:
803
+ if self.use_checkpoint:
804
+ x = checkpoint.checkpoint(blk, x)
805
+ else:
806
+ x, attn = blk(x)
807
+ if not self.training:
808
+ attns.append(attn.unsqueeze(0))
809
+ if self.downsample is not None:
810
+ x = self.downsample(x)
811
+ if not self.training:
812
+ attn = torch.cat(attns, dim = 0)
813
+ attn = torch.mean(attn, dim = 0)
814
+ return x, attn
815
+
816
+ # if self.downsample is not None:
817
+ # x = self.downsample(x)
818
+ # if not self.training:
819
+ # attn = torch.cat(attns, dim = 0)
820
+ # attn = torch.mean(attn, dim = 0)
821
+ # return x, attn
822
+
823
+ def extra_repr(self):
824
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
825
+
826
+
827
+ # The Core of HTSAT
828
+ class HTSAT_Swin_Transformer(nn.Module):
829
+ r"""HTSAT based on the Swin Transformer
830
+ Args:
831
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
832
+ patch_size (int | tuple(int)): Patch size. Default: 4
833
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
834
+ in_chans (int): Number of input image channels. Default: 1 (mono)
835
+ num_classes (int): Number of classes for classification head. Default: 527
836
+ embed_dim (int): Patch embedding dimension. Default: 96
837
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
838
+ num_heads (tuple(int)): Number of attention heads in different layers.
839
+ window_size (int): Window size. Default: 8
840
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
841
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
842
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
843
+ drop_rate (float): Dropout rate. Default: 0
844
+ attn_drop_rate (float): Attention dropout rate. Default: 0
845
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
846
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
847
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
848
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
849
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
850
+ config (module): The configuration Module from config.py
851
+ """
852
+
853
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
854
+ in_chans=1, num_classes=527,
855
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
856
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
857
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
858
+ norm_layer=nn.LayerNorm,
859
+ ape=False, patch_norm=True,
860
+ use_checkpoint=False, norm_before_mlp='ln', config = None,
861
+ enable_fusion = False, fusion_type = 'None', **kwargs):
862
+ super(HTSAT_Swin_Transformer, self).__init__()
863
+
864
+ self.config = config
865
+ self.spec_size = spec_size
866
+ self.patch_stride = patch_stride
867
+ self.patch_size = patch_size
868
+ self.window_size = window_size
869
+ self.embed_dim = embed_dim
870
+ self.depths = depths
871
+ self.ape = ape
872
+ self.in_chans = in_chans
873
+ self.num_classes = num_classes
874
+ self.num_heads = num_heads
875
+ self.num_layers = len(self.depths)
876
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
877
+
878
+ self.drop_rate = drop_rate
879
+ self.attn_drop_rate = attn_drop_rate
880
+ self.drop_path_rate = drop_path_rate
881
+
882
+ self.qkv_bias = qkv_bias
883
+ self.qk_scale = None
884
+
885
+ self.patch_norm = patch_norm
886
+ self.norm_layer = norm_layer if self.patch_norm else None
887
+ self.norm_before_mlp = norm_before_mlp
888
+ self.mlp_ratio = mlp_ratio
889
+
890
+ self.use_checkpoint = use_checkpoint
891
+
892
+ self.enable_fusion = enable_fusion
893
+ self.fusion_type = fusion_type
894
+
895
+ # process mel-spec ; used only once
896
+ self.freq_ratio = self.spec_size // self.config.mel_bins
897
+ window = 'hann'
898
+ center = True
899
+ pad_mode = 'reflect'
900
+ ref = 1.0
901
+ amin = 1e-10
902
+ top_db = None
903
+ self.interpolate_ratio = 32 # Downsampled ratio
904
+ # Spectrogram extractor
905
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
906
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
907
+ freeze_parameters=True)
908
+ # Logmel feature extractor
909
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
910
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
911
+ freeze_parameters=True)
912
+ # Spec augmenter
913
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
914
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
915
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
916
+
917
+
918
+ # split spctrogram into non-overlapping patches
919
+ self.patch_embed = PatchEmbed(
920
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
921
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride,
922
+ enable_fusion=self.enable_fusion, fusion_type=self.fusion_type
923
+ )
924
+
925
+ num_patches = self.patch_embed.num_patches
926
+ patches_resolution = self.patch_embed.grid_size
927
+ self.patches_resolution = patches_resolution
928
+
929
+ # absolute position embedding
930
+ if self.ape:
931
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
932
+ trunc_normal_(self.absolute_pos_embed, std=.02)
933
+
934
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
935
+
936
+ # stochastic depth
937
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
938
+
939
+ # build layers
940
+ self.layers = nn.ModuleList()
941
+ for i_layer in range(self.num_layers):
942
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
943
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
944
+ patches_resolution[1] // (2 ** i_layer)),
945
+ depth=self.depths[i_layer],
946
+ num_heads=self.num_heads[i_layer],
947
+ window_size=self.window_size,
948
+ mlp_ratio=self.mlp_ratio,
949
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
950
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
951
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
952
+ norm_layer=self.norm_layer,
953
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
954
+ use_checkpoint=use_checkpoint,
955
+ norm_before_mlp=self.norm_before_mlp)
956
+ self.layers.append(layer)
957
+
958
+ self.norm = self.norm_layer(self.num_features)
959
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
960
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
961
+
962
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
963
+ self.tscam_conv = nn.Conv2d(
964
+ in_channels = self.num_features,
965
+ out_channels = self.num_classes,
966
+ kernel_size = (SF,3),
967
+ padding = (0,1)
968
+ )
969
+ self.head = nn.Linear(num_classes, num_classes)
970
+
971
+ if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']):
972
+ self.mel_conv1d = nn.Sequential(
973
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
974
+ nn.BatchNorm1d(64)
975
+ )
976
+ if self.fusion_type == 'daf_1d':
977
+ self.fusion_model = DAF()
978
+ elif self.fusion_type == 'aff_1d':
979
+ self.fusion_model = AFF(channels=64, type='1D')
980
+ elif self.fusion_type == 'iaff_1d':
981
+ self.fusion_model = iAFF(channels=64, type='1D')
982
+
983
+ self.apply(self._init_weights)
984
+
985
+ def _init_weights(self, m):
986
+ if isinstance(m, nn.Linear):
987
+ trunc_normal_(m.weight, std=.02)
988
+ if isinstance(m, nn.Linear) and m.bias is not None:
989
+ nn.init.constant_(m.bias, 0)
990
+ elif isinstance(m, nn.LayerNorm):
991
+ nn.init.constant_(m.bias, 0)
992
+ nn.init.constant_(m.weight, 1.0)
993
+
994
+ @torch.jit.ignore
995
+ def no_weight_decay(self):
996
+ return {'absolute_pos_embed'}
997
+
998
+ @torch.jit.ignore
999
+ def no_weight_decay_keywords(self):
1000
+ return {'relative_position_bias_table'}
1001
+
1002
+
1003
+ def forward_features(self, x, longer_idx = None):
1004
+ # A deprecated optimization for using a hierarchical output from different blocks
1005
+
1006
+ frames_num = x.shape[2]
1007
+ x = self.patch_embed(x, longer_idx = longer_idx)
1008
+ if self.ape:
1009
+ x = x + self.absolute_pos_embed
1010
+ x = self.pos_drop(x)
1011
+ for i, layer in enumerate(self.layers):
1012
+ x, attn = layer(x)
1013
+ # for x
1014
+ x = self.norm(x)
1015
+ B, N, C = x.shape
1016
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1017
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1018
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
1019
+ B, C, F, T = x.shape
1020
+ # group 2D CNN
1021
+ c_freq_bin = F // self.freq_ratio
1022
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1023
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
1024
+ # get latent_output
1025
+ fine_grained_latent_output = torch.mean(x, dim = 2)
1026
+ fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1027
+
1028
+ latent_output = self.avgpool(torch.flatten(x,2))
1029
+ latent_output = torch.flatten(latent_output, 1)
1030
+
1031
+ # display the attention map, if needed
1032
+
1033
+ x = self.tscam_conv(x)
1034
+ x = torch.flatten(x, 2) # B, C, T
1035
+
1036
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1037
+
1038
+ x = self.avgpool(x)
1039
+ x = torch.flatten(x, 1)
1040
+
1041
+ output_dict = {
1042
+ 'framewise_output': fpx, # already sigmoided
1043
+ 'clipwise_output': torch.sigmoid(x),
1044
+ 'fine_grained_embedding': fine_grained_latent_output,
1045
+ 'embedding': latent_output
1046
+ }
1047
+
1048
+ return output_dict
1049
+
1050
+ def crop_wav(self, x, crop_size, spe_pos = None):
1051
+ time_steps = x.shape[2]
1052
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1053
+ for i in range(len(x)):
1054
+ if spe_pos is None:
1055
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1056
+ else:
1057
+ crop_pos = spe_pos
1058
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
1059
+ return tx
1060
+
1061
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1062
+ def reshape_wav2img(self, x):
1063
+ B, C, T, F = x.shape
1064
+ target_T = int(self.spec_size * self.freq_ratio)
1065
+ target_F = self.spec_size // self.freq_ratio
1066
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1067
+ # to avoid bicubic zero error
1068
+ if T < target_T:
1069
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1070
+ if F < target_F:
1071
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1072
+ x = x.permute(0,1,3,2).contiguous()
1073
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
1074
+ # print(x.shape)
1075
+ x = x.permute(0,1,3,2,4).contiguous()
1076
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1077
+ return x
1078
+
1079
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1080
+ def repeat_wat2img(self, x, cur_pos):
1081
+ B, C, T, F = x.shape
1082
+ target_T = int(self.spec_size * self.freq_ratio)
1083
+ target_F = self.spec_size // self.freq_ratio
1084
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1085
+ # to avoid bicubic zero error
1086
+ if T < target_T:
1087
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1088
+ if F < target_F:
1089
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1090
+ x = x.permute(0,1,3,2).contiguous() # B C F T
1091
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
1092
+ x = x.repeat(repeats = (1,1,4,1))
1093
+ return x
1094
+
1095
+ def forward_generator(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
1096
+
1097
+ n = int(x.shape[1]/480000)
1098
+ assert n * 480000 == x.shape[1]
1099
+ x = rearrange(x, 'b (n t) -> (b n) t', n=n)
1100
+ if not self.enable_fusion:
1101
+ # x = x["waveform"].to(device=device, non_blocking=True)
1102
+ x = x.to(device=device, non_blocking=True)
1103
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1104
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1105
+ x = x.transpose(1, 3)
1106
+ x = self.bn0(x)
1107
+ x = x.transpose(1, 3)
1108
+ if self.training:
1109
+ x = self.spec_augmenter(x)
1110
+
1111
+ if self.training and mixup_lambda is not None:
1112
+ x = do_mixup(x, mixup_lambda)
1113
+
1114
+ x = self.reshape_wav2img(x)
1115
+ # output_dict = self.forward_features(x)
1116
+
1117
+ # A deprecated optimization for using a hierarchical output from different blocks
1118
+ longer_idx = None
1119
+ frames_num = x.shape[2]
1120
+ x = self.patch_embed(x, longer_idx = longer_idx)
1121
+ if self.ape:
1122
+ x = x + self.absolute_pos_embed
1123
+ x = self.pos_drop(x)
1124
+ for i, layer in enumerate(self.layers[:3]): # depth: [2,2,12,2]
1125
+ if i == 2:
1126
+ for blk in layer.blocks:
1127
+ x, attn = blk(x)
1128
+ # 512
1129
+ x = rearrange(x, '(b n) t c -> b (n t) c', n=n)
1130
+ x = x if (new_x:=(yield x)) is None else new_x
1131
+ x = rearrange(x, 'b (n t) c -> (b n) t c', n=n)
1132
+ else:
1133
+ x, attn = layer(x)
1134
+
1135
+
1136
+
1137
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
1138
+
1139
+ n = int(x.shape[1] / 480000)
1140
+ assert n * 480000 == x.shape[1]
1141
+ x = rearrange(x, 'b (n t) -> (b n) t', n = n)
1142
+ if not self.enable_fusion:
1143
+ # x = x["waveform"].to(device=device, non_blocking=True)
1144
+ x = x.to(device=device, non_blocking=True)
1145
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1146
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1147
+ x = x.transpose(1, 3)
1148
+ x = self.bn0(x)
1149
+ x = x.transpose(1, 3)
1150
+ if self.training:
1151
+ x = self.spec_augmenter(x)
1152
+
1153
+ if self.training and mixup_lambda is not None:
1154
+ x = do_mixup(x, mixup_lambda)
1155
+
1156
+ x = self.reshape_wav2img(x)
1157
+ # x = self.forward_features(x)
1158
+
1159
+ longer_idx = None
1160
+ frames_num = x.shape[2]
1161
+ x = self.patch_embed(x, longer_idx = longer_idx)
1162
+ if self.ape:
1163
+ x = x + self.absolute_pos_embed
1164
+ x = self.pos_drop(x)
1165
+ for i, layer in enumerate(self.layers):
1166
+ x, attn = layer(x)
1167
+ # for x
1168
+ x = self.norm(x)
1169
+ x = rearrange(x, '(b n) t c -> b (n t) c', n = n)
1170
+ return x
1171
+
1172
+ # B, N, C = x.shape
1173
+ # SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1174
+ # ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1175
+ # x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
1176
+ # B, C, F, T = x.shape
1177
+ # # group 2D CNN
1178
+ # c_freq_bin = F // self.freq_ratio
1179
+ # x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1180
+ # x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
1181
+ # # get latent_output
1182
+ # fine_grained_latent_output = torch.mean(x, dim = 2)
1183
+ # fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1184
+
1185
+ # latent_output = self.avgpool(torch.flatten(x,2))
1186
+ # latent_output = torch.flatten(latent_output, 1)
1187
+
1188
+ # # display the attention map, if needed
1189
+
1190
+ # x = self.tscam_conv(x)
1191
+ # x = torch.flatten(x, 2) # B, C, T
1192
+
1193
+ # fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1194
+
1195
+ # x = self.avgpool(x)
1196
+ # x = torch.flatten(x, 1)
1197
+ # return x
1198
+
1199
+ def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type='None'):
1200
+ try:
1201
+
1202
+ assert audio_cfg.model_name in ["tiny", "base", "large"], "model name for HTS-AT is wrong!"
1203
+ if audio_cfg.model_name == "tiny":
1204
+ model = HTSAT_Swin_Transformer(
1205
+ spec_size=256,
1206
+ patch_size=4,
1207
+ patch_stride=(4,4),
1208
+ num_classes=audio_cfg.class_num,
1209
+ embed_dim=96,
1210
+ depths=[2,2,6,2],
1211
+ num_heads=[4,8,16,32],
1212
+ window_size=8,
1213
+ config = audio_cfg,
1214
+ enable_fusion = enable_fusion,
1215
+ fusion_type = fusion_type
1216
+ )
1217
+ elif audio_cfg.model_name == "base":
1218
+ model = HTSAT_Swin_Transformer(
1219
+ spec_size=256,
1220
+ patch_size=4,
1221
+ patch_stride=(4,4),
1222
+ num_classes=audio_cfg.class_num,
1223
+ embed_dim=128,
1224
+ depths=[2,2,12,2],
1225
+ num_heads=[4,8,16,32],
1226
+ window_size=8,
1227
+ config = audio_cfg,
1228
+ enable_fusion = enable_fusion,
1229
+ fusion_type = fusion_type
1230
+ )
1231
+ elif audio_cfg.model_name == "large":
1232
+ model = HTSAT_Swin_Transformer(
1233
+ spec_size=256,
1234
+ patch_size=4,
1235
+ patch_stride=(4,4),
1236
+ num_classes=audio_cfg.class_num,
1237
+ embed_dim=256,
1238
+ depths=[2,2,12,2],
1239
+ num_heads=[4,8,16,32],
1240
+ window_size=8,
1241
+ config = audio_cfg,
1242
+ enable_fusion = enable_fusion,
1243
+ fusion_type = fusion_type
1244
+ )
1245
+
1246
+ return model
1247
+ except:
1248
+ raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.')
1249
+
src/mae_vit.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mmcls.models import VisionTransformer
3
+ from torch import nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ import copy
6
+
7
+ def build_2d_sincos_position_embedding(patches_resolution,
8
+ embed_dims,
9
+ temperature=10000.,
10
+ cls_token=False):
11
+ """The function is to build position embedding for model to obtain the
12
+ position information of the image patches."""
13
+
14
+ if isinstance(patches_resolution, int):
15
+ patches_resolution = (patches_resolution, patches_resolution)
16
+
17
+ h, w = patches_resolution
18
+ grid_w = torch.arange(w, dtype=torch.float32)
19
+ grid_h = torch.arange(h, dtype=torch.float32)
20
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
21
+ assert embed_dims % 4 == 0, \
22
+ 'Embed dimension must be divisible by 4.'
23
+ pos_dim = embed_dims // 4
24
+
25
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
26
+ omega = 1. / (temperature**omega)
27
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
28
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
29
+
30
+ pos_emb = torch.cat(
31
+ [
32
+ torch.sin(out_w),
33
+ torch.cos(out_w),
34
+ torch.sin(out_h),
35
+ torch.cos(out_h)
36
+ ],
37
+ dim=1,
38
+ )[None, :, :]
39
+
40
+ if cls_token:
41
+ cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
42
+ pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
43
+
44
+ return pos_emb
45
+
46
+
47
+
48
+ class MAEViT(VisionTransformer):
49
+ """Vision Transformer for MAE pre-training.
50
+
51
+ A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers
52
+ for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
53
+
54
+ Args:
55
+ arch (str | dict): Vision Transformer architecture
56
+ Default: 'b'
57
+ img_size (int | tuple): Input image size
58
+ patch_size (int | tuple): The patch size
59
+ out_indices (Sequence | int): Output from which stages.
60
+ Defaults to -1, means the last stage.
61
+ drop_rate (float): Probability of an element to be zeroed.
62
+ Defaults to 0.
63
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
64
+ norm_cfg (dict): Config dict for normalization layer.
65
+ Defaults to ``dict(type='LN')``.
66
+ final_norm (bool): Whether to add a additional layer to normalize
67
+ final feature map. Defaults to True.
68
+ output_cls_token (bool): Whether output the cls_token. If set True,
69
+ `with_cls_token` must be True. Defaults to True.
70
+ interpolate_mode (str): Select the interpolate mode for position
71
+ embeding vector resize. Defaults to "bicubic".
72
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
73
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
74
+ encoder. Defaults to an empty dict.
75
+ mask_ratio (bool): The ratio of total number of patches to be masked.
76
+ Defaults to 0.75.
77
+ init_cfg (dict, optional): Initialization config dict.
78
+ Defaults to None.
79
+ """
80
+
81
+ arch_zoo = {
82
+ **dict.fromkeys(
83
+ ['mocov3-s', 'mocov3-small'], {
84
+ 'embed_dims': 384,
85
+ 'num_layers': 12,
86
+ 'num_heads': 12,
87
+ 'feedforward_channels': 1536,
88
+ }),
89
+ **dict.fromkeys(
90
+ ['b', 'base'], {
91
+ 'embed_dims': 768,
92
+ 'num_layers': 12,
93
+ 'num_heads': 12,
94
+ 'feedforward_channels': 3072
95
+ }),
96
+ }
97
+
98
+
99
+
100
+ def __init__(self,
101
+ arch='b',
102
+ img_size=224,
103
+ patch_size=16,
104
+ out_indices=-1,
105
+ drop_rate=0,
106
+ drop_path_rate=0,
107
+ norm_cfg=dict(type='LN', eps=1e-6),
108
+ final_norm=True,
109
+ output_cls_token=False,
110
+ interpolate_mode='bicubic',
111
+ patch_cfg=dict(),
112
+ layer_cfgs=dict(),
113
+ gradientCKPT=False,
114
+ mask_ratio=0.75,
115
+ init_cfg=None):
116
+ super().__init__(
117
+ arch=arch,
118
+ img_size=img_size,
119
+ patch_size=patch_size,
120
+ out_indices=out_indices,
121
+ drop_rate=drop_rate,
122
+ drop_path_rate=drop_path_rate,
123
+ norm_cfg=norm_cfg,
124
+ final_norm=final_norm,
125
+ output_cls_token=output_cls_token,
126
+ interpolate_mode=interpolate_mode,
127
+ patch_cfg=patch_cfg,
128
+ layer_cfgs=layer_cfgs,
129
+ init_cfg=init_cfg)
130
+ self.gradientCKPT = gradientCKPT
131
+ self.pos_embed.requires_grad = False
132
+ self.mask_ratio = mask_ratio
133
+ self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
134
+ # self.mask_embedding = copy.deepcopy(self.patch_embed)
135
+ # self.mask_embedding.norm = None
136
+
137
+ def init_weights(self):
138
+ super(MAEViT, self).init_weights()
139
+ if not (isinstance(self.init_cfg, dict)
140
+ and self.init_cfg['type'] == 'Pretrained'):
141
+ # initialize position embedding in backbone
142
+ pos_embed = build_2d_sincos_position_embedding(
143
+ self.patch_resolution,
144
+ self.pos_embed.shape[-1],
145
+ cls_token=True)
146
+ self.pos_embed.data.copy_(pos_embed.float())
147
+
148
+ w = self.patch_embed.projection.weight.data
149
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
150
+
151
+ torch.nn.init.normal_(self.cls_token, std=.02)
152
+
153
+ self.apply(self._init_weights)
154
+
155
+ # mask_embedding transfers pixel level mask to token level
156
+ # self.mask_embedding.apply(self._init_mask_embedding)
157
+ # for para in self.mask_embedding.parameters():
158
+ # para.requires_grad = False
159
+
160
+ def _init_mask_embedding(self,m):
161
+ if hasattr(m,'weight'):
162
+ nn.init.constant_(m.weight,1.0)
163
+ if hasattr(m, 'bias'):
164
+ nn.init.constant_(m.bias,0)
165
+
166
+ def _init_weights(self, m):
167
+
168
+ if isinstance(m, nn.Linear):
169
+ torch.nn.init.xavier_uniform_(m.weight)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ def random_masking(self, x, mask_ratio=0.75, attn_mask=None):
177
+ """Generate the mask for MAE Pre-training.
178
+
179
+ Args:
180
+ x (torch.tensor): Image with data augmentation applied.
181
+ mask_ratio (float): The mask ratio of total patches.
182
+ Defaults to 0.75.
183
+
184
+ Returns:
185
+ tuple[Tensor, Tensor, Tensor]: masked image, mask and the ids
186
+ to restore original image.
187
+
188
+ - x_masked (Tensor): masked image.
189
+ - mask (Tensor): mask used to mask image.
190
+ - ids_restore (Tensor): ids to restore original image.
191
+ """
192
+ N, L, D = x.shape # batch, length, dim
193
+ len_keep = int(L * (1 - mask_ratio))
194
+
195
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
196
+
197
+ # sort noise for each sample
198
+ ids_shuffle = torch.argsort(
199
+ noise, dim=1) # ascend: small is keep, large is remove
200
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
201
+
202
+ # keep the first subset
203
+ ids_keep = ids_shuffle[:, :len_keep]
204
+ x_masked = torch.gather(
205
+ x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
206
+ # modified_attn_mask = None if attn_mask is None else torch.gather(attn_mask,dim=1, index=ids_keep)
207
+
208
+ # generate the binary mask: 0 is keep, 1 is remove
209
+ mask = torch.ones([N, L], device=x.device)
210
+ mask[:, :len_keep] = 0
211
+ # unshuffle to get the binary mask
212
+ mask = torch.gather(mask, dim=1, index=ids_restore)
213
+
214
+ return x_masked, mask, ids_restore #, modified_attn_mask
215
+
216
+ def generate_mask(self, pixel_level_attn_mask):
217
+ '''
218
+ pixel_level_attn_mask: (0,1) attn mask with the same shape as img
219
+ '''
220
+ if pixel_level_attn_mask is None: return None
221
+ # H, W = patch_resolution
222
+ # B, C = pixel_level_attn_mask.shape[:2]
223
+ # attn_mask = torch.ones((B,C,H,W),device=pixel_level_attn_mask)
224
+ # H_splited = torch.chunk(pixel_level_attn_mask, H, -2)
225
+ # HW_splited_mask = (torch.chunk(Hs, W, -1) for Hs in H_splited)
226
+
227
+ # if HW_splited_mask[:,:,hi,wi].sum().item() == 0:
228
+ # attn_mask[:,:,hi,wi] = 0
229
+
230
+ # mask_patches = self.mask_embedding(pixel_level_attn_mask)[0]
231
+ # attn_mask = mask_patches.sum(-1) != 0
232
+
233
+ # return attn_mask
234
+
235
+ def extract_feat(self, img ,attn_mask=None):
236
+ x, *_ = self.forward(img,attn_mask)
237
+ if self.output_cls_token:
238
+ return x[:,0,:]
239
+ else:
240
+ return torch.mean(x,dim=1)
241
+
242
+ def forward(self, x, attn_mask=None):
243
+ if attn_mask is not None: assert self.output_cls_token
244
+
245
+ B = x.shape[0]
246
+ x = self.patch_embed(x)[0]
247
+ # add pos embed w/o cls token
248
+ x = x + self.pos_embed[:, 1:1+x.shape[1], :]
249
+ # masking: length -> length * mask_ratio
250
+ if True:
251
+ assert self.mask_ratio == 0.
252
+ else:
253
+ x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
254
+
255
+ # append cls token
256
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
257
+ cls_tokens = cls_token.expand(B, -1, -1)
258
+ x = torch.cat((cls_tokens, x), dim=1)
259
+ x = self.drop_after_pos(x)
260
+ # if attn_mask is not None:
261
+ # attn_mask = torch.concat((torch.ones((B,1),device=attn_mask.device) , attn_mask),dim=1)
262
+
263
+ for i, layer in enumerate(self.layers):
264
+ if self.gradientCKPT:
265
+ x = checkpoint(layer,x) # ,attn_mask
266
+ else:
267
+ x = layer(x) # ,attn_mask
268
+ if i == len(self.layers) - 1 and self.final_norm:
269
+ x = self.norm1(x)
270
+ if True:
271
+ return x
272
+ else:
273
+ return (x, mask, ids_restore)
274
+
275
+ def forward_generator(self, x, attn_mask=None):
276
+ if attn_mask is not None: assert self.output_cls_token
277
+
278
+ B = x.shape[0]
279
+ x = self.patch_embed(x)[0]
280
+ # add pos embed w/o cls token
281
+ x = x + self.pos_embed[:, 1:1+x.shape[1], :]
282
+
283
+ # append cls token
284
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
285
+ cls_tokens = cls_token.expand(B, -1, -1)
286
+ x = torch.cat((cls_tokens, x), dim=1)
287
+ x = self.drop_after_pos(x)
288
+
289
+ for i, layer in enumerate(self.layers):
290
+ if self.gradientCKPT:
291
+ x = checkpoint(layer,x) # ,attn_mask
292
+ else:
293
+ x = layer(x) # ,attn_mask
294
+
295
+ if i == len(self.layers) - 1 and self.final_norm:
296
+ x = self.norm1(x)
297
+
298
+ x = x if (new_x:=(yield x)) is None else new_x
299
+
300
+ debug = False
301
+ if debug:
302
+ print(f'layer {i}-th forwarded')
303
+
src/resampler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
2
+ # All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
3
+
4
+ import torch
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from einops_exts import rearrange_many, repeat_many
8
+
9
+
10
+ def FeedForward(dim, mult=4):
11
+ inner_dim = int(dim * mult)
12
+ return nn.Sequential(
13
+ nn.LayerNorm(dim),
14
+ nn.Linear(dim, inner_dim, bias=False),
15
+ nn.GELU(),
16
+ nn.Linear(inner_dim, dim, bias=False)
17
+ )
18
+
19
+
20
+ class PerceiverAttention(nn.Module):
21
+ def __init__(
22
+ self,
23
+ vision_width,
24
+ text_width,
25
+ dim_head=64,
26
+ heads=8
27
+ ):
28
+ super().__init__()
29
+
30
+ self.vision_width = vision_width
31
+ self.text_width = text_width
32
+
33
+ self.scale = dim_head ** -0.5
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm_media = nn.LayerNorm(vision_width)
38
+ self.norm_latents = nn.LayerNorm(text_width)
39
+
40
+ self.to_q = nn.Linear(text_width, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(vision_width, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, text_width, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ einstein notation
47
+ b - batch
48
+ t - time
49
+ n - sequence
50
+ d - dimension
51
+ """
52
+ x = self.norm_media(x)
53
+ latents = self.norm_latents(latents)
54
+
55
+ b, m, h = *x.shape[:2], self.heads
56
+
57
+ q = self.to_q(latents)
58
+
59
+ kv_input = x
60
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
61
+
62
+ q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h)
63
+
64
+ q = q * self.scale
65
+
66
+ # attention
67
+ sim = einsum('... i d, ... j d -> ... i j', q, k)
68
+
69
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
70
+ attn = sim.softmax(dim=-1)
71
+
72
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
73
+ out = rearrange(out, 'b h t n d -> b t n (h d)', h=h)
74
+ return self.to_out(out)
75
+
76
+
77
+ class PerceiverResampler(nn.Module):
78
+ def __init__(
79
+ self,
80
+ vision_width,
81
+ text_width,
82
+ depth,
83
+ dim_head=64,
84
+ heads=8,
85
+ num_latents=64,
86
+ ff_mult=4,
87
+ ):
88
+ super().__init__()
89
+ self.latents = nn.Parameter(torch.randn(num_latents, text_width))
90
+
91
+ self.layers = nn.ModuleList([])
92
+ for _ in range(depth):
93
+ self.layers.append(nn.ModuleList([
94
+ PerceiverAttention(vision_width=vision_width, text_width=text_width, dim_head=dim_head, heads=heads),
95
+ FeedForward(dim=text_width, mult=ff_mult)
96
+ ]))
97
+
98
+ self.norm = nn.LayerNorm(text_width)
99
+
100
+ def forward(self, vision_embeds=None, vision_atts=None):
101
+ x = vision_embeds
102
+
103
+ if x.ndim == 3:
104
+ x = rearrange(x, 'b n d -> b 1 n d')
105
+
106
+ latents = repeat(self.latents, 'n d -> b m n d', b=x.shape[0], m=x.shape[1])
107
+
108
+ for attn, ff in self.layers:
109
+ latents = attn(x, latents) + latents
110
+ latents = ff(latents) + latents
111
+
112
+ v2t_feats = self.norm(latents).squeeze(dim=1) # for image, squeeze dim=1
113
+ v2t_atts = torch.ones(v2t_feats.shape[:2], dtype=torch.long, device=v2t_feats.device)
114
+
115
+ return v2t_feats, v2t_atts
src/spectprompt.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pdb
4
+ from mmcv.cnn.bricks import padding
5
+ import torch
6
+ from torch import nn, einsum
7
+ from typing import Optional, Dict, Tuple
8
+ from src.mae_vit import MAEViT
9
+ from src.htsat import HTSAT_Swin_Transformer, create_htsat_model
10
+ from src.LMdecoder import LMDecoder, LMDecoder_qlora
11
+ from src.vision_transformer import VisionTransformer
12
+ from einops import rearrange, repeat
13
+ from einops_exts import rearrange_many
14
+ import inspect
15
+
16
+ class ArgsHandler:
17
+ def __init__(self, module, funcname, fargs, fkargs):
18
+ self.fargs = list(fargs)
19
+ self.fkargs = fkargs
20
+ func = getattr(module, funcname)
21
+ fal_repr = f"{funcname}_argnames_list"
22
+ if (argns_list:=getattr(module, fal_repr, None)) is None:
23
+ self.func_sig = inspect.signature(func)
24
+ self.argnames_list = list(self.func_sig.parameters.keys())
25
+ setattr(module, fal_repr, self.argnames_list)
26
+ else:
27
+ self.argnames_list = argns_list
28
+
29
+ def get_arg(self, arg_name):
30
+ if arg_name in self.fkargs:
31
+ arg = self.fkargs[arg_name]
32
+ else:
33
+ arg = self.fargs[self.argnames_list.index(arg_name)]
34
+ return arg
35
+
36
+ def set_arg(self, arg_name, arg_value):
37
+ if arg_name in self.fkargs:
38
+ self.fkargs[arg_name] = arg_value
39
+ else:
40
+ self.fargs[self.argnames_list.index(arg_name)] = arg_value
41
+
42
+ def return_all_args(self,):
43
+ return tuple(self.fargs), self.fkargs
44
+
45
+ class SquaredReLU(nn.Module):
46
+ """ squared ReLU activation function"""
47
+ def __init__(self):
48
+ super().__init__()
49
+
50
+ def forward(self, x):
51
+ return torch.pow(torch.relu(x), 2)
52
+
53
+ def FeedForward(dim, out_dim, mult=4, act='gelu'):
54
+ """
55
+ lucidrains implementation, slightly modified with the act parameter.
56
+ """
57
+
58
+ acts = dict(
59
+ gelu=nn.GELU,
60
+ sqrelu=SquaredReLU,
61
+ relu=nn.ReLU
62
+ )
63
+
64
+ assert act in acts, f"act. can only be one of {acts.keys()}"
65
+
66
+ inner_dim = int(dim * mult)
67
+ return nn.Sequential(
68
+ nn.LayerNorm(dim),
69
+ nn.Linear(dim, inner_dim, bias=False),
70
+ acts[act](),
71
+ nn.Linear(inner_dim, out_dim, bias=False)
72
+ )
73
+
74
+
75
+ class PerceiverAttentionLayer(nn.Module):
76
+ def __init__(
77
+ self,
78
+ *,
79
+ feat_dim,
80
+ latent_dim,
81
+ dim_head=64,
82
+ heads=8
83
+ ):
84
+ super().__init__()
85
+ self.scale = dim_head ** -0.5
86
+ self.heads = heads
87
+ self.dim_head = dim_head
88
+
89
+ inner_dim = dim_head * heads
90
+
91
+ # trainable components of PerceiverAttentionLayer
92
+ self.norm_media = nn.LayerNorm(feat_dim)
93
+ self.norm_latents = nn.LayerNorm(latent_dim)
94
+
95
+ self.to_q = nn.Linear(latent_dim, inner_dim, bias=False)
96
+ self.to_k = nn.Linear(feat_dim, inner_dim, bias=False)
97
+ self.to_v = nn.Linear(feat_dim, inner_dim, bias=False)
98
+ self.to_out = nn.Linear(inner_dim, latent_dim, bias=False)
99
+
100
+ def forward(self, features, latents):
101
+ """
102
+ Latent vectors are cross-attending to the visual features x.
103
+ :param x: Tensor (n_batch, n_features, dim)
104
+ visual features
105
+ :param latents: Tensor (n_batch, n_latents, dim)
106
+ latent learnt vectors from which the queries are computed.
107
+ Actually the same, just replicated in n_batch and n_frames dimension.
108
+ :return: Tensor (n_batch, n_latents, dim)
109
+ """
110
+ assert features.ndim == 3
111
+ assert latents.ndim == 3
112
+ assert features.shape[0] == latents.shape[0]
113
+ #assert features.shape[2] == latents.shape[2]
114
+
115
+ n_heads = self.heads
116
+ n_batch, n_features, dim = features.shape
117
+ n_queries = latents.shape[1]
118
+
119
+ # layer normalization, as usual
120
+ x = self.norm_media(features)
121
+ latents = self.norm_latents(latents)
122
+
123
+ # queries
124
+ # compute the queries from the latents, for all attention heads simultaneously.
125
+ q = self.to_q(latents)
126
+ q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads)
127
+ assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head])
128
+
129
+ # keys and values for all attention heads
130
+
131
+ '''
132
+ kv_input = torch.cat((x, latents), dim=-2)
133
+ n_features_latents = n_features + n_queries
134
+ '''
135
+
136
+ kv_input = x
137
+ n_features_latents = n_features
138
+
139
+ # keys, values
140
+ k = self.to_k(kv_input)
141
+ v = self.to_v(kv_input)
142
+ # batch, features, (heads, dim)
143
+
144
+ # split so we have an extra dimension for the heads
145
+ # q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h)
146
+ k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads)
147
+ assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head])
148
+
149
+ # scale queries?
150
+ q = q * self.scale
151
+
152
+ # attention
153
+
154
+ # attention scores
155
+ # sim = einsum('... i d, ... j d -> ... i j', q, k)
156
+ sim = einsum('b h q d, b h f d -> b h q f', q, k)
157
+
158
+ # Is this for numerical stability? Does not affect the result of the softmax operation
159
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
160
+ alphas = sim.softmax(dim=-1)
161
+
162
+ # out = einsum('... i j, ... j d -> ... i d', alphas, v)
163
+ out = einsum('b h q f, b h f v -> b h q v', alphas, v)
164
+
165
+ # out = rearrange(out, 'b h t n d -> b t n (h d)', h=h)
166
+ out = rearrange(out, 'b h q v -> b q (h v)')
167
+ return self.to_out(out)
168
+
169
+
170
+ class SpectPrompt(nn.Module):
171
+ """
172
+
173
+ Args:
174
+ backbone (dict): Config dict for encoder. Defaults to None.
175
+ neck (dict): Config dict for encoder. Defaults to None.
176
+ head (dict): Config dict for loss functions. Defaults to None.
177
+ init_cfg (dict, optional): Config dict for weight initialization.
178
+ Defaults to None.
179
+ """
180
+
181
+ def __init__(self,
182
+ backbone: dict,
183
+ neck: dict,
184
+ live_long_learning:bool=False, # TODO: costumize para or module
185
+ ) -> None:
186
+ super().__init__()
187
+ assert backbone is not None
188
+ bk_name = backbone.pop('name')
189
+ self.bk_name = bk_name
190
+ if bk_name == 'MAEViT':
191
+ ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
192
+ self.backbone = MAEViT(**backbone)
193
+ if ckpt_path is not None:
194
+ ckpt = torch.load( ckpt_path,'cpu')
195
+ self.backbone.load_state_dict(ckpt['state_dict'])
196
+
197
+ elif bk_name == 'HTSAT':
198
+ ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
199
+ self.backbone = create_htsat_model(backbone)
200
+ if ckpt_path is not None:
201
+ ckpt = torch.load( ckpt_path,'cpu')
202
+ self.backbone.load_state_dict(ckpt['state_dict'])
203
+ elif bk_name == 'qformer':
204
+ raise NotImplemented
205
+ else:
206
+ raise NotImplemented
207
+
208
+
209
+
210
+ # neck["num_patches"] = self.backbone.num_patches
211
+ # neck["patch_resolution"] = self.backbone.patch_resolution
212
+ assert neck is not None
213
+ nk_name = neck.pop('name')
214
+ if nk_name == 'LMDecoder':
215
+ self.neck = LMDecoder(**neck)
216
+ elif nk_name == 'LMDecoder_qlora':
217
+ self.neck = LMDecoder_qlora(**neck)
218
+ else:
219
+ raise NotImplemented
220
+ self.config = self.neck.LMconfig # TODO
221
+
222
+ '''
223
+ self.ae_proj = nn.Linear(
224
+ 768, self.config.hidden_size
225
+ )
226
+ '''
227
+
228
+ ## TODO
229
+
230
+ #self.neck.lm.apply(lambda m:m.gradient_checkpointing=True)
231
+ self.neck.lm.model.gradient_checkpointing = False
232
+
233
+ self.register_buffer('ones', torch.ones((1,4096), dtype=torch.long), persistent=False)
234
+ self.graft_adapter()
235
+ self.init_weights()
236
+
237
+ if False:
238
+ self.patch_llm()
239
+ self.first_run = True
240
+
241
+ def graft_adapter(self):
242
+ adapter_latent_len = 32
243
+ self.adapter_latent_len = adapter_latent_len
244
+ self.adapter_latent = nn.Parameter(torch.rand((1,adapter_latent_len, self.config.hidden_size), \
245
+ dtype=torch.float))
246
+ resampler_latent_len = 32
247
+ self.resampler_latent_len = resampler_latent_len
248
+ self.resampler_latent = nn.Parameter(torch.rand((1,resampler_latent_len, self.config.hidden_size), \
249
+ dtype=torch.float))
250
+ ## TODO
251
+ # self.adapter.pre_bn = torch.nn.BatchNorm1d(4096, affine=True)
252
+
253
+ self.adapter = nn.ModuleList([])
254
+
255
+ ff_mult = 4
256
+ heads=8
257
+ dim_head=512
258
+ act='gelu'
259
+
260
+ lm_dim = self.config.hidden_size
261
+ if self.bk_name == 'HTSAT':
262
+ feat_dim = 1024
263
+ depth = len(self.backbone.layers[2].blocks)
264
+ else:
265
+ feat_dim = 768
266
+ depth = int(len(self.neck.lm.model.layers)/2) # 16
267
+ for idx in range(depth):
268
+ self.adapter.append(nn.ModuleList([
269
+ Adapter(input_size=self.config.hidden_size),
270
+ # PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=dim_head, heads=heads),
271
+ # FeedForward(dim=lm_dim, out_dim=lm_dim, mult=1, act=act),
272
+ #FeedForward(dim=self.dim, out_dim=768, mult=ff_mult, act=act) if idx != depth-1 else nn.Identity()
273
+ ]))
274
+
275
+ self.samplers = nn.ModuleList([]) # add
276
+ for _ in range(3):
277
+ self.samplers.append(nn.ModuleList([
278
+ PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=64, heads=heads),
279
+ FeedForward(dim=lm_dim, out_dim=lm_dim, mult=4),
280
+ ]))
281
+ self.norm = nn.LayerNorm(lm_dim)
282
+
283
+ # self.agate_list = nn.ParameterList([])
284
+ # for i in range(len(self.neck.lm.model.layers)):
285
+ # self.agate_list.append(nn.Parameter(torch.zeros(lm_dim)))
286
+
287
+
288
+
289
+ def init_weights(self):
290
+ try:
291
+ super().init_weights()
292
+ except:
293
+ pass
294
+ # import traceback
295
+ # traceback.print_exc()
296
+ if getattr(self, 'adapter_latent', None) is not None:
297
+ self.adapter_latent.data.normal_(mean=0.0, std=0.02)
298
+ if getattr(self, 'resampler_latent', None) is not None:
299
+ self.adapter_latent.data.normal_(mean=0.0, std=0.02)
300
+
301
+ def forward_resampler(self, x):
302
+ # b, 768, 512
303
+ latents = repeat(self.resampler_latent, 'b n d -> (bs b) n d', bs=x.shape[0])
304
+ for attn, ff in self.samplers:
305
+ latents = attn(x, latents) + latents
306
+ latents = ff(latents) + latents
307
+ v2t_feats = self.norm(latents) #
308
+ # v2t_atts = torch.ones(v2t_feats.shape[:2], dtype=torch.long, device=v2t_feats.device)
309
+ return v2t_feats # bs, 32, dim_llm
310
+
311
+
312
+ def hook_adapter(self, audio_embedding, lm, v2t_feats):
313
+
314
+ class PHooker:
315
+ # model = self.backbone
316
+ # mgtr = self.backbone.forward_generator(spectrogram)
317
+ adapter = self.adapter
318
+ y = v2t_feats
319
+ handles_list = list()
320
+ cnter = 0
321
+ def layer_prehook(self, m, margs, mkargs):
322
+ ahl = ArgsHandler(m, 'forward', margs, mkargs)
323
+
324
+ # print(self.cnter)
325
+
326
+ # if self.cnter>=16:
327
+ # self.cnter+=1
328
+ # return None
329
+ adapt = self.adapter[self.cnter][0]
330
+
331
+ hs = ahl.get_arg("hidden_states")
332
+ adapter_residual = hs
333
+ neo_hs = adapt(hs, adapter_residual)
334
+
335
+ self.cnter+=1
336
+ ahl.set_arg("hidden_states", neo_hs)
337
+ return ahl.return_all_args()
338
+ def first_layer_prehook(self, m, margs, mkargs):
339
+ ahl = ArgsHandler(m, 'forward', margs, mkargs)
340
+ neo_lm_latents = self.y # torch.Size([128, 32, 4096])
341
+ hs = ahl.get_arg("hidden_states") # torch.Size([128, 87, 4096])
342
+ hs_msk = self.lm_ahl.get_arg("input_ids") < 0 # torch.Size([128, 87]) [False,, True*32, False,,]
343
+ # __import__('pdb').set_trace()
344
+ neo_hs = hs.masked_scatter(hs_msk.unsqueeze(-1), neo_lm_latents) # resampler hooker直接替换
345
+ ahl.set_arg("hidden_states", neo_hs)
346
+ return ahl.return_all_args()
347
+
348
+ def lm_prehook(self, m, margs, mkargs):
349
+ self.lm_ahl = ArgsHandler(m, 'forward', margs, mkargs)
350
+ return None
351
+ def last_layer_hook(self, m, margs, mkargs):
352
+ # __import__('pdb').set_trace()
353
+ self.cnter = 0
354
+
355
+ if getattr(lm,'phooker',False):
356
+ for _ in lm.phooker.handles_list:
357
+ _.remove()
358
+ del lm.phooker
359
+ lm.phooker = None
360
+ phooker = PHooker()
361
+ phooker.handles_list.append(lm.register_forward_pre_hook(phooker.lm_prehook, with_kwargs=True))
362
+ # 第一层插入
363
+ phooker.handles_list.append(lm.model.layers[0].register_forward_pre_hook(phooker.first_layer_prehook, with_kwargs=True))
364
+
365
+ for ii in range(1,len(lm.model.layers),2):
366
+ l = lm.model.layers[ii]
367
+ handle = l.register_forward_pre_hook(phooker.layer_prehook, with_kwargs=True)
368
+ phooker.handles_list.append(handle)
369
+ phooker.handles_list.append(lm.model.layers[-1].register_forward_pre_hook(phooker.last_layer_hook, with_kwargs=True))
370
+ lm.phooker = phooker
371
+ return None
372
+
373
+
374
+
375
+ def prepare_ids(self, batch, audio_ids):
376
+ toker = self.neck.tokenizer
377
+ # for idx, l in enumerate(self.neck.lm.model.layers):
378
+ # l.agate = self.agate_list[idx].clone() ## should clone the parameter
379
+
380
+ with torch.no_grad():
381
+
382
+ input_ids = batch['input_ids']
383
+ att_msk = batch['attention_mask']
384
+ au_crds = batch['audio_crds']
385
+ ans_crds = batch['ans_crds']
386
+ bsz = input_ids.shape[0]
387
+ # __import__('pdb').set_trace()
388
+ ## TODO
389
+ merged_ids, merged_msk, label_ids = list(), list(), list()
390
+ for i in range(bsz):
391
+ # cur_merged_ids = torch.cat([input_ids[i,:au_crds[i]], -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
392
+ cur_merged_ids = torch.cat([ -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
393
+
394
+ # cur_au_msk = self.ones[:,:audio_ids.shape[1]][0].clone().type_as(att_msk).detach()
395
+ cur_au_msk = torch.ones(audio_ids.shape[1], device=audio_ids.device)
396
+ # cur_merged_msk = torch.cat([att_msk[i,:au_crds[i]], cur_au_msk, att_msk[i,au_crds[i]:]])
397
+ cur_merged_msk = torch.cat([ cur_au_msk, att_msk[i,au_crds[i]:]])
398
+ cur_label_ids = cur_merged_ids.clone().detach()
399
+ cur_label_ids[:audio_ids.shape[1]+ans_crds[i]] = -100
400
+
401
+ merged_ids.append(cur_merged_ids)
402
+ merged_msk.append(cur_merged_msk)
403
+ label_ids.append(cur_label_ids)
404
+
405
+ merged_ids = torch.stack(merged_ids, dim=0)
406
+ merged_msk = torch.stack(merged_msk, dim=0)
407
+ label_ids = torch.stack(label_ids, dim=0)
408
+
409
+ assert merged_ids.shape[0] == bsz
410
+ assert merged_ids.shape == merged_msk.shape
411
+
412
+ label_msk = merged_msk.clone()
413
+ assert label_msk.shape == merged_msk.shape
414
+ assert merged_msk[:,-1].max() == 1
415
+
416
+ for i in range(len(ans_crds)):
417
+ label_ids[i,:audio_ids.shape[1]+ans_crds[i]].fill_(-100)
418
+
419
+
420
+ merged_labels = label_ids
421
+ merged_ids[merged_ids.eq(-100)] = toker.pad_token_id
422
+
423
+ return merged_ids, merged_msk, merged_labels
424
+
425
+ def forward(self, batch, **kwargs):
426
+ """Forward computation during training.
427
+
428
+ Args:
429
+ img (torch.Tensor): Input images of shape (N, C, H, W).
430
+ kwargs: Any keyword arguments to be used to forward.
431
+ Returns:
432
+ Dict[str, torch.Tensor]: A dictionary of loss components.
433
+ """
434
+
435
+ bsz = len(batch['input_ids'])
436
+ device = batch['input_ids'].device
437
+ float_type = next(self.parameters()).dtype
438
+ spectrogram = batch['spectrogram'].type(float_type)
439
+ audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
440
+ resampler_feats = self.forward_resampler(audio_embedding)
441
+ self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
442
+
443
+ # self.hook_resapmler(resampler_feats, self.neck.lm)
444
+
445
+ audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
446
+ assert audio_ids.max() < 100
447
+ merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
448
+
449
+ try:
450
+ assert merged_ids.shape == merged_labels.shape
451
+ outs = self.neck(input_ids=merged_ids.contiguous().long(),
452
+ flatten_embs=self.adapter_latent.flatten(0,1), # 32, 4096
453
+ # flatten_embs = resampler_feats.flatten(0,1), # b, 32, 4096
454
+ attention_mask=merged_msk.contiguous().long(),
455
+ labels=merged_labels.contiguous().long(), use_cache=False)
456
+ except Exception as e:
457
+ import traceback
458
+ traceback.print_exc()
459
+ __import__('remote_pdb').set_trace()
460
+ #outs.hidden_logits = self.hidden_logits
461
+
462
+ ## TODO
463
+ if eval(os.environ.get("doing_eval", 'False')):
464
+ outs.merged_ids = merged_ids.cpu()
465
+ outs.merged_labels = merged_labels.cpu()
466
+
467
+ return outs
468
+
469
+
470
+ def forward_test(self, batch, **kwargs):
471
+ """Forward computation during training.
472
+
473
+ Args:
474
+ img (torch.Tensor): Input images of shape (N, C, H, W).
475
+ kwargs: Any keyword arguments to be used to forward.
476
+ Returns:
477
+ Dict[str, torch.Tensor]: A dictionary of loss components.
478
+ """
479
+
480
+ assert self.training == False
481
+
482
+ bsz = len(batch['input_ids'])
483
+ device = batch['input_ids'].device
484
+ float_type = next(self.parameters()).dtype
485
+ spectrogram = batch['spectrogram'].type(float_type)
486
+ audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
487
+ resampler_feats = self.forward_resampler(audio_embedding)
488
+ self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
489
+ # self.extract_features(batch, self.neck.lm)
490
+ audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
491
+ assert audio_ids.max() < 100
492
+
493
+ merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
494
+ au_crds = batch['audio_crds']
495
+ ans_crds = batch['ans_crds']
496
+
497
+ aid_len = audio_ids.shape[-1]
498
+
499
+
500
+ toker = self.neck.tokenizer
501
+ with torch.no_grad():
502
+
503
+ ## TODO
504
+ pad_token = toker.encode(self.neck.tokenizer.eos_token)[0]
505
+ padded_merged_ids = self.ones[:, :aid_len+max(ans_crds)].repeat(bsz, 1).clone().detach() * pad_token
506
+ for i in range(bsz):
507
+ # for i in range(1):
508
+ assert au_crds[i] <= ans_crds[i]
509
+ cur_ids = merged_ids[i][:aid_len+ans_crds[i]]
510
+ padded_merged_ids[i][max(ans_crds)-ans_crds[i]:] = cur_ids
511
+ # __import__('pdb').set_trace()
512
+ outs = self.neck.generate(padded_merged_ids, self.adapter_latent.flatten(0,1))
513
+ #outs.hidden_logits = self.hidden_logits
514
+
515
+ return outs
516
+
517
+
518
+
519
+ import torch
520
+ from torch import nn
521
+
522
+ from transformers.activations import ACT2FN
523
+
524
+ class Adapter(nn.Module):
525
+ """
526
+ Implementation of a sequential bottleneck adapter block.
527
+ """
528
+ def __init__(
529
+ self,
530
+ input_size,
531
+ down_sample=None,
532
+ ):
533
+ super().__init__()
534
+
535
+ self.input_size = input_size
536
+
537
+ # if a downsample size is not passed, we just half the size of the original input
538
+ self.down_sample = down_sample
539
+ if down_sample is None:
540
+ self.down_sample = self.input_size // 2
541
+
542
+ self.adapter_norm_before = nn.LayerNorm(self.input_size)
543
+ self.adapter_down = nn.Linear(self.input_size, self.down_sample)
544
+ self.non_linearity = ACT2FN["silu"]
545
+
546
+ # Up projection to input size
547
+ self.adapter_up = nn.Linear(self.down_sample, self.input_size)
548
+
549
+ # Additional scaling factor (from He et al. (2021))
550
+ self.scaling = nn.Parameter(torch.ones(1))
551
+
552
+ self.adapter_down.apply(self._init_weights)
553
+ self.adapter_up.apply(self._init_weights)
554
+
555
+ def forward(self, x, residual_input): # , residual_input=None):
556
+
557
+ down = self.non_linearity(self.adapter_down(self.adapter_norm_before(x)))
558
+
559
+ up = self.adapter_up(down)
560
+ up = up * self.scaling
561
+ output = up
562
+
563
+ output = output + residual_input
564
+
565
+ return output
566
+
567
+ @staticmethod
568
+ def _init_weights(module):
569
+ """Initialize the weights."""
570
+ if isinstance(module, (nn.Linear, nn.Embedding)):
571
+ # std defaults to 0.02, this might need to be changed
572
+ module.weight.data.normal_(mean=0.0, std=0.02)
573
+ elif isinstance(module, nn.LayerNorm):
574
+ module.bias.data.zero_()
575
+ module.weight.data.fill_(1.0)
576
+ if isinstance(module, nn.Linear) and module.bias is not None:
577
+ module.bias.data.zero_()
src/stft.py ADDED
@@ -0,0 +1,1111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import argparse
3
+
4
+ import librosa
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.parameter import Parameter
11
+
12
+
13
+ class DFTBase(nn.Module):
14
+ def __init__(self):
15
+ r"""Base class for DFT and IDFT matrix.
16
+ """
17
+ super(DFTBase, self).__init__()
18
+
19
+ def dft_matrix(self, n):
20
+ (x, y) = np.meshgrid(np.arange(n), np.arange(n))
21
+ omega = np.exp(-2 * np.pi * 1j / n)
22
+ W = np.power(omega, x * y) # shape: (n, n)
23
+ return W
24
+
25
+ def idft_matrix(self, n):
26
+ (x, y) = np.meshgrid(np.arange(n), np.arange(n))
27
+ omega = np.exp(2 * np.pi * 1j / n)
28
+ W = np.power(omega, x * y) # shape: (n, n)
29
+ return W
30
+
31
+
32
+ class DFT(DFTBase):
33
+ def __init__(self, n, norm):
34
+ r"""Calculate discrete Fourier transform (DFT), inverse DFT (IDFT,
35
+ right DFT (RDFT) RDFT, and inverse RDFT (IRDFT.)
36
+
37
+ Args:
38
+ n: fft window size
39
+ norm: None | 'ortho'
40
+ """
41
+ super(DFT, self).__init__()
42
+
43
+ self.W = self.dft_matrix(n)
44
+ self.inv_W = self.idft_matrix(n)
45
+
46
+ self.W_real = torch.Tensor(np.real(self.W))
47
+ self.W_imag = torch.Tensor(np.imag(self.W))
48
+ self.inv_W_real = torch.Tensor(np.real(self.inv_W))
49
+ self.inv_W_imag = torch.Tensor(np.imag(self.inv_W))
50
+
51
+ self.n = n
52
+ self.norm = norm
53
+
54
+ def dft(self, x_real, x_imag):
55
+ r"""Calculate DFT of a signal.
56
+
57
+ Args:
58
+ x_real: (n,), real part of a signal
59
+ x_imag: (n,), imag part of a signal
60
+
61
+ Returns:
62
+ z_real: (n,), real part of output
63
+ z_imag: (n,), imag part of output
64
+ """
65
+ z_real = torch.matmul(x_real, self.W_real) - torch.matmul(x_imag, self.W_imag)
66
+ z_imag = torch.matmul(x_imag, self.W_real) + torch.matmul(x_real, self.W_imag)
67
+ # shape: (n,)
68
+
69
+ if self.norm is None:
70
+ pass
71
+ elif self.norm == 'ortho':
72
+ z_real /= math.sqrt(self.n)
73
+ z_imag /= math.sqrt(self.n)
74
+
75
+ return z_real, z_imag
76
+
77
+ def idft(self, x_real, x_imag):
78
+ r"""Calculate IDFT of a signal.
79
+
80
+ Args:
81
+ x_real: (n,), real part of a signal
82
+ x_imag: (n,), imag part of a signal
83
+ Returns:
84
+ z_real: (n,), real part of output
85
+ z_imag: (n,), imag part of output
86
+ """
87
+ z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag)
88
+ z_imag = torch.matmul(x_imag, self.inv_W_real) + torch.matmul(x_real, self.inv_W_imag)
89
+ # shape: (n,)
90
+
91
+ if self.norm is None:
92
+ z_real /= self.n
93
+ elif self.norm == 'ortho':
94
+ z_real /= math.sqrt(n)
95
+ z_imag /= math.sqrt(n)
96
+
97
+ return z_real, z_imag
98
+
99
+ def rdft(self, x_real):
100
+ r"""Calculate right RDFT of signal.
101
+
102
+ Args:
103
+ x_real: (n,), real part of a signal
104
+ x_imag: (n,), imag part of a signal
105
+
106
+ Returns:
107
+ z_real: (n // 2 + 1,), real part of output
108
+ z_imag: (n // 2 + 1,), imag part of output
109
+ """
110
+ n_rfft = self.n // 2 + 1
111
+ z_real = torch.matmul(x_real, self.W_real[..., 0 : n_rfft])
112
+ z_imag = torch.matmul(x_real, self.W_imag[..., 0 : n_rfft])
113
+ # shape: (n // 2 + 1,)
114
+
115
+ if self.norm is None:
116
+ pass
117
+ elif self.norm == 'ortho':
118
+ z_real /= math.sqrt(self.n)
119
+ z_imag /= math.sqrt(self.n)
120
+
121
+ return z_real, z_imag
122
+
123
+ def irdft(self, x_real, x_imag):
124
+ r"""Calculate IRDFT of signal.
125
+
126
+ Args:
127
+ x_real: (n // 2 + 1,), real part of a signal
128
+ x_imag: (n // 2 + 1,), imag part of a signal
129
+
130
+ Returns:
131
+ z_real: (n,), real part of output
132
+ z_imag: (n,), imag part of output
133
+ """
134
+ n_rfft = self.n // 2 + 1
135
+
136
+ flip_x_real = torch.flip(x_real, dims=(-1,))
137
+ flip_x_imag = torch.flip(x_imag, dims=(-1,))
138
+ # shape: (n // 2 + 1,)
139
+
140
+ x_real = torch.cat((x_real, flip_x_real[..., 1 : n_rfft - 1]), dim=-1)
141
+ x_imag = torch.cat((x_imag, -1. * flip_x_imag[..., 1 : n_rfft - 1]), dim=-1)
142
+ # shape: (n,)
143
+
144
+ z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag)
145
+ # shape: (n,)
146
+
147
+ if self.norm is None:
148
+ z_real /= self.n
149
+ elif self.norm == 'ortho':
150
+ z_real /= math.sqrt(n)
151
+
152
+ return z_real
153
+
154
+
155
+ class STFT(DFTBase):
156
+ def __init__(self, n_fft=2048, hop_length=None, win_length=None,
157
+ window='hann', center=True, pad_mode='reflect', freeze_parameters=True):
158
+ r"""PyTorch implementation of STFT with Conv1d. The function has the
159
+ same output as librosa.stft.
160
+
161
+ Args:
162
+ n_fft: int, fft window size, e.g., 2048
163
+ hop_length: int, hop length samples, e.g., 441
164
+ win_length: int, window length e.g., 2048
165
+ window: str, window function name, e.g., 'hann'
166
+ center: bool
167
+ pad_mode: str, e.g., 'reflect'
168
+ freeze_parameters: bool, set to True to freeze all parameters. Set
169
+ to False to finetune all parameters.
170
+ """
171
+ super(STFT, self).__init__()
172
+
173
+ assert pad_mode in ['constant', 'reflect']
174
+
175
+ self.n_fft = n_fft
176
+ self.hop_length = hop_length
177
+ self.win_length = win_length
178
+ self.window = window
179
+ self.center = center
180
+ self.pad_mode = pad_mode
181
+
182
+ # By default, use the entire frame.
183
+ if self.win_length is None:
184
+ self.win_length = n_fft
185
+
186
+ # Set the default hop, if it's not already specified.
187
+ if self.hop_length is None:
188
+ self.hop_length = int(self.win_length // 4)
189
+
190
+ fft_window = librosa.filters.get_window(window, self.win_length, fftbins=True)
191
+
192
+ # Pad the window out to n_fft size.
193
+ fft_window = librosa.util.pad_center(fft_window, size=n_fft)
194
+
195
+ # DFT & IDFT matrix.
196
+ self.W = self.dft_matrix(n_fft)
197
+
198
+ out_channels = n_fft // 2 + 1
199
+
200
+ self.conv_real = nn.Conv1d(in_channels=1, out_channels=out_channels,
201
+ kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1,
202
+ groups=1, bias=False)
203
+
204
+ self.conv_imag = nn.Conv1d(in_channels=1, out_channels=out_channels,
205
+ kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1,
206
+ groups=1, bias=False)
207
+
208
+ # Initialize Conv1d weights.
209
+ self.conv_real.weight.data.copy_(torch.Tensor(
210
+ np.real(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :])
211
+ # (n_fft // 2 + 1, 1, n_fft)
212
+
213
+ self.conv_imag.weight.data.copy_(torch.Tensor(
214
+ np.imag(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :])
215
+ # (n_fft // 2 + 1, 1, n_fft)
216
+
217
+ if freeze_parameters:
218
+ for param in self.parameters():
219
+ param.requires_grad = False
220
+
221
+ def forward(self, input):
222
+ r"""Calculate STFT of batch of signals.
223
+
224
+ Args:
225
+ input: (batch_size, data_length), input signals.
226
+
227
+ Returns:
228
+ real: (batch_size, 1, time_steps, n_fft // 2 + 1)
229
+ imag: (batch_size, 1, time_steps, n_fft // 2 + 1)
230
+ """
231
+
232
+ x = input[:, None, :] # (batch_size, channels_num, data_length)
233
+
234
+ if self.center:
235
+ x = F.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode)
236
+
237
+ real = self.conv_real(x)
238
+ imag = self.conv_imag(x)
239
+ # (batch_size, n_fft // 2 + 1, time_steps)
240
+
241
+ real = real[:, None, :, :].transpose(2, 3)
242
+ imag = imag[:, None, :, :].transpose(2, 3)
243
+ # (batch_size, 1, time_steps, n_fft // 2 + 1)
244
+
245
+ return real, imag
246
+
247
+
248
+ def magphase(real, imag):
249
+ r"""Calculate magnitude and phase from real and imag part of signals.
250
+
251
+ Args:
252
+ real: tensor, real part of signals
253
+ imag: tensor, imag part of signals
254
+
255
+ Returns:
256
+ mag: tensor, magnitude of signals
257
+ cos: tensor, cosine of phases of signals
258
+ sin: tensor, sine of phases of signals
259
+ """
260
+ mag = (real ** 2 + imag ** 2) ** 0.5
261
+ cos = real / torch.clamp(mag, 1e-10, np.inf)
262
+ sin = imag / torch.clamp(mag, 1e-10, np.inf)
263
+
264
+ return mag, cos, sin
265
+
266
+
267
+ class ISTFT(DFTBase):
268
+ def __init__(self, n_fft=2048, hop_length=None, win_length=None,
269
+ window='hann', center=True, pad_mode='reflect', freeze_parameters=True,
270
+ onnx=False, frames_num=None, device=None):
271
+ """PyTorch implementation of ISTFT with Conv1d. The function has the
272
+ same output as librosa.istft.
273
+
274
+ Args:
275
+ n_fft: int, fft window size, e.g., 2048
276
+ hop_length: int, hop length samples, e.g., 441
277
+ win_length: int, window length e.g., 2048
278
+ window: str, window function name, e.g., 'hann'
279
+ center: bool
280
+ pad_mode: str, e.g., 'reflect'
281
+ freeze_parameters: bool, set to True to freeze all parameters. Set
282
+ to False to finetune all parameters.
283
+ onnx: bool, set to True when exporting trained model to ONNX. This
284
+ will replace several operations to operators supported by ONNX.
285
+ frames_num: None | int, number of frames of audio clips to be
286
+ inferneced. Only useable when onnx=True.
287
+ device: None | str, device of ONNX. Only useable when onnx=True.
288
+ """
289
+ super(ISTFT, self).__init__()
290
+
291
+ assert pad_mode in ['constant', 'reflect']
292
+
293
+ if not onnx:
294
+ assert frames_num is None, "When onnx=False, frames_num must be None!"
295
+ assert device is None, "When onnx=False, device must be None!"
296
+
297
+ self.n_fft = n_fft
298
+ self.hop_length = hop_length
299
+ self.win_length = win_length
300
+ self.window = window
301
+ self.center = center
302
+ self.pad_mode = pad_mode
303
+ self.onnx = onnx
304
+
305
+ # By default, use the entire frame.
306
+ if self.win_length is None:
307
+ self.win_length = self.n_fft
308
+
309
+ # Set the default hop, if it's not already specified.
310
+ if self.hop_length is None:
311
+ self.hop_length = int(self.win_length // 4)
312
+
313
+ # Initialize Conv1d modules for calculating real and imag part of DFT.
314
+ self.init_real_imag_conv()
315
+
316
+ # Initialize overlap add window for reconstruct time domain signals.
317
+ self.init_overlap_add_window()
318
+
319
+ if self.onnx:
320
+ # Initialize ONNX modules.
321
+ self.init_onnx_modules(frames_num, device)
322
+
323
+ if freeze_parameters:
324
+ for param in self.parameters():
325
+ param.requires_grad = False
326
+
327
+ def init_real_imag_conv(self):
328
+ r"""Initialize Conv1d for calculating real and imag part of DFT.
329
+ """
330
+ self.W = self.idft_matrix(self.n_fft) / self.n_fft
331
+
332
+ self.conv_real = nn.Conv1d(in_channels=self.n_fft, out_channels=self.n_fft,
333
+ kernel_size=1, stride=1, padding=0, dilation=1,
334
+ groups=1, bias=False)
335
+
336
+ self.conv_imag = nn.Conv1d(in_channels=self.n_fft, out_channels=self.n_fft,
337
+ kernel_size=1, stride=1, padding=0, dilation=1,
338
+ groups=1, bias=False)
339
+
340
+ ifft_window = librosa.filters.get_window(self.window, self.win_length, fftbins=True)
341
+ # (win_length,)
342
+
343
+ # Pad the window to n_fft
344
+ ifft_window = librosa.util.pad_center(ifft_window, size=self.n_fft)
345
+
346
+ self.conv_real.weight.data = torch.Tensor(
347
+ np.real(self.W * ifft_window[None, :]).T)[:, :, None]
348
+ # (n_fft // 2 + 1, 1, n_fft)
349
+
350
+ self.conv_imag.weight.data = torch.Tensor(
351
+ np.imag(self.W * ifft_window[None, :]).T)[:, :, None]
352
+ # (n_fft // 2 + 1, 1, n_fft)
353
+
354
+ def init_overlap_add_window(self):
355
+ r"""Initialize overlap add window for reconstruct time domain signals.
356
+ """
357
+
358
+ ola_window = librosa.filters.get_window(self.window, self.win_length, fftbins=True)
359
+ # (win_length,)
360
+
361
+ ola_window = librosa.util.normalize(ola_window, norm=None) ** 2
362
+ ola_window = librosa.util.pad_center(ola_window, size=self.n_fft)
363
+ ola_window = torch.Tensor(ola_window)
364
+
365
+ self.register_buffer('ola_window', ola_window)
366
+ # (win_length,)
367
+
368
+ def init_onnx_modules(self, frames_num, device):
369
+ r"""Initialize ONNX modules.
370
+
371
+ Args:
372
+ frames_num: int
373
+ device: str | None
374
+ """
375
+
376
+ # Use Conv1d to implement torch.flip(), because torch.flip() is not
377
+ # supported by ONNX.
378
+ self.reverse = nn.Conv1d(in_channels=self.n_fft // 2 + 1,
379
+ out_channels=self.n_fft // 2 - 1, kernel_size=1, bias=False)
380
+
381
+ tmp = np.zeros((self.n_fft // 2 - 1, self.n_fft // 2 + 1, 1))
382
+ tmp[:, 1 : -1, 0] = np.array(np.eye(self.n_fft // 2 - 1)[::-1])
383
+ self.reverse.weight.data = torch.Tensor(tmp)
384
+ # (n_fft // 2 - 1, n_fft // 2 + 1, 1)
385
+
386
+ # Use nn.ConvTranspose2d to implement torch.nn.functional.fold(),
387
+ # because torch.nn.functional.fold() is not supported by ONNX.
388
+ self.overlap_add = nn.ConvTranspose2d(in_channels=self.n_fft,
389
+ out_channels=1, kernel_size=(self.n_fft, 1), stride=(self.hop_length, 1), bias=False)
390
+
391
+ self.overlap_add.weight.data = torch.Tensor(np.eye(self.n_fft)[:, None, :, None])
392
+ # (n_fft, 1, n_fft, 1)
393
+
394
+ if frames_num:
395
+ # Pre-calculate overlap-add window sum for reconstructing signals
396
+ # when using ONNX.
397
+ self.ifft_window_sum = self._get_ifft_window_sum_onnx(frames_num, device)
398
+ else:
399
+ self.ifft_window_sum = []
400
+
401
+ def forward(self, real_stft, imag_stft, length):
402
+ r"""Calculate inverse STFT.
403
+
404
+ Args:
405
+ real_stft: (batch_size, channels=1, time_steps, n_fft // 2 + 1)
406
+ imag_stft: (batch_size, channels=1, time_steps, n_fft // 2 + 1)
407
+ length: int
408
+
409
+ Returns:
410
+ real: (batch_size, data_length), output signals.
411
+ """
412
+ assert real_stft.ndimension() == 4 and imag_stft.ndimension() == 4
413
+ batch_size, _, frames_num, _ = real_stft.shape
414
+
415
+ real_stft = real_stft[:, 0, :, :].transpose(1, 2)
416
+ imag_stft = imag_stft[:, 0, :, :].transpose(1, 2)
417
+ # (batch_size, n_fft // 2 + 1, time_steps)
418
+
419
+ # Get full stft representation from spectrum using symmetry attribute.
420
+ if self.onnx:
421
+ full_real_stft, full_imag_stft = self._get_full_stft_onnx(real_stft, imag_stft)
422
+ else:
423
+ full_real_stft, full_imag_stft = self._get_full_stft(real_stft, imag_stft)
424
+ # full_real_stft: (batch_size, n_fft, time_steps)
425
+ # full_imag_stft: (batch_size, n_fft, time_steps)
426
+
427
+ # Calculate IDFT frame by frame.
428
+ s_real = self.conv_real(full_real_stft) - self.conv_imag(full_imag_stft)
429
+ # (batch_size, n_fft, time_steps)
430
+
431
+ # Overlap add signals in frames to reconstruct signals.
432
+ if self.onnx:
433
+ y = self._overlap_add_divide_window_sum_onnx(s_real, frames_num)
434
+ else:
435
+ y = self._overlap_add_divide_window_sum(s_real, frames_num)
436
+ # y: (batch_size, audio_samples + win_length,)
437
+
438
+ y = self._trim_edges(y, length)
439
+ # (batch_size, audio_samples,)
440
+
441
+ return y
442
+
443
+ def _get_full_stft(self, real_stft, imag_stft):
444
+ r"""Get full stft representation from spectrum using symmetry attribute.
445
+
446
+ Args:
447
+ real_stft: (batch_size, n_fft // 2 + 1, time_steps)
448
+ imag_stft: (batch_size, n_fft // 2 + 1, time_steps)
449
+
450
+ Returns:
451
+ full_real_stft: (batch_size, n_fft, time_steps)
452
+ full_imag_stft: (batch_size, n_fft, time_steps)
453
+ """
454
+ full_real_stft = torch.cat((real_stft, torch.flip(real_stft[:, 1 : -1, :], dims=[1])), dim=1)
455
+ full_imag_stft = torch.cat((imag_stft, - torch.flip(imag_stft[:, 1 : -1, :], dims=[1])), dim=1)
456
+
457
+ return full_real_stft, full_imag_stft
458
+
459
+ def _get_full_stft_onnx(self, real_stft, imag_stft):
460
+ r"""Get full stft representation from spectrum using symmetry attribute
461
+ for ONNX. Replace several pytorch operations in self._get_full_stft()
462
+ that are not supported by ONNX.
463
+
464
+ Args:
465
+ real_stft: (batch_size, n_fft // 2 + 1, time_steps)
466
+ imag_stft: (batch_size, n_fft // 2 + 1, time_steps)
467
+
468
+ Returns:
469
+ full_real_stft: (batch_size, n_fft, time_steps)
470
+ full_imag_stft: (batch_size, n_fft, time_steps)
471
+ """
472
+
473
+ # Implement torch.flip() with Conv1d.
474
+ full_real_stft = torch.cat((real_stft, self.reverse(real_stft)), dim=1)
475
+ full_imag_stft = torch.cat((imag_stft, - self.reverse(imag_stft)), dim=1)
476
+
477
+ return full_real_stft, full_imag_stft
478
+
479
+ def _overlap_add_divide_window_sum(self, s_real, frames_num):
480
+ r"""Overlap add signals in frames to reconstruct signals.
481
+
482
+ Args:
483
+ s_real: (batch_size, n_fft, time_steps), signals in frames
484
+ frames_num: int
485
+
486
+ Returns:
487
+ y: (batch_size, audio_samples)
488
+ """
489
+
490
+ output_samples = (s_real.shape[-1] - 1) * self.hop_length + self.win_length
491
+ # (audio_samples,)
492
+
493
+ # Overlap-add signals in frames to signals. Ref:
494
+ # asteroid_filterbanks.torch_stft_fb.torch_stft_fb() from
495
+ # https://github.com/asteroid-team/asteroid-filterbanks
496
+ y = torch.nn.functional.fold(input=s_real, output_size=(1, output_samples),
497
+ kernel_size=(1, self.win_length), stride=(1, self.hop_length))
498
+ # (batch_size, 1, 1, audio_samples,)
499
+
500
+ y = y[:, 0, 0, :]
501
+ # (batch_size, audio_samples)
502
+
503
+ # Get overlap-add window sum to be divided.
504
+ ifft_window_sum = self._get_ifft_window(frames_num)
505
+ # (audio_samples,)
506
+
507
+ # Following code is abandaned for divide overlap-add window, because
508
+ # not supported by half precision training and ONNX.
509
+ # min_mask = ifft_window_sum.abs() < 1e-11
510
+ # y[:, ~min_mask] = y[:, ~min_mask] / ifft_window_sum[None, ~min_mask]
511
+ # # (batch_size, audio_samples)
512
+
513
+ ifft_window_sum = torch.clamp(ifft_window_sum, 1e-11, np.inf)
514
+ # (audio_samples,)
515
+
516
+ y = y / ifft_window_sum[None, :]
517
+ # (batch_size, audio_samples,)
518
+
519
+ return y
520
+
521
+ def _get_ifft_window(self, frames_num):
522
+ r"""Get overlap-add window sum to be divided.
523
+
524
+ Args:
525
+ frames_num: int
526
+
527
+ Returns:
528
+ ifft_window_sum: (audio_samlpes,), overlap-add window sum to be
529
+ divided.
530
+ """
531
+
532
+ output_samples = (frames_num - 1) * self.hop_length + self.win_length
533
+ # (audio_samples,)
534
+
535
+ window_matrix = self.ola_window[None, :, None].repeat(1, 1, frames_num)
536
+ # (batch_size, win_length, time_steps)
537
+
538
+ ifft_window_sum = F.fold(input=window_matrix,
539
+ output_size=(1, output_samples), kernel_size=(1, self.win_length),
540
+ stride=(1, self.hop_length))
541
+ # (1, 1, 1, audio_samples)
542
+
543
+ ifft_window_sum = ifft_window_sum.squeeze()
544
+ # (audio_samlpes,)
545
+
546
+ return ifft_window_sum
547
+
548
+ def _overlap_add_divide_window_sum_onnx(self, s_real, frames_num):
549
+ r"""Overlap add signals in frames to reconstruct signals for ONNX.
550
+ Replace several pytorch operations in
551
+ self._overlap_add_divide_window_sum() that are not supported by ONNX.
552
+
553
+ Args:
554
+ s_real: (batch_size, n_fft, time_steps), signals in frames
555
+ frames_num: int
556
+
557
+ Returns:
558
+ y: (batch_size, audio_samples)
559
+ """
560
+
561
+ s_real = s_real[..., None]
562
+ # (batch_size, n_fft, time_steps, 1)
563
+
564
+ # Implement overlap-add with Conv1d, because torch.nn.functional.fold()
565
+ # is not supported by ONNX.
566
+ y = self.overlap_add(s_real)[:, 0, :, 0]
567
+ # y: (batch_size, samples_num)
568
+
569
+ if len(self.ifft_window_sum) != y.shape[1]:
570
+ device = s_real.device
571
+
572
+ self.ifft_window_sum = self._get_ifft_window_sum_onnx(frames_num, device)
573
+ # (audio_samples,)
574
+
575
+ # Use torch.clamp() to prevent from underflow to make sure all
576
+ # operations are supported by ONNX.
577
+ ifft_window_sum = torch.clamp(self.ifft_window_sum, 1e-11, np.inf)
578
+ # (audio_samples,)
579
+
580
+ y = y / ifft_window_sum[None, :]
581
+ # (batch_size, audio_samples,)
582
+
583
+ return y
584
+
585
+ def _get_ifft_window_sum_onnx(self, frames_num, device):
586
+ r"""Pre-calculate overlap-add window sum for reconstructing signals when
587
+ using ONNX.
588
+
589
+ Args:
590
+ frames_num: int
591
+ device: str | None
592
+
593
+ Returns:
594
+ ifft_window_sum: (audio_samples,)
595
+ """
596
+
597
+ ifft_window_sum = librosa.filters.window_sumsquare(window=self.window,
598
+ n_frames=frames_num, win_length=self.win_length, n_fft=self.n_fft,
599
+ hop_length=self.hop_length)
600
+ # (audio_samples,)
601
+
602
+ ifft_window_sum = torch.Tensor(ifft_window_sum)
603
+
604
+ if device:
605
+ ifft_window_sum = ifft_window_sum.to(device)
606
+
607
+ return ifft_window_sum
608
+
609
+ def _trim_edges(self, y, length):
610
+ r"""Trim audio.
611
+
612
+ Args:
613
+ y: (audio_samples,)
614
+ length: int
615
+
616
+ Returns:
617
+ (trimmed_audio_samples,)
618
+ """
619
+ # Trim or pad to length
620
+ if length is None:
621
+ if self.center:
622
+ y = y[:, self.n_fft // 2 : -self.n_fft // 2]
623
+ else:
624
+ if self.center:
625
+ start = self.n_fft // 2
626
+ else:
627
+ start = 0
628
+
629
+ y = y[:, start : start + length]
630
+
631
+ return y
632
+
633
+
634
+ class Spectrogram(nn.Module):
635
+ def __init__(self, n_fft=2048, hop_length=None, win_length=None,
636
+ window='hann', center=True, pad_mode='reflect', power=2.0,
637
+ freeze_parameters=True):
638
+ r"""Calculate spectrogram using pytorch. The STFT is implemented with
639
+ Conv1d. The function has the same output of librosa.stft
640
+ """
641
+ super(Spectrogram, self).__init__()
642
+
643
+ self.power = power
644
+
645
+ self.stft = STFT(n_fft=n_fft, hop_length=hop_length,
646
+ win_length=win_length, window=window, center=center,
647
+ pad_mode=pad_mode, freeze_parameters=True)
648
+
649
+ def forward(self, input):
650
+ r"""Calculate spectrogram of input signals.
651
+ Args:
652
+ input: (batch_size, data_length)
653
+
654
+ Returns:
655
+ spectrogram: (batch_size, 1, time_steps, n_fft // 2 + 1)
656
+ """
657
+
658
+ (real, imag) = self.stft.forward(input)
659
+ # (batch_size, n_fft // 2 + 1, time_steps)
660
+
661
+ spectrogram = real ** 2 + imag ** 2
662
+
663
+ if self.power == 2.0:
664
+ pass
665
+ else:
666
+ spectrogram = spectrogram ** (self.power / 2.0)
667
+
668
+ return spectrogram
669
+
670
+
671
+ class LogmelFilterBank(nn.Module):
672
+ def __init__(self, sr=22050, n_fft=2048, n_mels=64, fmin=0.0, fmax=None,
673
+ is_log=True, ref=1.0, amin=1e-10, top_db=80.0, freeze_parameters=True):
674
+ r"""Calculate logmel spectrogram using pytorch. The mel filter bank is
675
+ the pytorch implementation of as librosa.filters.mel
676
+ """
677
+ super(LogmelFilterBank, self).__init__()
678
+
679
+ self.is_log = is_log
680
+ self.ref = ref
681
+ self.amin = amin
682
+ self.top_db = top_db
683
+ if fmax == None:
684
+ fmax = sr//2
685
+
686
+ self.melW = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels,
687
+ fmin=fmin, fmax=fmax).T
688
+ # (n_fft // 2 + 1, mel_bins)
689
+
690
+ self.melW = nn.Parameter(torch.Tensor(self.melW).contiguous())
691
+
692
+ if freeze_parameters:
693
+ for param in self.parameters():
694
+ param.requires_grad = False
695
+
696
+ def forward(self, input):
697
+ r"""Calculate (log) mel spectrogram from spectrogram.
698
+
699
+ Args:
700
+ input: (*, n_fft), spectrogram
701
+
702
+ Returns:
703
+ output: (*, mel_bins), (log) mel spectrogram
704
+ """
705
+
706
+ # Mel spectrogram
707
+ mel_spectrogram = torch.matmul(input, self.melW)
708
+ # (*, mel_bins)
709
+
710
+ # Logmel spectrogram
711
+ if self.is_log:
712
+ output = self.power_to_db(mel_spectrogram)
713
+ else:
714
+ output = mel_spectrogram
715
+
716
+ return output
717
+
718
+
719
+ def power_to_db(self, input):
720
+ r"""Power to db, this function is the pytorch implementation of
721
+ librosa.power_to_lb
722
+ """
723
+ ref_value = self.ref
724
+ log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf))
725
+ log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value))
726
+
727
+ if self.top_db is not None:
728
+ if self.top_db < 0:
729
+ raise librosa.util.exceptions.ParameterError('top_db must be non-negative')
730
+ log_spec = torch.clamp(log_spec, min=log_spec.max().item() - self.top_db, max=np.inf)
731
+
732
+ return log_spec
733
+
734
+
735
+ class Enframe(nn.Module):
736
+ def __init__(self, frame_length=2048, hop_length=512):
737
+ r"""Enframe a time sequence. This function is the pytorch implementation
738
+ of librosa.util.frame
739
+ """
740
+ super(Enframe, self).__init__()
741
+
742
+ self.enframe_conv = nn.Conv1d(in_channels=1, out_channels=frame_length,
743
+ kernel_size=frame_length, stride=hop_length,
744
+ padding=0, bias=False)
745
+
746
+ self.enframe_conv.weight.data = torch.Tensor(torch.eye(frame_length)[:, None, :])
747
+ self.enframe_conv.weight.requires_grad = False
748
+
749
+ def forward(self, input):
750
+ r"""Enframe signals into frames.
751
+ Args:
752
+ input: (batch_size, samples)
753
+
754
+ Returns:
755
+ output: (batch_size, window_length, frames_num)
756
+ """
757
+ output = self.enframe_conv(input[:, None, :])
758
+ return output
759
+
760
+
761
+ def power_to_db(self, input):
762
+ r"""Power to db, this function is the pytorch implementation of
763
+ librosa.power_to_lb.
764
+ """
765
+ ref_value = self.ref
766
+ log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf))
767
+ log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value))
768
+
769
+ if self.top_db is not None:
770
+ if self.top_db < 0:
771
+ raise librosa.util.exceptions.ParameterError('top_db must be non-negative')
772
+ log_spec = torch.clamp(log_spec, min=log_spec.max() - self.top_db, max=np.inf)
773
+
774
+ return log_spec
775
+
776
+
777
+ class Scalar(nn.Module):
778
+ def __init__(self, scalar, freeze_parameters):
779
+ super(Scalar, self).__init__()
780
+
781
+ self.scalar_mean = Parameter(torch.Tensor(scalar['mean']))
782
+ self.scalar_std = Parameter(torch.Tensor(scalar['std']))
783
+
784
+ if freeze_parameters:
785
+ for param in self.parameters():
786
+ param.requires_grad = False
787
+
788
+ def forward(self, input):
789
+ return (input - self.scalar_mean) / self.scalar_std
790
+
791
+
792
+ def debug(select, device):
793
+ """Compare numpy + librosa and torchlibrosa results. For debug.
794
+
795
+ Args:
796
+ select: 'dft' | 'logmel'
797
+ device: 'cpu' | 'cuda'
798
+ """
799
+
800
+ if select == 'dft':
801
+ n = 10
802
+ norm = None # None | 'ortho'
803
+ np.random.seed(0)
804
+
805
+ # Data
806
+ np_data = np.random.uniform(-1, 1, n)
807
+ pt_data = torch.Tensor(np_data)
808
+
809
+ # Numpy FFT
810
+ np_fft = np.fft.fft(np_data, norm=norm)
811
+ np_ifft = np.fft.ifft(np_fft, norm=norm)
812
+ np_rfft = np.fft.rfft(np_data, norm=norm)
813
+ np_irfft = np.fft.ifft(np_rfft, norm=norm)
814
+
815
+ # Pytorch FFT
816
+ obj = DFT(n, norm)
817
+ pt_dft = obj.dft(pt_data, torch.zeros_like(pt_data))
818
+ pt_idft = obj.idft(pt_dft[0], pt_dft[1])
819
+ pt_rdft = obj.rdft(pt_data)
820
+ pt_irdft = obj.irdft(pt_rdft[0], pt_rdft[1])
821
+
822
+ print('Comparing librosa and pytorch implementation of DFT. All numbers '
823
+ 'below should be close to 0.')
824
+ print(np.mean((np.abs(np.real(np_fft) - pt_dft[0].cpu().numpy()))))
825
+ print(np.mean((np.abs(np.imag(np_fft) - pt_dft[1].cpu().numpy()))))
826
+
827
+ print(np.mean((np.abs(np.real(np_ifft) - pt_idft[0].cpu().numpy()))))
828
+ print(np.mean((np.abs(np.imag(np_ifft) - pt_idft[1].cpu().numpy()))))
829
+
830
+ print(np.mean((np.abs(np.real(np_rfft) - pt_rdft[0].cpu().numpy()))))
831
+ print(np.mean((np.abs(np.imag(np_rfft) - pt_rdft[1].cpu().numpy()))))
832
+
833
+ print(np.mean(np.abs(np_data - pt_irdft.cpu().numpy())))
834
+
835
+ elif select == 'stft':
836
+ device = torch.device(device)
837
+ np.random.seed(0)
838
+
839
+ # Spectrogram parameters (the same as librosa.stft)
840
+ sample_rate = 22050
841
+ data_length = sample_rate * 1
842
+ n_fft = 2048
843
+ hop_length = 512
844
+ win_length = 2048
845
+ window = 'hann'
846
+ center = True
847
+ pad_mode = 'reflect'
848
+
849
+ # Data
850
+ np_data = np.random.uniform(-1, 1, data_length)
851
+ pt_data = torch.Tensor(np_data).to(device)
852
+
853
+ # Numpy stft matrix
854
+ np_stft_matrix = librosa.stft(y=np_data, n_fft=n_fft,
855
+ hop_length=hop_length, window=window, center=center).T
856
+
857
+ # Pytorch stft matrix
858
+ pt_stft_extractor = STFT(n_fft=n_fft, hop_length=hop_length,
859
+ win_length=win_length, window=window, center=center, pad_mode=pad_mode,
860
+ freeze_parameters=True)
861
+
862
+ pt_stft_extractor.to(device)
863
+
864
+ (pt_stft_real, pt_stft_imag) = pt_stft_extractor.forward(pt_data[None, :])
865
+
866
+ print('Comparing librosa and pytorch implementation of STFT & ISTFT. \
867
+ All numbers below should be close to 0.')
868
+ print(np.mean(np.abs(np.real(np_stft_matrix) - pt_stft_real.data.cpu().numpy()[0, 0])))
869
+ print(np.mean(np.abs(np.imag(np_stft_matrix) - pt_stft_imag.data.cpu().numpy()[0, 0])))
870
+
871
+ # Numpy istft
872
+ np_istft_s = librosa.istft(stft_matrix=np_stft_matrix.T,
873
+ hop_length=hop_length, window=window, center=center, length=data_length)
874
+
875
+ # Pytorch istft
876
+ pt_istft_extractor = ISTFT(n_fft=n_fft, hop_length=hop_length,
877
+ win_length=win_length, window=window, center=center, pad_mode=pad_mode,
878
+ freeze_parameters=True)
879
+ pt_istft_extractor.to(device)
880
+
881
+ # Recover from real and imag part
882
+ pt_istft_s = pt_istft_extractor.forward(pt_stft_real, pt_stft_imag, data_length)[0, :]
883
+
884
+ # Recover from magnitude and phase
885
+ (pt_stft_mag, cos, sin) = magphase(pt_stft_real, pt_stft_imag)
886
+ pt_istft_s2 = pt_istft_extractor.forward(pt_stft_mag * cos, pt_stft_mag * sin, data_length)[0, :]
887
+
888
+ print(np.mean(np.abs(np_istft_s - pt_istft_s.data.cpu().numpy())))
889
+ print(np.mean(np.abs(np_data - pt_istft_s.data.cpu().numpy())))
890
+ print(np.mean(np.abs(np_data - pt_istft_s2.data.cpu().numpy())))
891
+
892
+ elif select == 'logmel':
893
+ dtype = np.complex64
894
+ device = torch.device(device)
895
+ np.random.seed(0)
896
+
897
+ # Spectrogram parameters (the same as librosa.stft)
898
+ sample_rate = 22050
899
+ data_length = sample_rate * 1
900
+ n_fft = 2048
901
+ hop_length = 512
902
+ win_length = 2048
903
+ window = 'hann'
904
+ center = True
905
+ pad_mode = 'reflect'
906
+
907
+ # Mel parameters (the same as librosa.feature.melspectrogram)
908
+ n_mels = 128
909
+ fmin = 0.
910
+ fmax = sample_rate / 2.0
911
+
912
+ # Power to db parameters (the same as default settings of librosa.power_to_db
913
+ ref = 1.0
914
+ amin = 1e-10
915
+ top_db = 80.0
916
+
917
+ # Data
918
+ np_data = np.random.uniform(-1, 1, data_length)
919
+ pt_data = torch.Tensor(np_data).to(device)
920
+
921
+ print('Comparing librosa and pytorch implementation of logmel '
922
+ 'spectrogram. All numbers below should be close to 0.')
923
+
924
+ # Numpy librosa
925
+ np_stft_matrix = librosa.stft(y=np_data, n_fft=n_fft, hop_length=hop_length,
926
+ win_length=win_length, window=window, center=center, dtype=dtype,
927
+ pad_mode=pad_mode)
928
+
929
+ np_pad = np.pad(np_data, int(n_fft // 2), mode=pad_mode)
930
+
931
+ np_melW = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels,
932
+ fmin=fmin, fmax=fmax).T
933
+
934
+ np_mel_spectrogram = np.dot(np.abs(np_stft_matrix.T) ** 2, np_melW)
935
+
936
+ np_logmel_spectrogram = librosa.power_to_db(
937
+ np_mel_spectrogram, ref=ref, amin=amin, top_db=top_db)
938
+
939
+ # Pytorch
940
+ stft_extractor = STFT(n_fft=n_fft, hop_length=hop_length,
941
+ win_length=win_length, window=window, center=center, pad_mode=pad_mode,
942
+ freeze_parameters=True)
943
+
944
+ logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=n_fft,
945
+ n_mels=n_mels, fmin=fmin, fmax=fmax, ref=ref, amin=amin,
946
+ top_db=top_db, freeze_parameters=True)
947
+
948
+ stft_extractor.to(device)
949
+ logmel_extractor.to(device)
950
+
951
+ pt_pad = F.pad(pt_data[None, None, :], pad=(n_fft // 2, n_fft // 2), mode=pad_mode)[0, 0]
952
+ print(np.mean(np.abs(np_pad - pt_pad.cpu().numpy())))
953
+
954
+ pt_stft_matrix_real = stft_extractor.conv_real(pt_pad[None, None, :])[0]
955
+ pt_stft_matrix_imag = stft_extractor.conv_imag(pt_pad[None, None, :])[0]
956
+ print(np.mean(np.abs(np.real(np_stft_matrix) - pt_stft_matrix_real.data.cpu().numpy())))
957
+ print(np.mean(np.abs(np.imag(np_stft_matrix) - pt_stft_matrix_imag.data.cpu().numpy())))
958
+
959
+ # Spectrogram
960
+ spectrogram_extractor = Spectrogram(n_fft=n_fft, hop_length=hop_length,
961
+ win_length=win_length, window=window, center=center, pad_mode=pad_mode,
962
+ freeze_parameters=True)
963
+
964
+ spectrogram_extractor.to(device)
965
+
966
+ pt_spectrogram = spectrogram_extractor.forward(pt_data[None, :])
967
+ pt_mel_spectrogram = torch.matmul(pt_spectrogram, logmel_extractor.melW)
968
+ print(np.mean(np.abs(np_mel_spectrogram - pt_mel_spectrogram.data.cpu().numpy()[0, 0])))
969
+
970
+ # Log mel spectrogram
971
+ pt_logmel_spectrogram = logmel_extractor.forward(pt_spectrogram)
972
+ print(np.mean(np.abs(np_logmel_spectrogram - pt_logmel_spectrogram[0, 0].data.cpu().numpy())))
973
+
974
+ elif select == 'enframe':
975
+ device = torch.device(device)
976
+ np.random.seed(0)
977
+
978
+ # Spectrogram parameters (the same as librosa.stft)
979
+ sample_rate = 22050
980
+ data_length = sample_rate * 1
981
+ hop_length = 512
982
+ win_length = 2048
983
+
984
+ # Data
985
+ np_data = np.random.uniform(-1, 1, data_length)
986
+ pt_data = torch.Tensor(np_data).to(device)
987
+
988
+ print('Comparing librosa and pytorch implementation of '
989
+ 'librosa.util.frame. All numbers below should be close to 0.')
990
+
991
+ # Numpy librosa
992
+ np_frames = librosa.util.frame(np_data, frame_length=win_length,
993
+ hop_length=hop_length)
994
+
995
+ # Pytorch
996
+ pt_frame_extractor = Enframe(frame_length=win_length, hop_length=hop_length)
997
+ pt_frame_extractor.to(device)
998
+
999
+ pt_frames = pt_frame_extractor(pt_data[None, :])
1000
+ print(np.mean(np.abs(np_frames - pt_frames.data.cpu().numpy())))
1001
+
1002
+ elif select == 'default':
1003
+ device = torch.device(device)
1004
+ np.random.seed(0)
1005
+
1006
+ # Spectrogram parameters (the same as librosa.stft)
1007
+ sample_rate = 22050
1008
+ data_length = sample_rate * 1
1009
+ hop_length = 512
1010
+ win_length = 2048
1011
+
1012
+ # Mel parameters (the same as librosa.feature.melspectrogram)
1013
+ n_mels = 128
1014
+
1015
+ # Data
1016
+ np_data = np.random.uniform(-1, 1, data_length)
1017
+ pt_data = torch.Tensor(np_data).to(device)
1018
+
1019
+ feature_extractor = nn.Sequential(
1020
+ Spectrogram(
1021
+ hop_length=hop_length,
1022
+ win_length=win_length,
1023
+ ), LogmelFilterBank(
1024
+ sr=sample_rate,
1025
+ n_mels=n_mels,
1026
+ is_log=False, #Default is true
1027
+ ))
1028
+
1029
+ feature_extractor.to(device)
1030
+
1031
+ print(
1032
+ 'Comparing default mel spectrogram from librosa to the pytorch implementation.'
1033
+ )
1034
+
1035
+ # Numpy librosa
1036
+ np_melspect = librosa.feature.melspectrogram(np_data,
1037
+ hop_length=hop_length,
1038
+ sr=sample_rate,
1039
+ win_length=win_length,
1040
+ n_mels=n_mels).T
1041
+ #Pytorch
1042
+ pt_melspect = feature_extractor(pt_data[None, :]).squeeze()
1043
+ passed = np.allclose(pt_melspect.data.to('cpu').numpy(), np_melspect)
1044
+ print(f"Passed? {passed}")
1045
+
1046
+
1047
+
1048
+ if __name__ == '__main__':
1049
+
1050
+ parser = argparse.ArgumentParser(description='')
1051
+ parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'])
1052
+ args = parser.parse_args()
1053
+
1054
+ device = args.device
1055
+ norm = None # None | 'ortho'
1056
+ np.random.seed(0)
1057
+
1058
+ # Spectrogram parameters (the same as librosa.stft)
1059
+ sample_rate = 22050
1060
+ data_length = sample_rate * 1
1061
+ n_fft = 2048
1062
+ hop_length = 512
1063
+ win_length = 2048
1064
+ window = 'hann'
1065
+ center = True
1066
+ pad_mode = 'reflect'
1067
+
1068
+ # Mel parameters (the same as librosa.feature.melspectrogram)
1069
+ n_mels = 128
1070
+ fmin = 0.
1071
+ fmax = sample_rate / 2.0
1072
+
1073
+ # Power to db parameters (the same as default settings of librosa.power_to_db
1074
+ ref = 1.0
1075
+ amin = 1e-10
1076
+ top_db = 80.0
1077
+
1078
+ # Data
1079
+ np_data = np.random.uniform(-1, 1, data_length)
1080
+ pt_data = torch.Tensor(np_data).to(device)
1081
+
1082
+ # Pytorch
1083
+ spectrogram_extractor = Spectrogram(n_fft=n_fft, hop_length=hop_length,
1084
+ win_length=win_length, window=window, center=center, pad_mode=pad_mode,
1085
+ freeze_parameters=True)
1086
+
1087
+ logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=n_fft,
1088
+ n_mels=n_mels, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
1089
+ freeze_parameters=True)
1090
+
1091
+ spectrogram_extractor.to(device)
1092
+ logmel_extractor.to(device)
1093
+
1094
+ # Spectrogram
1095
+ pt_spectrogram = spectrogram_extractor.forward(pt_data[None, :])
1096
+
1097
+ # Log mel spectrogram
1098
+ pt_logmel_spectrogram = logmel_extractor.forward(pt_spectrogram)
1099
+
1100
+ # Uncomment for debug
1101
+ if True:
1102
+ debug(select='dft', device=device)
1103
+ debug(select='stft', device=device)
1104
+ debug(select='logmel', device=device)
1105
+ debug(select='enframe', device=device)
1106
+
1107
+ try:
1108
+ debug(select='default', device=device)
1109
+ except:
1110
+ raise Exception('Torchlibrosa does support librosa>=0.6.0, for \
1111
+ comparison with librosa, please use librosa>=0.7.0!')
src/vision_transformer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import reduce
3
+ from operator import mul
4
+ from ipdb import set_trace
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from mmcls.models.backbones import VisionTransformer as _VisionTransformer
10
+ from mmcls.models.utils import to_2tuple
11
+ from mmcv.cnn.bricks.transformer import PatchEmbed
12
+ from torch.nn.modules.batchnorm import _BatchNorm
13
+
14
+
15
+ def build_2d_sincos_position_embedding(patches_resolution,
16
+ embed_dims,
17
+ temperature=10000.,
18
+ cls_token=False):
19
+ """The function is to build position embedding for model to obtain the
20
+ position information of the image patches."""
21
+
22
+ if isinstance(patches_resolution, int):
23
+ patches_resolution = (patches_resolution, patches_resolution)
24
+
25
+ h, w = patches_resolution
26
+ grid_w = torch.arange(w, dtype=torch.float32)
27
+ grid_h = torch.arange(h, dtype=torch.float32)
28
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
29
+ assert embed_dims % 4 == 0, \
30
+ 'Embed dimension must be divisible by 4.'
31
+ pos_dim = embed_dims // 4
32
+
33
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
34
+ omega = 1. / (temperature**omega)
35
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
36
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
37
+
38
+ pos_emb = torch.cat(
39
+ [
40
+ torch.sin(out_w),
41
+ torch.cos(out_w),
42
+ torch.sin(out_h),
43
+ torch.cos(out_h)
44
+ ],
45
+ dim=1,
46
+ )[None, :, :]
47
+
48
+ if cls_token:
49
+ cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
50
+ pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
51
+
52
+ return pos_emb
53
+
54
+
55
+ class VisionTransformer(_VisionTransformer):
56
+ """Vision Transformer.
57
+
58
+ A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for
59
+ Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
60
+
61
+ Part of the code is modified from:
62
+ `<https://github.com/facebookresearch/moco-v3/blob/main/vits.py>`_.
63
+
64
+ Args:
65
+ stop_grad_conv1 (bool, optional): whether to stop the gradient of
66
+ convolution layer in `PatchEmbed`. Defaults to False.
67
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
68
+ -1 means not freezing any parameters. Defaults to -1.
69
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
70
+ freeze running stats (mean and var). Note: Effect on Batch Norm
71
+ and its variants only. Defaults to False.
72
+ init_cfg (dict or list[dict], optional): Initialization config dict.
73
+ Defaults to None.
74
+ """
75
+
76
+ arch_zoo = {
77
+ **dict.fromkeys(
78
+ ['mocov3-s', 'mocov3-small'], {
79
+ 'embed_dims': 384,
80
+ 'num_layers': 12,
81
+ 'num_heads': 12,
82
+ 'feedforward_channels': 1536,
83
+ }),
84
+ **dict.fromkeys(
85
+ ['b', 'base'], {
86
+ 'embed_dims': 768,
87
+ 'num_layers': 12,
88
+ 'num_heads': 12,
89
+ 'feedforward_channels': 3072
90
+ }),
91
+ }
92
+
93
+ def __init__(self,
94
+ stop_grad_conv1=False,
95
+ frozen_stages=-1,
96
+ norm_eval=False,
97
+ init_cfg=None,
98
+ **kwargs):
99
+ super(VisionTransformer, self).__init__(init_cfg=init_cfg,)
100
+ self.patch_size = kwargs['patch_size']
101
+ self.frozen_stages = frozen_stages
102
+ self.norm_eval = norm_eval
103
+ self.init_cfg = init_cfg
104
+
105
+
106
+ if isinstance(self.patch_embed, PatchEmbed):
107
+ if stop_grad_conv1:
108
+ self.patch_embed.projection.weight.requires_grad = False
109
+ self.patch_embed.projection.bias.requires_grad = False
110
+
111
+ self._freeze_stages()
112
+
113
+ def init_weights(self):
114
+ super(VisionTransformer, self).init_weights()
115
+
116
+ if not (isinstance(self.init_cfg, dict)
117
+ and self.init_cfg['type'] == 'Pretrained'):
118
+
119
+ # Use fixed 2D sin-cos position embedding
120
+ pos_emb = build_2d_sincos_position_embedding(
121
+ patches_resolution=self.patch_resolution,
122
+ embed_dims=self.embed_dims,
123
+ cls_token=True)
124
+ self.pos_embed.data.copy_(pos_emb)
125
+ self.pos_embed.requires_grad = False
126
+
127
+ # xavier_uniform initialization for PatchEmbed
128
+ if isinstance(self.patch_embed, PatchEmbed):
129
+ val = math.sqrt(
130
+ 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) +
131
+ self.embed_dims))
132
+ nn.init.uniform_(self.patch_embed.projection.weight, -val, val)
133
+ nn.init.zeros_(self.patch_embed.projection.bias)
134
+
135
+ # initialization for linear layers
136
+ for name, m in self.named_modules():
137
+ if isinstance(m, nn.Linear):
138
+ if 'qkv' in name:
139
+ # treat the weights of Q, K, V separately
140
+ val = math.sqrt(
141
+ 6. /
142
+ float(m.weight.shape[0] // 3 + m.weight.shape[1]))
143
+ nn.init.uniform_(m.weight, -val, val)
144
+ else:
145
+ nn.init.xavier_uniform_(m.weight)
146
+ nn.init.zeros_(m.bias)
147
+ nn.init.normal_(self.cls_token, std=1e-6)
148
+
149
+ def _freeze_stages(self):
150
+ """Freeze patch_embed layer, some parameters and stages."""
151
+ if self.frozen_stages >= 0:
152
+ self.patch_embed.eval()
153
+ for param in self.patch_embed.parameters():
154
+ param.requires_grad = False
155
+
156
+ self.cls_token.requires_grad = False
157
+ self.pos_embed.requires_grad = False
158
+
159
+ for i in range(1, self.frozen_stages + 1):
160
+ m = self.layers[i - 1]
161
+ m.eval()
162
+ for param in m.parameters():
163
+ param.requires_grad = False
164
+
165
+ if i == (self.num_layers) and self.final_norm:
166
+ for param in getattr(self, 'norm1').parameters():
167
+ param.requires_grad = False
168
+
169
+ def train(self, mode=True):
170
+ super(VisionTransformer, self).train(mode)
171
+ self._freeze_stages()
172
+ if mode and self.norm_eval:
173
+ for m in self.modules():
174
+ # trick: eval have effect on BatchNorm only
175
+ if isinstance(m, _BatchNorm):
176
+ m.eval()