sino commited on
Commit
ff4fdee
·
1 Parent(s): 354c8fa

Upload 4 files

Browse files
Files changed (4) hide show
  1. LMdecoder.py +169 -0
  2. htsat.py +1249 -0
  3. mae_vit.py +303 -0
  4. vision_transformer.py +176 -0
LMdecoder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from doctest import ELLIPSIS_MARKER
3
+ from functools import partial
4
+ import json
5
+ from turtle import forward, shape
6
+ import einops
7
+ import torch
8
+ from torch import nn
9
+
10
+ from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
11
+ from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \
12
+ BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
13
+ from transformers import BitsAndBytesConfig
14
+
15
+ from peft import prepare_model_for_kbit_training
16
+ from peft import LoraConfig
17
+ from peft import get_peft_model
18
+
19
+
20
+ from mmcv.cnn import build_norm_layer
21
+ from mmcv.runner import BaseModule
22
+ import math
23
+ from ipdb import set_trace
24
+
25
+ class mixEmbed(nn.Module):
26
+ def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None:
27
+ super().__init__(*args, **kwargs)
28
+ self.lm_embed = lm_embed
29
+ self.audio_embeddings = audio_embeddings # ugly but works without modifying raw model codes
30
+
31
+ def forward(self, input_ids):
32
+ text_ids = torch.clamp(input_ids.clone(), 0).long()
33
+
34
+ au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long()
35
+ text_embeds = self.lm_embed(text_ids)
36
+ au_embeds = self.audio_embeddings[au_ids]
37
+ with torch.no_grad():
38
+ embed_mask = (input_ids > 0)
39
+ mix_embeds = au_embeds.clone()
40
+ mix_embeds[embed_mask] = text_embeds[embed_mask]
41
+ return mix_embeds
42
+
43
+
44
+ class LMDecoder(nn.Module):
45
+ def __init__(self,
46
+ # num_patches=196,
47
+ img_size=(80,512),
48
+ patch_size:int=16,
49
+ in_chans:int=3,
50
+ embed_dim=1024, # encoder embed dim
51
+ decoder_embed_dim=512,
52
+ norm_cfg=dict(type='LN', eps=1e-6),
53
+ # patch_resolution=14,
54
+ decoder_type='gpt2',
55
+ freeze_decoder=True,
56
+ additional_layer:int=0,
57
+ ):
58
+ super().__init__()
59
+ self.decoder_type = decoder_type
60
+ self.load_lm()
61
+
62
+ self.lm_embed = self.lm.get_input_embeddings()
63
+ try:
64
+ self.lm_pos_embed = self.lm.get_position_embeddings()
65
+ except NotImplementedError:
66
+ self.lm_pos_embed = None # rotrary embeds
67
+
68
+
69
+ if hasattr(self.lm,'embed_dim'):
70
+ self.embed_dim = self.lm.embed_dim
71
+ else:
72
+ self.embed_dim = decoder_embed_dim
73
+
74
+ # self.asLM = asLM # if generates tokens rather than hidden states
75
+ # if self.asLM: # TODO: 当年写这个是为啥?
76
+ # self.lm.set_output_embeddings(nn.Linear(self.embed_dim, self.self.LMconfig.vocab_size, bias=False))
77
+ self.freeze_decoder = False
78
+ if True:
79
+ for para in self.lm.parameters():
80
+ para.requires_grad = False
81
+
82
+ def load_lm(self):
83
+ ## ---------------------LM setting----------------------
84
+ self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
85
+ if self.tokenizer.pad_token is None:
86
+ self.tokenizer.pad_token = self.tokenizer.eos_token
87
+ self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
88
+ self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, trust_remote_code=True)
89
+
90
+
91
+ def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs):
92
+ mix_embed = mixEmbed(self.lm_embed, flatten_embs)
93
+ self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
94
+ output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs)
95
+ self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
96
+ return output
97
+
98
+ def generate(self, input_ids, flatten_embs):
99
+ mix_embed = mixEmbed(self.lm_embed, flatten_embs)
100
+ self.lm.set_input_embeddings(mix_embed) # modification of the lm embed
101
+ outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False)
102
+ # outputs = self.lm.generate(input_ids=input_ids,
103
+ # max_new_tokens=1024,
104
+ # do_sample=True,
105
+ # temperature=1.5,
106
+ # num_beams=1,
107
+ # top_p=0.9,
108
+ # top_k=3,
109
+ # use_cache=False)
110
+ self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed
111
+ return outputs
112
+ '''
113
+ ## infer params
114
+ max_input_tokens: 40
115
+ batch_size_test: 16
116
+ max_new_tokens: 64
117
+ min_length: 2
118
+ num_beams: 5
119
+ length_penalty: -2.0
120
+ top_p: 0.9
121
+ top_k: 3
122
+ no_repeat_ngram_size: 2
123
+ apply_lemmatizer: False
124
+ use_nucleus_sampling: True
125
+ '''
126
+
127
+ class LMDecoder_qlora(LMDecoder):
128
+ def __init__(self,
129
+ # num_patches=196,
130
+ img_size=(80,512),
131
+ patch_size:int=16,
132
+ in_chans:int=3,
133
+ embed_dim=1024, # encoder embed dim
134
+ decoder_embed_dim=512,
135
+ norm_cfg=dict(type='LN', eps=1e-6),
136
+ # patch_resolution=14,
137
+ decoder_type='gpt2',
138
+ freeze_decoder=True,
139
+ additional_layer:int=0,
140
+ ):
141
+ super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer)
142
+
143
+ def load_lm(self):
144
+ self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
145
+ self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
146
+ double_quant_config = BitsAndBytesConfig(
147
+ load_in_4bit=True,
148
+ bnb_4bit_use_double_quant=True,
149
+ )
150
+ model = AutoModelForCausalLM.from_pretrained(self.decoder_type,
151
+ # device_map='auto', # if remove, can not add lora
152
+ # load_in_4bit=True,# if remove, can not add lora
153
+ # # torch_dtype=torch.bfloat16,
154
+ # quantization_config=double_quant_config, # if remove, can not add lora
155
+ trust_remote_code=True )
156
+
157
+ model.gradient_checkpointing_enable()
158
+ model = prepare_model_for_kbit_training(model)
159
+ lora_config = LoraConfig(
160
+ r=8,
161
+ lora_alpha=32,
162
+ target_modules=["query_key_value"],
163
+ lora_dropout=0.05,
164
+ bias="none",
165
+ task_type="CAUSAL_LM"
166
+ )
167
+
168
+ self.lm = get_peft_model(model, lora_config)
169
+ self.lm.print_trainable_parameters()
htsat.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+ import math
14
+ import warnings
15
+
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+ import torch.utils.checkpoint as checkpoint
18
+
19
+ import random
20
+
21
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
+ from torchlibrosa.augmentation import SpecAugmentation
23
+ from einops import rearrange
24
+ from itertools import repeat
25
+ # from .utils import interpolate
26
+
27
+ # from .feature_fusion import iAFF, AFF, DAF
28
+
29
+
30
+ '''
31
+ Feature Fusion for Varible-Length Data Processing
32
+ AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
33
+ According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
34
+ '''
35
+
36
+ class DAF(nn.Module):
37
+ '''
38
+ 直接相加 DirectAddFuse
39
+ '''
40
+
41
+ def __init__(self):
42
+ super(DAF, self).__init__()
43
+
44
+ def forward(self, x, residual):
45
+ return x + residual
46
+
47
+
48
+ class iAFF(nn.Module):
49
+ '''
50
+ 多特征融合 iAFF
51
+ '''
52
+
53
+ def __init__(self, channels=64, r=4, type='2D'):
54
+ super(iAFF, self).__init__()
55
+ inter_channels = int(channels // r)
56
+
57
+ if type == '1D':
58
+ # 本地注意力
59
+ self.local_att = nn.Sequential(
60
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
61
+ nn.BatchNorm1d(inter_channels),
62
+ nn.ReLU(inplace=True),
63
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
64
+ nn.BatchNorm1d(channels),
65
+ )
66
+
67
+ # 全局注意力
68
+ self.global_att = nn.Sequential(
69
+ nn.AdaptiveAvgPool1d(1),
70
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
71
+ nn.BatchNorm1d(inter_channels),
72
+ nn.ReLU(inplace=True),
73
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
74
+ nn.BatchNorm1d(channels),
75
+ )
76
+
77
+ # 第二次本地注意力
78
+ self.local_att2 = nn.Sequential(
79
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
80
+ nn.BatchNorm1d(inter_channels),
81
+ nn.ReLU(inplace=True),
82
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
83
+ nn.BatchNorm1d(channels),
84
+ )
85
+ # 第二次全局注意力
86
+ self.global_att2 = nn.Sequential(
87
+ nn.AdaptiveAvgPool1d(1),
88
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
89
+ nn.BatchNorm1d(inter_channels),
90
+ nn.ReLU(inplace=True),
91
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm1d(channels),
93
+ )
94
+ elif type == '2D':
95
+ # 本地注意力
96
+ self.local_att = nn.Sequential(
97
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
98
+ nn.BatchNorm2d(inter_channels),
99
+ nn.ReLU(inplace=True),
100
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(channels),
102
+ )
103
+
104
+ # 全局注意力
105
+ self.global_att = nn.Sequential(
106
+ nn.AdaptiveAvgPool2d(1),
107
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
108
+ nn.BatchNorm2d(inter_channels),
109
+ nn.ReLU(inplace=True),
110
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
111
+ nn.BatchNorm2d(channels),
112
+ )
113
+
114
+ # 第二次本地注意力
115
+ self.local_att2 = nn.Sequential(
116
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
117
+ nn.BatchNorm2d(inter_channels),
118
+ nn.ReLU(inplace=True),
119
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
120
+ nn.BatchNorm2d(channels),
121
+ )
122
+ # 第二次全局注意力
123
+ self.global_att2 = nn.Sequential(
124
+ nn.AdaptiveAvgPool2d(1),
125
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
126
+ nn.BatchNorm2d(inter_channels),
127
+ nn.ReLU(inplace=True),
128
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
129
+ nn.BatchNorm2d(channels),
130
+ )
131
+ else:
132
+ raise f'the type is not supported'
133
+
134
+ self.sigmoid = nn.Sigmoid()
135
+
136
+ def forward(self, x, residual):
137
+ flag = False
138
+ xa = x + residual
139
+ if xa.size(0) == 1:
140
+ xa = torch.cat([xa,xa],dim=0)
141
+ flag = True
142
+ xl = self.local_att(xa)
143
+ xg = self.global_att(xa)
144
+ xlg = xl + xg
145
+ wei = self.sigmoid(xlg)
146
+ xi = x * wei + residual * (1 - wei)
147
+
148
+ xl2 = self.local_att2(xi)
149
+ xg2 = self.global_att(xi)
150
+ xlg2 = xl2 + xg2
151
+ wei2 = self.sigmoid(xlg2)
152
+ xo = x * wei2 + residual * (1 - wei2)
153
+ if flag:
154
+ xo = xo[0].unsqueeze(0)
155
+ return xo
156
+
157
+
158
+ class AFF(nn.Module):
159
+ '''
160
+ 多特征融合 AFF
161
+ '''
162
+
163
+ def __init__(self, channels=64, r=4, type='2D'):
164
+ super(AFF, self).__init__()
165
+ inter_channels = int(channels // r)
166
+
167
+ if type == '1D':
168
+ self.local_att = nn.Sequential(
169
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
170
+ nn.BatchNorm1d(inter_channels),
171
+ nn.ReLU(inplace=True),
172
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
173
+ nn.BatchNorm1d(channels),
174
+ )
175
+ self.global_att = nn.Sequential(
176
+ nn.AdaptiveAvgPool1d(1),
177
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
178
+ nn.BatchNorm1d(inter_channels),
179
+ nn.ReLU(inplace=True),
180
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
181
+ nn.BatchNorm1d(channels),
182
+ )
183
+ elif type == '2D':
184
+ self.local_att = nn.Sequential(
185
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
186
+ nn.BatchNorm2d(inter_channels),
187
+ nn.ReLU(inplace=True),
188
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
189
+ nn.BatchNorm2d(channels),
190
+ )
191
+ self.global_att = nn.Sequential(
192
+ nn.AdaptiveAvgPool2d(1),
193
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
194
+ nn.BatchNorm2d(inter_channels),
195
+ nn.ReLU(inplace=True),
196
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
197
+ nn.BatchNorm2d(channels),
198
+ )
199
+ else:
200
+ raise f'the type is not supported.'
201
+
202
+ self.sigmoid = nn.Sigmoid()
203
+
204
+ def forward(self, x, residual):
205
+ flag = False
206
+ xa = x + residual
207
+ if xa.size(0) == 1:
208
+ xa = torch.cat([xa,xa],dim=0)
209
+ flag = True
210
+ xl = self.local_att(xa)
211
+ xg = self.global_att(xa)
212
+ xlg = xl + xg
213
+ wei = self.sigmoid(xlg)
214
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
215
+ if flag:
216
+ xo = xo[0].unsqueeze(0)
217
+ return xo
218
+
219
+
220
+ # .utils
221
+
222
+ def interpolate(x, ratio):
223
+ """Interpolate data in time domain. This is used to compensate the
224
+ resolution reduction in downsampling of a CNN.
225
+
226
+ Args:
227
+ x: (batch_size, time_steps, classes_num)
228
+ ratio: int, ratio to interpolate
229
+ Returns:
230
+ upsampled: (batch_size, time_steps * ratio, classes_num)
231
+ """
232
+ (batch_size, time_steps, classes_num) = x.shape
233
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
234
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
235
+ return upsampled
236
+
237
+ def do_mixup(x, mixup_lambda):
238
+ """
239
+ Args:
240
+ x: (batch_size , ...)
241
+ mixup_lambda: (batch_size,)
242
+ Returns:
243
+ out: (batch_size, ...)
244
+ """
245
+ out = (
246
+ x.transpose(0, -1) * mixup_lambda
247
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
248
+ ).transpose(0, -1)
249
+ return out
250
+
251
+ # from PyTorch internals
252
+ def _ntuple(n):
253
+ def parse(x):
254
+ if isinstance(x, collections.abc.Iterable):
255
+ return x
256
+ return tuple(repeat(x, n))
257
+ return parse
258
+
259
+ to_1tuple = _ntuple(1)
260
+ to_2tuple = _ntuple(2)
261
+ to_3tuple = _ntuple(3)
262
+ to_4tuple = _ntuple(4)
263
+ to_ntuple = _ntuple
264
+
265
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
266
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
267
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
268
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
269
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
270
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
271
+ 'survival rate' as the argument.
272
+ """
273
+ if drop_prob == 0. or not training:
274
+ return x
275
+ keep_prob = 1 - drop_prob
276
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
277
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
278
+ random_tensor.floor_() # binarize
279
+ output = x.div(keep_prob) * random_tensor
280
+ return output
281
+
282
+
283
+ class DropPath(nn.Module):
284
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
285
+ """
286
+ def __init__(self, drop_prob=None):
287
+ super(DropPath, self).__init__()
288
+ self.drop_prob = drop_prob
289
+
290
+ def forward(self, x):
291
+ return drop_path(x, self.drop_prob, self.training)
292
+
293
+ class PatchEmbed(nn.Module):
294
+ """ 2D Image to Patch Embedding
295
+ """
296
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16,
297
+ enable_fusion=False, fusion_type='None'):
298
+ super().__init__()
299
+ img_size = to_2tuple(img_size)
300
+ patch_size = to_2tuple(patch_size)
301
+ patch_stride = to_2tuple(patch_stride)
302
+ self.img_size = img_size
303
+ self.patch_size = patch_size
304
+ self.patch_stride = patch_stride
305
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
306
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
307
+ self.flatten = flatten
308
+ self.in_chans = in_chans
309
+ self.embed_dim = embed_dim
310
+
311
+ self.enable_fusion = enable_fusion
312
+ self.fusion_type = fusion_type
313
+
314
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
315
+
316
+ if (self.enable_fusion) and (self.fusion_type == 'channel_map'):
317
+ self.proj = nn.Conv2d(in_chans*4, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
318
+ else:
319
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
320
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
321
+
322
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
323
+ self.mel_conv2d = nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_size[0], patch_size[1]*3), stride=(patch_stride[0], patch_stride[1] * 3), padding=padding)
324
+ if self.fusion_type == 'daf_2d':
325
+ self.fusion_model = DAF()
326
+ elif self.fusion_type == 'aff_2d':
327
+ self.fusion_model = AFF(channels=embed_dim, type='2D')
328
+ elif self.fusion_type == 'iaff_2d':
329
+ self.fusion_model = iAFF(channels=embed_dim, type='2D')
330
+ def forward(self, x, longer_idx = None):
331
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
332
+ global_x = x[:,0:1,:,:]
333
+
334
+
335
+ # global processing
336
+ B, C, H, W = global_x.shape
337
+ assert H == self.img_size[0] and W == self.img_size[1], \
338
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
339
+ global_x = self.proj(global_x)
340
+ TW = global_x.size(-1)
341
+ if len(longer_idx) > 0:
342
+ # local processing
343
+ local_x = x[longer_idx,1:,:,:].contiguous()
344
+ B, C, H, W = local_x.shape
345
+ local_x = local_x.view(B*C,1,H,W)
346
+ local_x = self.mel_conv2d(local_x)
347
+ local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3))
348
+ local_x = local_x.permute((0,2,3,1,4)).contiguous().flatten(3)
349
+ TB,TC,TH,_ = local_x.size()
350
+ if local_x.size(-1) < TW:
351
+ local_x = torch.cat([local_x, torch.zeros((TB,TC,TH,TW-local_x.size(-1)), device=global_x.device)], dim=-1)
352
+ else:
353
+ local_x = local_x[:,:,:,:TW]
354
+
355
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx],local_x)
356
+ x = global_x
357
+ else:
358
+ B, C, H, W = x.shape
359
+ assert H == self.img_size[0] and W == self.img_size[1], \
360
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
361
+ x = self.proj(x)
362
+
363
+ if self.flatten:
364
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
365
+ x = self.norm(x)
366
+ return x
367
+
368
+ class Mlp(nn.Module):
369
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
370
+ """
371
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
372
+ super().__init__()
373
+ out_features = out_features or in_features
374
+ hidden_features = hidden_features or in_features
375
+ self.fc1 = nn.Linear(in_features, hidden_features)
376
+ self.act = act_layer()
377
+ self.fc2 = nn.Linear(hidden_features, out_features)
378
+ self.drop = nn.Dropout(drop)
379
+
380
+ def forward(self, x):
381
+ x = self.fc1(x)
382
+ x = self.act(x)
383
+ x = self.drop(x)
384
+ x = self.fc2(x)
385
+ x = self.drop(x)
386
+ return x
387
+
388
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
389
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
390
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
391
+ def norm_cdf(x):
392
+ # Computes standard normal cumulative distribution function
393
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
394
+
395
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
396
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
397
+ "The distribution of values may be incorrect.",
398
+ stacklevel=2)
399
+
400
+ with torch.no_grad():
401
+ # Values are generated by using a truncated uniform distribution and
402
+ # then using the inverse CDF for the normal distribution.
403
+ # Get upper and lower cdf values
404
+ l = norm_cdf((a - mean) / std)
405
+ u = norm_cdf((b - mean) / std)
406
+
407
+ # Uniformly fill tensor with values from [l, u], then translate to
408
+ # [2l-1, 2u-1].
409
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
410
+
411
+ # Use inverse cdf transform for normal distribution to get truncated
412
+ # standard normal
413
+ tensor.erfinv_()
414
+
415
+ # Transform to proper mean, std
416
+ tensor.mul_(std * math.sqrt(2.))
417
+ tensor.add_(mean)
418
+
419
+ # Clamp to ensure it's in the proper range
420
+ tensor.clamp_(min=a, max=b)
421
+ return tensor
422
+
423
+
424
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
425
+ # type: (Tensor, float, float, float, float) -> Tensor
426
+ r"""Fills the input Tensor with values drawn from a truncated
427
+ normal distribution. The values are effectively drawn from the
428
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
429
+ with values outside :math:`[a, b]` redrawn until they are within
430
+ the bounds. The method used for generating the random values works
431
+ best when :math:`a \leq \text{mean} \leq b`.
432
+ Args:
433
+ tensor: an n-dimensional `torch.Tensor`
434
+ mean: the mean of the normal distribution
435
+ std: the standard deviation of the normal distribution
436
+ a: the minimum cutoff value
437
+ b: the maximum cutoff value
438
+ Examples:
439
+ >>> w = torch.empty(3, 5)
440
+ >>> nn.init.trunc_normal_(w)
441
+ """
442
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
443
+
444
+
445
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
446
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
447
+ if mode == 'fan_in':
448
+ denom = fan_in
449
+ elif mode == 'fan_out':
450
+ denom = fan_out
451
+ elif mode == 'fan_avg':
452
+ denom = (fan_in + fan_out) / 2
453
+
454
+ variance = scale / denom
455
+
456
+ if distribution == "truncated_normal":
457
+ # constant is stddev of standard normal truncated to (-2, 2)
458
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
459
+ elif distribution == "normal":
460
+ tensor.normal_(std=math.sqrt(variance))
461
+ elif distribution == "uniform":
462
+ bound = math.sqrt(3 * variance)
463
+ tensor.uniform_(-bound, bound)
464
+ else:
465
+ raise ValueError(f"invalid distribution {distribution}")
466
+
467
+
468
+ def lecun_normal_(tensor):
469
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
470
+
471
+ def window_partition(x, window_size):
472
+ """
473
+ Args:
474
+ x: (B, H, W, C)
475
+ window_size (int): window size
476
+ Returns:
477
+ windows: (num_windows*B, window_size, window_size, C)
478
+ """
479
+ B, H, W, C = x.shape
480
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
481
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
482
+ return windows
483
+
484
+
485
+ def window_reverse(windows, window_size, H, W):
486
+ """
487
+ Args:
488
+ windows: (num_windows*B, window_size, window_size, C)
489
+ window_size (int): Window size
490
+ H (int): Height of image
491
+ W (int): Width of image
492
+ Returns:
493
+ x: (B, H, W, C)
494
+ """
495
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
496
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
497
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
498
+ return x
499
+
500
+
501
+ class WindowAttention(nn.Module):
502
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
503
+ It supports both of shifted and non-shifted window.
504
+ Args:
505
+ dim (int): Number of input channels.
506
+ window_size (tuple[int]): The height and width of the window.
507
+ num_heads (int): Number of attention heads.
508
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
509
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
510
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
511
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
512
+ """
513
+
514
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
515
+
516
+ super().__init__()
517
+ self.dim = dim
518
+ self.window_size = window_size # Wh, Ww
519
+ self.num_heads = num_heads
520
+ head_dim = dim // num_heads
521
+ self.scale = qk_scale or head_dim ** -0.5
522
+
523
+ # define a parameter table of relative position bias
524
+ self.relative_position_bias_table = nn.Parameter(
525
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
526
+
527
+ # get pair-wise relative position index for each token inside the window
528
+ coords_h = torch.arange(self.window_size[0])
529
+ coords_w = torch.arange(self.window_size[1])
530
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
531
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
532
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
533
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
534
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
535
+ relative_coords[:, :, 1] += self.window_size[1] - 1
536
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
537
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
538
+ self.register_buffer("relative_position_index", relative_position_index)
539
+
540
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
541
+ self.attn_drop = nn.Dropout(attn_drop)
542
+ self.proj = nn.Linear(dim, dim)
543
+ self.proj_drop = nn.Dropout(proj_drop)
544
+
545
+ trunc_normal_(self.relative_position_bias_table, std=.02)
546
+ self.softmax = nn.Softmax(dim=-1)
547
+
548
+ def forward(self, x, mask=None):
549
+ """
550
+ Args:
551
+ x: input features with shape of (num_windows*B, N, C)
552
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
553
+ """
554
+ B_, N, C = x.shape
555
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
556
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
557
+
558
+ q = q * self.scale
559
+ attn = (q @ k.transpose(-2, -1))
560
+
561
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
562
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
563
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
564
+ attn = attn + relative_position_bias.unsqueeze(0)
565
+
566
+ if mask is not None:
567
+ nW = mask.shape[0]
568
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
569
+ attn = attn.view(-1, self.num_heads, N, N)
570
+ attn = self.softmax(attn)
571
+ else:
572
+ attn = self.softmax(attn)
573
+
574
+ attn = self.attn_drop(attn)
575
+
576
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
577
+ x = self.proj(x)
578
+ x = self.proj_drop(x)
579
+ return x, attn
580
+
581
+ def extra_repr(self):
582
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
583
+
584
+
585
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
586
+ class SwinTransformerBlock(nn.Module):
587
+ r""" Swin Transformer Block.
588
+ Args:
589
+ dim (int): Number of input channels.
590
+ input_resolution (tuple[int]): Input resulotion.
591
+ num_heads (int): Number of attention heads.
592
+ window_size (int): Window size.
593
+ shift_size (int): Shift size for SW-MSA.
594
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
595
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
596
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
597
+ drop (float, optional): Dropout rate. Default: 0.0
598
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
599
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
600
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
601
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
602
+ """
603
+
604
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
605
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
606
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
607
+ super().__init__()
608
+ self.dim = dim
609
+ self.input_resolution = input_resolution
610
+ self.num_heads = num_heads
611
+ self.window_size = window_size
612
+ self.shift_size = shift_size
613
+ self.mlp_ratio = mlp_ratio
614
+ self.norm_before_mlp = norm_before_mlp
615
+ if min(self.input_resolution) <= self.window_size:
616
+ # if window size is larger than input resolution, we don't partition windows
617
+ self.shift_size = 0
618
+ self.window_size = min(self.input_resolution)
619
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
620
+
621
+ self.norm1 = norm_layer(dim)
622
+ self.attn = WindowAttention(
623
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
624
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
625
+
626
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
627
+ if self.norm_before_mlp == 'ln':
628
+ self.norm2 = nn.LayerNorm(dim)
629
+ elif self.norm_before_mlp == 'bn':
630
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
631
+ else:
632
+ raise NotImplementedError
633
+ mlp_hidden_dim = int(dim * mlp_ratio)
634
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
635
+
636
+ if self.shift_size > 0:
637
+ # calculate attention mask for SW-MSA
638
+ H, W = self.input_resolution
639
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
640
+ h_slices = (slice(0, -self.window_size),
641
+ slice(-self.window_size, -self.shift_size),
642
+ slice(-self.shift_size, None))
643
+ w_slices = (slice(0, -self.window_size),
644
+ slice(-self.window_size, -self.shift_size),
645
+ slice(-self.shift_size, None))
646
+ cnt = 0
647
+ for h in h_slices:
648
+ for w in w_slices:
649
+ img_mask[:, h, w, :] = cnt
650
+ cnt += 1
651
+
652
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
653
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
654
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
655
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
656
+ else:
657
+ attn_mask = None
658
+
659
+ self.register_buffer("attn_mask", attn_mask)
660
+
661
+ def forward(self, x):
662
+ # pdb.set_trace()
663
+ H, W = self.input_resolution
664
+ # print("H: ", H)
665
+ # print("W: ", W)
666
+ # pdb.set_trace()
667
+ B, L, C = x.shape
668
+ # assert L == H * W, "input feature has wrong size"
669
+
670
+ shortcut = x
671
+ x = self.norm1(x)
672
+ x = x.view(B, H, W, C)
673
+
674
+ # cyclic shift
675
+ if self.shift_size > 0:
676
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
677
+ else:
678
+ shifted_x = x
679
+
680
+ # partition windows
681
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
682
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
683
+
684
+ # W-MSA/SW-MSA
685
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
686
+
687
+ # merge windows
688
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
689
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
690
+
691
+ # reverse cyclic shift
692
+ if self.shift_size > 0:
693
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
694
+ else:
695
+ x = shifted_x
696
+ x = x.view(B, H * W, C)
697
+
698
+ # FFN
699
+ x = shortcut + self.drop_path(x)
700
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
701
+
702
+ return x, attn
703
+
704
+ def extra_repr(self):
705
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
706
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
707
+
708
+
709
+
710
+ class PatchMerging(nn.Module):
711
+ r""" Patch Merging Layer.
712
+ Args:
713
+ input_resolution (tuple[int]): Resolution of input feature.
714
+ dim (int): Number of input channels.
715
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
716
+ """
717
+
718
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
719
+ super().__init__()
720
+ self.input_resolution = input_resolution
721
+ self.dim = dim
722
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
723
+ self.norm = norm_layer(4 * dim)
724
+
725
+ def forward(self, x):
726
+ """
727
+ x: B, H*W, C
728
+ """
729
+ H, W = self.input_resolution
730
+ B, L, C = x.shape
731
+ assert L == H * W, "input feature has wrong size"
732
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
733
+
734
+ x = x.view(B, H, W, C)
735
+
736
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
737
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
738
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
739
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
740
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
741
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
742
+
743
+ x = self.norm(x)
744
+ x = self.reduction(x)
745
+
746
+ return x
747
+
748
+ def extra_repr(self):
749
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
750
+
751
+
752
+ class BasicLayer(nn.Module):
753
+ """ A basic Swin Transformer layer for one stage.
754
+ Args:
755
+ dim (int): Number of input channels.
756
+ input_resolution (tuple[int]): Input resolution.
757
+ depth (int): Number of blocks.
758
+ num_heads (int): Number of attention heads.
759
+ window_size (int): Local window size.
760
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
761
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
762
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
763
+ drop (float, optional): Dropout rate. Default: 0.0
764
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
765
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
766
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
767
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
768
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
769
+ """
770
+
771
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
772
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
773
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
774
+ norm_before_mlp='ln'):
775
+
776
+ super().__init__()
777
+ self.dim = dim
778
+ self.input_resolution = input_resolution
779
+ self.depth = depth
780
+ self.use_checkpoint = use_checkpoint
781
+
782
+ # build blocks
783
+ self.blocks = nn.ModuleList([
784
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
785
+ num_heads=num_heads, window_size=window_size,
786
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
787
+ mlp_ratio=mlp_ratio,
788
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
789
+ drop=drop, attn_drop=attn_drop,
790
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
791
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
792
+ for i in range(depth)])
793
+
794
+ # patch merging layer
795
+ if downsample is not None:
796
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
797
+ else:
798
+ self.downsample = None
799
+
800
+ def forward(self, x):
801
+ attns = []
802
+ for blk in self.blocks:
803
+ if self.use_checkpoint:
804
+ x = checkpoint.checkpoint(blk, x)
805
+ else:
806
+ x, attn = blk(x)
807
+ if not self.training:
808
+ attns.append(attn.unsqueeze(0))
809
+ if self.downsample is not None:
810
+ x = self.downsample(x)
811
+ if not self.training:
812
+ attn = torch.cat(attns, dim = 0)
813
+ attn = torch.mean(attn, dim = 0)
814
+ return x, attn
815
+
816
+ # if self.downsample is not None:
817
+ # x = self.downsample(x)
818
+ # if not self.training:
819
+ # attn = torch.cat(attns, dim = 0)
820
+ # attn = torch.mean(attn, dim = 0)
821
+ # return x, attn
822
+
823
+ def extra_repr(self):
824
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
825
+
826
+
827
+ # The Core of HTSAT
828
+ class HTSAT_Swin_Transformer(nn.Module):
829
+ r"""HTSAT based on the Swin Transformer
830
+ Args:
831
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
832
+ patch_size (int | tuple(int)): Patch size. Default: 4
833
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
834
+ in_chans (int): Number of input image channels. Default: 1 (mono)
835
+ num_classes (int): Number of classes for classification head. Default: 527
836
+ embed_dim (int): Patch embedding dimension. Default: 96
837
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
838
+ num_heads (tuple(int)): Number of attention heads in different layers.
839
+ window_size (int): Window size. Default: 8
840
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
841
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
842
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
843
+ drop_rate (float): Dropout rate. Default: 0
844
+ attn_drop_rate (float): Attention dropout rate. Default: 0
845
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
846
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
847
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
848
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
849
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
850
+ config (module): The configuration Module from config.py
851
+ """
852
+
853
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
854
+ in_chans=1, num_classes=527,
855
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
856
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
857
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
858
+ norm_layer=nn.LayerNorm,
859
+ ape=False, patch_norm=True,
860
+ use_checkpoint=False, norm_before_mlp='ln', config = None,
861
+ enable_fusion = False, fusion_type = 'None', **kwargs):
862
+ super(HTSAT_Swin_Transformer, self).__init__()
863
+
864
+ self.config = config
865
+ self.spec_size = spec_size
866
+ self.patch_stride = patch_stride
867
+ self.patch_size = patch_size
868
+ self.window_size = window_size
869
+ self.embed_dim = embed_dim
870
+ self.depths = depths
871
+ self.ape = ape
872
+ self.in_chans = in_chans
873
+ self.num_classes = num_classes
874
+ self.num_heads = num_heads
875
+ self.num_layers = len(self.depths)
876
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
877
+
878
+ self.drop_rate = drop_rate
879
+ self.attn_drop_rate = attn_drop_rate
880
+ self.drop_path_rate = drop_path_rate
881
+
882
+ self.qkv_bias = qkv_bias
883
+ self.qk_scale = None
884
+
885
+ self.patch_norm = patch_norm
886
+ self.norm_layer = norm_layer if self.patch_norm else None
887
+ self.norm_before_mlp = norm_before_mlp
888
+ self.mlp_ratio = mlp_ratio
889
+
890
+ self.use_checkpoint = use_checkpoint
891
+
892
+ self.enable_fusion = enable_fusion
893
+ self.fusion_type = fusion_type
894
+
895
+ # process mel-spec ; used only once
896
+ self.freq_ratio = self.spec_size // self.config.mel_bins
897
+ window = 'hann'
898
+ center = True
899
+ pad_mode = 'reflect'
900
+ ref = 1.0
901
+ amin = 1e-10
902
+ top_db = None
903
+ self.interpolate_ratio = 32 # Downsampled ratio
904
+ # Spectrogram extractor
905
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
906
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
907
+ freeze_parameters=True)
908
+ # Logmel feature extractor
909
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
910
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
911
+ freeze_parameters=True)
912
+ # Spec augmenter
913
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
914
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
915
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
916
+
917
+
918
+ # split spctrogram into non-overlapping patches
919
+ self.patch_embed = PatchEmbed(
920
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
921
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride,
922
+ enable_fusion=self.enable_fusion, fusion_type=self.fusion_type
923
+ )
924
+
925
+ num_patches = self.patch_embed.num_patches
926
+ patches_resolution = self.patch_embed.grid_size
927
+ self.patches_resolution = patches_resolution
928
+
929
+ # absolute position embedding
930
+ if self.ape:
931
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
932
+ trunc_normal_(self.absolute_pos_embed, std=.02)
933
+
934
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
935
+
936
+ # stochastic depth
937
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
938
+
939
+ # build layers
940
+ self.layers = nn.ModuleList()
941
+ for i_layer in range(self.num_layers):
942
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
943
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
944
+ patches_resolution[1] // (2 ** i_layer)),
945
+ depth=self.depths[i_layer],
946
+ num_heads=self.num_heads[i_layer],
947
+ window_size=self.window_size,
948
+ mlp_ratio=self.mlp_ratio,
949
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
950
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
951
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
952
+ norm_layer=self.norm_layer,
953
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
954
+ use_checkpoint=use_checkpoint,
955
+ norm_before_mlp=self.norm_before_mlp)
956
+ self.layers.append(layer)
957
+
958
+ self.norm = self.norm_layer(self.num_features)
959
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
960
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
961
+
962
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
963
+ self.tscam_conv = nn.Conv2d(
964
+ in_channels = self.num_features,
965
+ out_channels = self.num_classes,
966
+ kernel_size = (SF,3),
967
+ padding = (0,1)
968
+ )
969
+ self.head = nn.Linear(num_classes, num_classes)
970
+
971
+ if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']):
972
+ self.mel_conv1d = nn.Sequential(
973
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
974
+ nn.BatchNorm1d(64)
975
+ )
976
+ if self.fusion_type == 'daf_1d':
977
+ self.fusion_model = DAF()
978
+ elif self.fusion_type == 'aff_1d':
979
+ self.fusion_model = AFF(channels=64, type='1D')
980
+ elif self.fusion_type == 'iaff_1d':
981
+ self.fusion_model = iAFF(channels=64, type='1D')
982
+
983
+ self.apply(self._init_weights)
984
+
985
+ def _init_weights(self, m):
986
+ if isinstance(m, nn.Linear):
987
+ trunc_normal_(m.weight, std=.02)
988
+ if isinstance(m, nn.Linear) and m.bias is not None:
989
+ nn.init.constant_(m.bias, 0)
990
+ elif isinstance(m, nn.LayerNorm):
991
+ nn.init.constant_(m.bias, 0)
992
+ nn.init.constant_(m.weight, 1.0)
993
+
994
+ @torch.jit.ignore
995
+ def no_weight_decay(self):
996
+ return {'absolute_pos_embed'}
997
+
998
+ @torch.jit.ignore
999
+ def no_weight_decay_keywords(self):
1000
+ return {'relative_position_bias_table'}
1001
+
1002
+
1003
+ def forward_features(self, x, longer_idx = None):
1004
+ # A deprecated optimization for using a hierarchical output from different blocks
1005
+
1006
+ frames_num = x.shape[2]
1007
+ x = self.patch_embed(x, longer_idx = longer_idx)
1008
+ if self.ape:
1009
+ x = x + self.absolute_pos_embed
1010
+ x = self.pos_drop(x)
1011
+ for i, layer in enumerate(self.layers):
1012
+ x, attn = layer(x)
1013
+ # for x
1014
+ x = self.norm(x)
1015
+ B, N, C = x.shape
1016
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1017
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1018
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
1019
+ B, C, F, T = x.shape
1020
+ # group 2D CNN
1021
+ c_freq_bin = F // self.freq_ratio
1022
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1023
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
1024
+ # get latent_output
1025
+ fine_grained_latent_output = torch.mean(x, dim = 2)
1026
+ fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1027
+
1028
+ latent_output = self.avgpool(torch.flatten(x,2))
1029
+ latent_output = torch.flatten(latent_output, 1)
1030
+
1031
+ # display the attention map, if needed
1032
+
1033
+ x = self.tscam_conv(x)
1034
+ x = torch.flatten(x, 2) # B, C, T
1035
+
1036
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1037
+
1038
+ x = self.avgpool(x)
1039
+ x = torch.flatten(x, 1)
1040
+
1041
+ output_dict = {
1042
+ 'framewise_output': fpx, # already sigmoided
1043
+ 'clipwise_output': torch.sigmoid(x),
1044
+ 'fine_grained_embedding': fine_grained_latent_output,
1045
+ 'embedding': latent_output
1046
+ }
1047
+
1048
+ return output_dict
1049
+
1050
+ def crop_wav(self, x, crop_size, spe_pos = None):
1051
+ time_steps = x.shape[2]
1052
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1053
+ for i in range(len(x)):
1054
+ if spe_pos is None:
1055
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
1056
+ else:
1057
+ crop_pos = spe_pos
1058
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
1059
+ return tx
1060
+
1061
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1062
+ def reshape_wav2img(self, x):
1063
+ B, C, T, F = x.shape
1064
+ target_T = int(self.spec_size * self.freq_ratio)
1065
+ target_F = self.spec_size // self.freq_ratio
1066
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1067
+ # to avoid bicubic zero error
1068
+ if T < target_T:
1069
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1070
+ if F < target_F:
1071
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1072
+ x = x.permute(0,1,3,2).contiguous()
1073
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
1074
+ # print(x.shape)
1075
+ x = x.permute(0,1,3,2,4).contiguous()
1076
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1077
+ return x
1078
+
1079
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1080
+ def repeat_wat2img(self, x, cur_pos):
1081
+ B, C, T, F = x.shape
1082
+ target_T = int(self.spec_size * self.freq_ratio)
1083
+ target_F = self.spec_size // self.freq_ratio
1084
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
1085
+ # to avoid bicubic zero error
1086
+ if T < target_T:
1087
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
1088
+ if F < target_F:
1089
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
1090
+ x = x.permute(0,1,3,2).contiguous() # B C F T
1091
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
1092
+ x = x.repeat(repeats = (1,1,4,1))
1093
+ return x
1094
+
1095
+ def forward_generator(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
1096
+
1097
+ n = int(x.shape[1]/480000)
1098
+ assert n * 480000 == x.shape[1]
1099
+ x = rearrange(x, 'b (n t) -> (b n) t', n=n)
1100
+ if not self.enable_fusion:
1101
+ # x = x["waveform"].to(device=device, non_blocking=True)
1102
+ x = x.to(device=device, non_blocking=True)
1103
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1104
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1105
+ x = x.transpose(1, 3)
1106
+ x = self.bn0(x)
1107
+ x = x.transpose(1, 3)
1108
+ if self.training:
1109
+ x = self.spec_augmenter(x)
1110
+
1111
+ if self.training and mixup_lambda is not None:
1112
+ x = do_mixup(x, mixup_lambda)
1113
+
1114
+ x = self.reshape_wav2img(x)
1115
+ # output_dict = self.forward_features(x)
1116
+
1117
+ # A deprecated optimization for using a hierarchical output from different blocks
1118
+ longer_idx = None
1119
+ frames_num = x.shape[2]
1120
+ x = self.patch_embed(x, longer_idx = longer_idx)
1121
+ if self.ape:
1122
+ x = x + self.absolute_pos_embed
1123
+ x = self.pos_drop(x)
1124
+ for i, layer in enumerate(self.layers[:3]): # depth: [2,2,12,2]
1125
+ if i == 2:
1126
+ for blk in layer.blocks:
1127
+ x, attn = blk(x)
1128
+ # 512
1129
+ x = rearrange(x, '(b n) t c -> b (n t) c', n=n)
1130
+ x = x if (new_x:=(yield x)) is None else new_x
1131
+ x = rearrange(x, 'b (n t) c -> (b n) t c', n=n)
1132
+ else:
1133
+ x, attn = layer(x)
1134
+
1135
+
1136
+
1137
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
1138
+
1139
+ n = int(x.shape[1] / 480000)
1140
+ assert n * 480000 == x.shape[1]
1141
+ x = rearrange(x, 'b (n t) -> (b n) t', n = n)
1142
+ if not self.enable_fusion:
1143
+ # x = x["waveform"].to(device=device, non_blocking=True)
1144
+ x = x.to(device=device, non_blocking=True)
1145
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1146
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1147
+ x = x.transpose(1, 3)
1148
+ x = self.bn0(x)
1149
+ x = x.transpose(1, 3)
1150
+ if self.training:
1151
+ x = self.spec_augmenter(x)
1152
+
1153
+ if self.training and mixup_lambda is not None:
1154
+ x = do_mixup(x, mixup_lambda)
1155
+
1156
+ x = self.reshape_wav2img(x)
1157
+ # x = self.forward_features(x)
1158
+
1159
+ longer_idx = None
1160
+ frames_num = x.shape[2]
1161
+ x = self.patch_embed(x, longer_idx = longer_idx)
1162
+ if self.ape:
1163
+ x = x + self.absolute_pos_embed
1164
+ x = self.pos_drop(x)
1165
+ for i, layer in enumerate(self.layers):
1166
+ x, attn = layer(x)
1167
+ # for x
1168
+ x = self.norm(x)
1169
+ x = rearrange(x, '(b n) t c -> b (n t) c', n = n)
1170
+ return x
1171
+
1172
+ # B, N, C = x.shape
1173
+ # SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1174
+ # ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1175
+ # x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
1176
+ # B, C, F, T = x.shape
1177
+ # # group 2D CNN
1178
+ # c_freq_bin = F // self.freq_ratio
1179
+ # x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1180
+ # x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
1181
+ # # get latent_output
1182
+ # fine_grained_latent_output = torch.mean(x, dim = 2)
1183
+ # fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1184
+
1185
+ # latent_output = self.avgpool(torch.flatten(x,2))
1186
+ # latent_output = torch.flatten(latent_output, 1)
1187
+
1188
+ # # display the attention map, if needed
1189
+
1190
+ # x = self.tscam_conv(x)
1191
+ # x = torch.flatten(x, 2) # B, C, T
1192
+
1193
+ # fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
1194
+
1195
+ # x = self.avgpool(x)
1196
+ # x = torch.flatten(x, 1)
1197
+ # return x
1198
+
1199
+ def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type='None'):
1200
+ try:
1201
+
1202
+ assert audio_cfg.model_name in ["tiny", "base", "large"], "model name for HTS-AT is wrong!"
1203
+ if audio_cfg.model_name == "tiny":
1204
+ model = HTSAT_Swin_Transformer(
1205
+ spec_size=256,
1206
+ patch_size=4,
1207
+ patch_stride=(4,4),
1208
+ num_classes=audio_cfg.class_num,
1209
+ embed_dim=96,
1210
+ depths=[2,2,6,2],
1211
+ num_heads=[4,8,16,32],
1212
+ window_size=8,
1213
+ config = audio_cfg,
1214
+ enable_fusion = enable_fusion,
1215
+ fusion_type = fusion_type
1216
+ )
1217
+ elif audio_cfg.model_name == "base":
1218
+ model = HTSAT_Swin_Transformer(
1219
+ spec_size=256,
1220
+ patch_size=4,
1221
+ patch_stride=(4,4),
1222
+ num_classes=audio_cfg.class_num,
1223
+ embed_dim=128,
1224
+ depths=[2,2,12,2],
1225
+ num_heads=[4,8,16,32],
1226
+ window_size=8,
1227
+ config = audio_cfg,
1228
+ enable_fusion = enable_fusion,
1229
+ fusion_type = fusion_type
1230
+ )
1231
+ elif audio_cfg.model_name == "large":
1232
+ model = HTSAT_Swin_Transformer(
1233
+ spec_size=256,
1234
+ patch_size=4,
1235
+ patch_stride=(4,4),
1236
+ num_classes=audio_cfg.class_num,
1237
+ embed_dim=256,
1238
+ depths=[2,2,12,2],
1239
+ num_heads=[4,8,16,32],
1240
+ window_size=8,
1241
+ config = audio_cfg,
1242
+ enable_fusion = enable_fusion,
1243
+ fusion_type = fusion_type
1244
+ )
1245
+
1246
+ return model
1247
+ except:
1248
+ raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.')
1249
+
mae_vit.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mmcls.models import VisionTransformer
3
+ from torch import nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ import copy
6
+
7
+ def build_2d_sincos_position_embedding(patches_resolution,
8
+ embed_dims,
9
+ temperature=10000.,
10
+ cls_token=False):
11
+ """The function is to build position embedding for model to obtain the
12
+ position information of the image patches."""
13
+
14
+ if isinstance(patches_resolution, int):
15
+ patches_resolution = (patches_resolution, patches_resolution)
16
+
17
+ h, w = patches_resolution
18
+ grid_w = torch.arange(w, dtype=torch.float32)
19
+ grid_h = torch.arange(h, dtype=torch.float32)
20
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
21
+ assert embed_dims % 4 == 0, \
22
+ 'Embed dimension must be divisible by 4.'
23
+ pos_dim = embed_dims // 4
24
+
25
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
26
+ omega = 1. / (temperature**omega)
27
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
28
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
29
+
30
+ pos_emb = torch.cat(
31
+ [
32
+ torch.sin(out_w),
33
+ torch.cos(out_w),
34
+ torch.sin(out_h),
35
+ torch.cos(out_h)
36
+ ],
37
+ dim=1,
38
+ )[None, :, :]
39
+
40
+ if cls_token:
41
+ cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
42
+ pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
43
+
44
+ return pos_emb
45
+
46
+
47
+
48
+ class MAEViT(VisionTransformer):
49
+ """Vision Transformer for MAE pre-training.
50
+
51
+ A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers
52
+ for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
53
+
54
+ Args:
55
+ arch (str | dict): Vision Transformer architecture
56
+ Default: 'b'
57
+ img_size (int | tuple): Input image size
58
+ patch_size (int | tuple): The patch size
59
+ out_indices (Sequence | int): Output from which stages.
60
+ Defaults to -1, means the last stage.
61
+ drop_rate (float): Probability of an element to be zeroed.
62
+ Defaults to 0.
63
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
64
+ norm_cfg (dict): Config dict for normalization layer.
65
+ Defaults to ``dict(type='LN')``.
66
+ final_norm (bool): Whether to add a additional layer to normalize
67
+ final feature map. Defaults to True.
68
+ output_cls_token (bool): Whether output the cls_token. If set True,
69
+ `with_cls_token` must be True. Defaults to True.
70
+ interpolate_mode (str): Select the interpolate mode for position
71
+ embeding vector resize. Defaults to "bicubic".
72
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
73
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
74
+ encoder. Defaults to an empty dict.
75
+ mask_ratio (bool): The ratio of total number of patches to be masked.
76
+ Defaults to 0.75.
77
+ init_cfg (dict, optional): Initialization config dict.
78
+ Defaults to None.
79
+ """
80
+
81
+ arch_zoo = {
82
+ **dict.fromkeys(
83
+ ['mocov3-s', 'mocov3-small'], {
84
+ 'embed_dims': 384,
85
+ 'num_layers': 12,
86
+ 'num_heads': 12,
87
+ 'feedforward_channels': 1536,
88
+ }),
89
+ **dict.fromkeys(
90
+ ['b', 'base'], {
91
+ 'embed_dims': 768,
92
+ 'num_layers': 12,
93
+ 'num_heads': 12,
94
+ 'feedforward_channels': 3072
95
+ }),
96
+ }
97
+
98
+
99
+
100
+ def __init__(self,
101
+ arch='b',
102
+ img_size=224,
103
+ patch_size=16,
104
+ out_indices=-1,
105
+ drop_rate=0,
106
+ drop_path_rate=0,
107
+ norm_cfg=dict(type='LN', eps=1e-6),
108
+ final_norm=True,
109
+ output_cls_token=False,
110
+ interpolate_mode='bicubic',
111
+ patch_cfg=dict(),
112
+ layer_cfgs=dict(),
113
+ gradientCKPT=False,
114
+ mask_ratio=0.75,
115
+ init_cfg=None):
116
+ super().__init__(
117
+ arch=arch,
118
+ img_size=img_size,
119
+ patch_size=patch_size,
120
+ out_indices=out_indices,
121
+ drop_rate=drop_rate,
122
+ drop_path_rate=drop_path_rate,
123
+ norm_cfg=norm_cfg,
124
+ final_norm=final_norm,
125
+ output_cls_token=output_cls_token,
126
+ interpolate_mode=interpolate_mode,
127
+ patch_cfg=patch_cfg,
128
+ layer_cfgs=layer_cfgs,
129
+ init_cfg=init_cfg)
130
+ self.gradientCKPT = gradientCKPT
131
+ self.pos_embed.requires_grad = False
132
+ self.mask_ratio = mask_ratio
133
+ self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
134
+ # self.mask_embedding = copy.deepcopy(self.patch_embed)
135
+ # self.mask_embedding.norm = None
136
+
137
+ def init_weights(self):
138
+ super(MAEViT, self).init_weights()
139
+ if not (isinstance(self.init_cfg, dict)
140
+ and self.init_cfg['type'] == 'Pretrained'):
141
+ # initialize position embedding in backbone
142
+ pos_embed = build_2d_sincos_position_embedding(
143
+ self.patch_resolution,
144
+ self.pos_embed.shape[-1],
145
+ cls_token=True)
146
+ self.pos_embed.data.copy_(pos_embed.float())
147
+
148
+ w = self.patch_embed.projection.weight.data
149
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
150
+
151
+ torch.nn.init.normal_(self.cls_token, std=.02)
152
+
153
+ self.apply(self._init_weights)
154
+
155
+ # mask_embedding transfers pixel level mask to token level
156
+ # self.mask_embedding.apply(self._init_mask_embedding)
157
+ # for para in self.mask_embedding.parameters():
158
+ # para.requires_grad = False
159
+
160
+ def _init_mask_embedding(self,m):
161
+ if hasattr(m,'weight'):
162
+ nn.init.constant_(m.weight,1.0)
163
+ if hasattr(m, 'bias'):
164
+ nn.init.constant_(m.bias,0)
165
+
166
+ def _init_weights(self, m):
167
+
168
+ if isinstance(m, nn.Linear):
169
+ torch.nn.init.xavier_uniform_(m.weight)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ def random_masking(self, x, mask_ratio=0.75, attn_mask=None):
177
+ """Generate the mask for MAE Pre-training.
178
+
179
+ Args:
180
+ x (torch.tensor): Image with data augmentation applied.
181
+ mask_ratio (float): The mask ratio of total patches.
182
+ Defaults to 0.75.
183
+
184
+ Returns:
185
+ tuple[Tensor, Tensor, Tensor]: masked image, mask and the ids
186
+ to restore original image.
187
+
188
+ - x_masked (Tensor): masked image.
189
+ - mask (Tensor): mask used to mask image.
190
+ - ids_restore (Tensor): ids to restore original image.
191
+ """
192
+ N, L, D = x.shape # batch, length, dim
193
+ len_keep = int(L * (1 - mask_ratio))
194
+
195
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
196
+
197
+ # sort noise for each sample
198
+ ids_shuffle = torch.argsort(
199
+ noise, dim=1) # ascend: small is keep, large is remove
200
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
201
+
202
+ # keep the first subset
203
+ ids_keep = ids_shuffle[:, :len_keep]
204
+ x_masked = torch.gather(
205
+ x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
206
+ # modified_attn_mask = None if attn_mask is None else torch.gather(attn_mask,dim=1, index=ids_keep)
207
+
208
+ # generate the binary mask: 0 is keep, 1 is remove
209
+ mask = torch.ones([N, L], device=x.device)
210
+ mask[:, :len_keep] = 0
211
+ # unshuffle to get the binary mask
212
+ mask = torch.gather(mask, dim=1, index=ids_restore)
213
+
214
+ return x_masked, mask, ids_restore #, modified_attn_mask
215
+
216
+ def generate_mask(self, pixel_level_attn_mask):
217
+ '''
218
+ pixel_level_attn_mask: (0,1) attn mask with the same shape as img
219
+ '''
220
+ if pixel_level_attn_mask is None: return None
221
+ # H, W = patch_resolution
222
+ # B, C = pixel_level_attn_mask.shape[:2]
223
+ # attn_mask = torch.ones((B,C,H,W),device=pixel_level_attn_mask)
224
+ # H_splited = torch.chunk(pixel_level_attn_mask, H, -2)
225
+ # HW_splited_mask = (torch.chunk(Hs, W, -1) for Hs in H_splited)
226
+
227
+ # if HW_splited_mask[:,:,hi,wi].sum().item() == 0:
228
+ # attn_mask[:,:,hi,wi] = 0
229
+
230
+ # mask_patches = self.mask_embedding(pixel_level_attn_mask)[0]
231
+ # attn_mask = mask_patches.sum(-1) != 0
232
+
233
+ # return attn_mask
234
+
235
+ def extract_feat(self, img ,attn_mask=None):
236
+ x, *_ = self.forward(img,attn_mask)
237
+ if self.output_cls_token:
238
+ return x[:,0,:]
239
+ else:
240
+ return torch.mean(x,dim=1)
241
+
242
+ def forward(self, x, attn_mask=None):
243
+ if attn_mask is not None: assert self.output_cls_token
244
+
245
+ B = x.shape[0]
246
+ x = self.patch_embed(x)[0]
247
+ # add pos embed w/o cls token
248
+ x = x + self.pos_embed[:, 1:1+x.shape[1], :]
249
+ # masking: length -> length * mask_ratio
250
+ if True:
251
+ assert self.mask_ratio == 0.
252
+ else:
253
+ x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
254
+
255
+ # append cls token
256
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
257
+ cls_tokens = cls_token.expand(B, -1, -1)
258
+ x = torch.cat((cls_tokens, x), dim=1)
259
+ x = self.drop_after_pos(x)
260
+ # if attn_mask is not None:
261
+ # attn_mask = torch.concat((torch.ones((B,1),device=attn_mask.device) , attn_mask),dim=1)
262
+
263
+ for i, layer in enumerate(self.layers):
264
+ if self.gradientCKPT:
265
+ x = checkpoint(layer,x) # ,attn_mask
266
+ else:
267
+ x = layer(x) # ,attn_mask
268
+ if i == len(self.layers) - 1 and self.final_norm:
269
+ x = self.norm1(x)
270
+ if True:
271
+ return x
272
+ else:
273
+ return (x, mask, ids_restore)
274
+
275
+ def forward_generator(self, x, attn_mask=None):
276
+ if attn_mask is not None: assert self.output_cls_token
277
+
278
+ B = x.shape[0]
279
+ x = self.patch_embed(x)[0]
280
+ # add pos embed w/o cls token
281
+ x = x + self.pos_embed[:, 1:1+x.shape[1], :]
282
+
283
+ # append cls token
284
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
285
+ cls_tokens = cls_token.expand(B, -1, -1)
286
+ x = torch.cat((cls_tokens, x), dim=1)
287
+ x = self.drop_after_pos(x)
288
+
289
+ for i, layer in enumerate(self.layers):
290
+ if self.gradientCKPT:
291
+ x = checkpoint(layer,x) # ,attn_mask
292
+ else:
293
+ x = layer(x) # ,attn_mask
294
+
295
+ if i == len(self.layers) - 1 and self.final_norm:
296
+ x = self.norm1(x)
297
+
298
+ x = x if (new_x:=(yield x)) is None else new_x
299
+
300
+ debug = False
301
+ if debug:
302
+ print(f'layer {i}-th forwarded')
303
+
vision_transformer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import reduce
3
+ from operator import mul
4
+ from ipdb import set_trace
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from mmcls.models.backbones import VisionTransformer as _VisionTransformer
10
+ from mmcls.models.utils import to_2tuple
11
+ from mmcv.cnn.bricks.transformer import PatchEmbed
12
+ from torch.nn.modules.batchnorm import _BatchNorm
13
+
14
+
15
+ def build_2d_sincos_position_embedding(patches_resolution,
16
+ embed_dims,
17
+ temperature=10000.,
18
+ cls_token=False):
19
+ """The function is to build position embedding for model to obtain the
20
+ position information of the image patches."""
21
+
22
+ if isinstance(patches_resolution, int):
23
+ patches_resolution = (patches_resolution, patches_resolution)
24
+
25
+ h, w = patches_resolution
26
+ grid_w = torch.arange(w, dtype=torch.float32)
27
+ grid_h = torch.arange(h, dtype=torch.float32)
28
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
29
+ assert embed_dims % 4 == 0, \
30
+ 'Embed dimension must be divisible by 4.'
31
+ pos_dim = embed_dims // 4
32
+
33
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
34
+ omega = 1. / (temperature**omega)
35
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
36
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
37
+
38
+ pos_emb = torch.cat(
39
+ [
40
+ torch.sin(out_w),
41
+ torch.cos(out_w),
42
+ torch.sin(out_h),
43
+ torch.cos(out_h)
44
+ ],
45
+ dim=1,
46
+ )[None, :, :]
47
+
48
+ if cls_token:
49
+ cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
50
+ pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
51
+
52
+ return pos_emb
53
+
54
+
55
+ class VisionTransformer(_VisionTransformer):
56
+ """Vision Transformer.
57
+
58
+ A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for
59
+ Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
60
+
61
+ Part of the code is modified from:
62
+ `<https://github.com/facebookresearch/moco-v3/blob/main/vits.py>`_.
63
+
64
+ Args:
65
+ stop_grad_conv1 (bool, optional): whether to stop the gradient of
66
+ convolution layer in `PatchEmbed`. Defaults to False.
67
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
68
+ -1 means not freezing any parameters. Defaults to -1.
69
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
70
+ freeze running stats (mean and var). Note: Effect on Batch Norm
71
+ and its variants only. Defaults to False.
72
+ init_cfg (dict or list[dict], optional): Initialization config dict.
73
+ Defaults to None.
74
+ """
75
+
76
+ arch_zoo = {
77
+ **dict.fromkeys(
78
+ ['mocov3-s', 'mocov3-small'], {
79
+ 'embed_dims': 384,
80
+ 'num_layers': 12,
81
+ 'num_heads': 12,
82
+ 'feedforward_channels': 1536,
83
+ }),
84
+ **dict.fromkeys(
85
+ ['b', 'base'], {
86
+ 'embed_dims': 768,
87
+ 'num_layers': 12,
88
+ 'num_heads': 12,
89
+ 'feedforward_channels': 3072
90
+ }),
91
+ }
92
+
93
+ def __init__(self,
94
+ stop_grad_conv1=False,
95
+ frozen_stages=-1,
96
+ norm_eval=False,
97
+ init_cfg=None,
98
+ **kwargs):
99
+ super(VisionTransformer, self).__init__(init_cfg=init_cfg,)
100
+ self.patch_size = kwargs['patch_size']
101
+ self.frozen_stages = frozen_stages
102
+ self.norm_eval = norm_eval
103
+ self.init_cfg = init_cfg
104
+
105
+
106
+ if isinstance(self.patch_embed, PatchEmbed):
107
+ if stop_grad_conv1:
108
+ self.patch_embed.projection.weight.requires_grad = False
109
+ self.patch_embed.projection.bias.requires_grad = False
110
+
111
+ self._freeze_stages()
112
+
113
+ def init_weights(self):
114
+ super(VisionTransformer, self).init_weights()
115
+
116
+ if not (isinstance(self.init_cfg, dict)
117
+ and self.init_cfg['type'] == 'Pretrained'):
118
+
119
+ # Use fixed 2D sin-cos position embedding
120
+ pos_emb = build_2d_sincos_position_embedding(
121
+ patches_resolution=self.patch_resolution,
122
+ embed_dims=self.embed_dims,
123
+ cls_token=True)
124
+ self.pos_embed.data.copy_(pos_emb)
125
+ self.pos_embed.requires_grad = False
126
+
127
+ # xavier_uniform initialization for PatchEmbed
128
+ if isinstance(self.patch_embed, PatchEmbed):
129
+ val = math.sqrt(
130
+ 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) +
131
+ self.embed_dims))
132
+ nn.init.uniform_(self.patch_embed.projection.weight, -val, val)
133
+ nn.init.zeros_(self.patch_embed.projection.bias)
134
+
135
+ # initialization for linear layers
136
+ for name, m in self.named_modules():
137
+ if isinstance(m, nn.Linear):
138
+ if 'qkv' in name:
139
+ # treat the weights of Q, K, V separately
140
+ val = math.sqrt(
141
+ 6. /
142
+ float(m.weight.shape[0] // 3 + m.weight.shape[1]))
143
+ nn.init.uniform_(m.weight, -val, val)
144
+ else:
145
+ nn.init.xavier_uniform_(m.weight)
146
+ nn.init.zeros_(m.bias)
147
+ nn.init.normal_(self.cls_token, std=1e-6)
148
+
149
+ def _freeze_stages(self):
150
+ """Freeze patch_embed layer, some parameters and stages."""
151
+ if self.frozen_stages >= 0:
152
+ self.patch_embed.eval()
153
+ for param in self.patch_embed.parameters():
154
+ param.requires_grad = False
155
+
156
+ self.cls_token.requires_grad = False
157
+ self.pos_embed.requires_grad = False
158
+
159
+ for i in range(1, self.frozen_stages + 1):
160
+ m = self.layers[i - 1]
161
+ m.eval()
162
+ for param in m.parameters():
163
+ param.requires_grad = False
164
+
165
+ if i == (self.num_layers) and self.final_norm:
166
+ for param in getattr(self, 'norm1').parameters():
167
+ param.requires_grad = False
168
+
169
+ def train(self, mode=True):
170
+ super(VisionTransformer, self).train(mode)
171
+ self._freeze_stages()
172
+ if mode and self.norm_eval:
173
+ for m in self.modules():
174
+ # trick: eval have effect on BatchNorm only
175
+ if isinstance(m, _BatchNorm):
176
+ m.eval()