Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +38 -0
- __init__.py +8 -0
- cfg.yaml +282 -0
- config.json +12 -0
- model.pt +3 -0
- model_files.zip +3 -0
- modeling_synchformer.py +10 -0
- module_loader.py +32 -0
- modules/dataset/__init__.py +0 -0
- modules/dataset/audioset.py +216 -0
- modules/dataset/dataset_utils.py +112 -0
- modules/dataset/lrs.py +192 -0
- modules/dataset/transforms.py +1074 -0
- modules/dataset/vggsound.py +394 -0
- modules/model/__init__.py +1 -0
- modules/model/modules/bridges.py +178 -0
- modules/model/modules/feat_extractors/audio/ast.py +279 -0
- modules/model/modules/feat_extractors/audio/hf_src/modeling_ast.py +662 -0
- modules/model/modules/feat_extractors/audio/resnet.py +249 -0
- modules/model/modules/feat_extractors/train_clip_src/__init__.py +3 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/__init__.py +13 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/coca_model.py +458 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/constants.py +2 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/factory.py +193 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/generation_utils.py +0 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_configs.py +45 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_model.py +176 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/loss.py +229 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model.py +883 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101-quickgelu.json +22 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101.json +21 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50-quickgelu.json +22 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50.json +21 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x16.json +21 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x4.json +21 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x64.json +21 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-14.json +17 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-16.json +17 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-280.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-336.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16-320.json +16 -0
- modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16.json +16 -0
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Synchformer Hugging Face Model
|
2 |
+
|
3 |
+
This repository contains a Synchformer model for audio-visual synchronization. The model predicts the offset between audio and video tracks.
|
4 |
+
|
5 |
+
## Usage
|
6 |
+
|
7 |
+
```python
|
8 |
+
from transformers import AutoModel
|
9 |
+
import torch
|
10 |
+
|
11 |
+
# Load the model
|
12 |
+
model = AutoModel.from_pretrained("AmrMKayid/synchformer-hf")
|
13 |
+
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
# Predict offset for a video
|
16 |
+
results = model.predict_offset(
|
17 |
+
"path/to/your/video.mp4",
|
18 |
+
offset_sec=0.0, # Ground truth offset (if known)
|
19 |
+
v_start_i_sec=0.0 # Start time in seconds for video
|
20 |
+
)
|
21 |
+
|
22 |
+
# Print results
|
23 |
+
print("\nPrediction Results:")
|
24 |
+
for pred in results["predictions"]:
|
25 |
+
print(f'p={pred["probability"]:.4f}, "{pred["offset_sec"]:.2f}" (class {pred["class_idx"]})')
|
26 |
+
```
|
27 |
+
|
28 |
+
## Model Details
|
29 |
+
|
30 |
+
This model is based on the Synchformer architecture, which uses a transformer to predict the offset between audio and video tracks.
|
31 |
+
|
32 |
+
## Requirements
|
33 |
+
|
34 |
+
- torch
|
35 |
+
- torchaudio
|
36 |
+
- torchvision
|
37 |
+
- omegaconf
|
38 |
+
- ffmpeg (for video processing)
|
__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .synchformer_config import SynchformerConfig
|
2 |
+
from .synchformer_model import SynchformerModel
|
3 |
+
from .modeling_synchformer import *
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"SynchformerConfig",
|
7 |
+
"SynchformerModel",
|
8 |
+
]
|
cfg.yaml
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
action: train_avsync_model
|
2 |
+
model:
|
3 |
+
target: model.sync_model.Synchformer
|
4 |
+
params:
|
5 |
+
afeat_extractor:
|
6 |
+
is_trainable: false
|
7 |
+
target: model.modules.feat_extractors.audio.ast.AST
|
8 |
+
params:
|
9 |
+
ckpt_path: /scratch/project_462000293/vladimir/logs/sync/avclip_models/23-12-22T16-13-38/checkpoints/epoch_e28.pt
|
10 |
+
extract_features: true
|
11 |
+
max_spec_t: 66
|
12 |
+
factorize_freq_time: true
|
13 |
+
agg_freq_module: TransformerEncoderLayer
|
14 |
+
agg_time_module: torch.nn.Identity
|
15 |
+
add_global_repr: false
|
16 |
+
vfeat_extractor:
|
17 |
+
is_trainable: false
|
18 |
+
target: model.modules.feat_extractors.visual.motionformer.MotionFormer
|
19 |
+
params:
|
20 |
+
ckpt_path: /scratch/project_462000293/vladimir/logs/sync/avclip_models/23-12-22T16-13-38/checkpoints/epoch_e28.pt
|
21 |
+
extract_features: true
|
22 |
+
factorize_space_time: true
|
23 |
+
agg_space_module: TransformerEncoderLayer
|
24 |
+
agg_time_module: torch.nn.Identity
|
25 |
+
add_global_repr: false
|
26 |
+
aproj:
|
27 |
+
target: torch.nn.Linear
|
28 |
+
params:
|
29 |
+
in_features: 768
|
30 |
+
out_features: 768
|
31 |
+
vproj:
|
32 |
+
target: torch.nn.Linear
|
33 |
+
params:
|
34 |
+
in_features: 768
|
35 |
+
out_features: 768
|
36 |
+
transformer:
|
37 |
+
target: model.sync_model.GlobalTransformer
|
38 |
+
params:
|
39 |
+
n_layer: 3
|
40 |
+
n_head: 8
|
41 |
+
n_embd: 768
|
42 |
+
tok_pdrop: 0.0
|
43 |
+
embd_pdrop: 0.1
|
44 |
+
resid_pdrop: 0.1
|
45 |
+
attn_pdrop: 0.1
|
46 |
+
pos_emb_cfg:
|
47 |
+
target: model.modules.transformer.RandInitPositionalEncoding
|
48 |
+
params:
|
49 |
+
block_shape:
|
50 |
+
- 198
|
51 |
+
n_embd: 768
|
52 |
+
off_head_cfg:
|
53 |
+
target: torch.nn.Linear
|
54 |
+
params:
|
55 |
+
in_features: 768
|
56 |
+
out_features: 21
|
57 |
+
training:
|
58 |
+
base_learning_rate: 2.0e-06
|
59 |
+
base_batch_size: 16
|
60 |
+
num_workers: 7
|
61 |
+
num_epochs: 10000
|
62 |
+
patience: 50
|
63 |
+
to_max_metric: true
|
64 |
+
metric_name: accuracy_1
|
65 |
+
early_stop_phase: valid
|
66 |
+
use_half_precision: true
|
67 |
+
seed: 1337
|
68 |
+
compile: false
|
69 |
+
skip_test: false
|
70 |
+
run_test_only: false
|
71 |
+
resume: false
|
72 |
+
finetune: false
|
73 |
+
dist_backend: nccl
|
74 |
+
max_clip_norm: 1
|
75 |
+
lr_scheduler:
|
76 |
+
name: constant_with_warmup
|
77 |
+
warmup: 1000
|
78 |
+
optimizer:
|
79 |
+
name: adam
|
80 |
+
betas:
|
81 |
+
- 0.9
|
82 |
+
- 0.999
|
83 |
+
momentum: 0.9
|
84 |
+
weight_decay: 0
|
85 |
+
local_rank: 0
|
86 |
+
global_rank: 0
|
87 |
+
world_size: 32
|
88 |
+
data:
|
89 |
+
offset_type: grid
|
90 |
+
num_off_cls: 21
|
91 |
+
prob_oos: null
|
92 |
+
max_off_sec: 2
|
93 |
+
crop_len_sec: 5
|
94 |
+
step_size_seg: 0.5
|
95 |
+
vids_path: /scratch/project_462000293/vladimir/data/audioset/h264_video_25fps_256side_16000hz_aac/
|
96 |
+
size_before_crop: 256
|
97 |
+
input_size: 224
|
98 |
+
segment_size_vframes: 16
|
99 |
+
vfps: 25
|
100 |
+
afps: 16000
|
101 |
+
n_segments: 14
|
102 |
+
do_offset: true
|
103 |
+
p_color_jitter: 0.0
|
104 |
+
p_gray_scale: 0.0
|
105 |
+
sometimes_upscale_p: 0.0
|
106 |
+
is_spatial_crop_random: true
|
107 |
+
is_temporal_crop_random: true
|
108 |
+
audio_jitter_sec: 0.05
|
109 |
+
p_horizontal_flip: 0.5
|
110 |
+
p_audio_aug: 0.0
|
111 |
+
dataset:
|
112 |
+
target: dataset.audioset.AudioSet
|
113 |
+
params:
|
114 |
+
load_fixed_offsets_on:
|
115 |
+
- valid
|
116 |
+
- test
|
117 |
+
vis_load_backend: read_video
|
118 |
+
size_ratio: null
|
119 |
+
transform_sequence_train:
|
120 |
+
- target: dataset.transforms.EqualifyFromRight
|
121 |
+
params:
|
122 |
+
clip_max_len_sec: 10
|
123 |
+
- target: dataset.transforms.RGBSpatialCropSometimesUpscale
|
124 |
+
params:
|
125 |
+
sometimes_p: 0.0
|
126 |
+
smaller_input_size: 192
|
127 |
+
target_input_size: 224
|
128 |
+
is_random: true
|
129 |
+
- target: dataset.transforms.TemporalCropAndOffset
|
130 |
+
params:
|
131 |
+
crop_len_sec: 5
|
132 |
+
max_off_sec: 2
|
133 |
+
max_wiggle_sec: 0.05
|
134 |
+
do_offset: true
|
135 |
+
offset_type: grid
|
136 |
+
prob_oos: null
|
137 |
+
grid_size: 21
|
138 |
+
segment_size_vframes: 16
|
139 |
+
n_segments: 14
|
140 |
+
step_size_seg: 0.5
|
141 |
+
vfps: 25
|
142 |
+
- target: dataset.transforms.RandomApplyColorDistortion
|
143 |
+
params:
|
144 |
+
p_color_jitter: 0.0
|
145 |
+
s: 1.0
|
146 |
+
p_gray_scale: 0.0
|
147 |
+
- target: dataset.transforms.RandomHorizontalFlip
|
148 |
+
params:
|
149 |
+
p: 0.5
|
150 |
+
- target: dataset.transforms.AudioRandomReverb
|
151 |
+
params:
|
152 |
+
p: 0.0
|
153 |
+
- target: dataset.transforms.AudioRandomVolume
|
154 |
+
params:
|
155 |
+
p: 0.0
|
156 |
+
gain: 2.0
|
157 |
+
gain_type: amplitude
|
158 |
+
- target: dataset.transforms.AudioRandomPitchShift
|
159 |
+
params:
|
160 |
+
p: 0.0
|
161 |
+
shift: 1000
|
162 |
+
- target: dataset.transforms.AudioRandomLowpassFilter
|
163 |
+
params:
|
164 |
+
p: 0.0
|
165 |
+
cutoff_freq: 100
|
166 |
+
- target: dataset.transforms.AudioRandomGaussNoise
|
167 |
+
params:
|
168 |
+
p: 0.0
|
169 |
+
amplitude: 0.01
|
170 |
+
- target: dataset.transforms.GenerateMultipleSegments
|
171 |
+
params:
|
172 |
+
segment_size_vframes: 16
|
173 |
+
n_segments: 14
|
174 |
+
is_start_random: true
|
175 |
+
step_size_seg: 0.5
|
176 |
+
- target: dataset.transforms.RGBToHalfToZeroOne
|
177 |
+
- target: dataset.transforms.RGBNormalize
|
178 |
+
params:
|
179 |
+
mean:
|
180 |
+
- 0.5
|
181 |
+
- 0.5
|
182 |
+
- 0.5
|
183 |
+
std:
|
184 |
+
- 0.5
|
185 |
+
- 0.5
|
186 |
+
- 0.5
|
187 |
+
- target: dataset.transforms.AudioMelSpectrogram
|
188 |
+
params:
|
189 |
+
sample_rate: 16000
|
190 |
+
win_length: 400
|
191 |
+
hop_length: 160
|
192 |
+
n_fft: 1024
|
193 |
+
n_mels: 128
|
194 |
+
- target: dataset.transforms.AudioLog
|
195 |
+
- target: dataset.transforms.PadOrTruncate
|
196 |
+
params:
|
197 |
+
max_spec_t: 66
|
198 |
+
- target: dataset.transforms.AudioNormalizeAST
|
199 |
+
params:
|
200 |
+
mean: -4.2677393
|
201 |
+
std: 4.5689974
|
202 |
+
- target: dataset.transforms.PermuteStreams
|
203 |
+
params:
|
204 |
+
einops_order_audio: S F T -> S 1 F T
|
205 |
+
einops_order_rgb: S T C H W -> S T C H W
|
206 |
+
transform_sequence_test:
|
207 |
+
- target: dataset.transforms.EqualifyFromRight
|
208 |
+
- target: dataset.transforms.RGBSpatialCrop
|
209 |
+
params:
|
210 |
+
input_size: 224
|
211 |
+
is_random: false
|
212 |
+
- target: dataset.transforms.TemporalCropAndOffset
|
213 |
+
params:
|
214 |
+
crop_len_sec: 5
|
215 |
+
max_off_sec: 2
|
216 |
+
max_wiggle_sec: 0.0
|
217 |
+
do_offset: true
|
218 |
+
grid_size: 21
|
219 |
+
offset_type: grid
|
220 |
+
prob_oos: null
|
221 |
+
segment_size_vframes: 16
|
222 |
+
n_segments: 14
|
223 |
+
step_size_seg: 0.5
|
224 |
+
vfps: 25
|
225 |
+
- target: dataset.transforms.GenerateMultipleSegments
|
226 |
+
params:
|
227 |
+
segment_size_vframes: 16
|
228 |
+
n_segments: 14
|
229 |
+
is_start_random: false
|
230 |
+
step_size_seg: 0.5
|
231 |
+
- target: dataset.transforms.RGBToHalfToZeroOne
|
232 |
+
- target: dataset.transforms.RGBNormalize
|
233 |
+
params:
|
234 |
+
mean:
|
235 |
+
- 0.5
|
236 |
+
- 0.5
|
237 |
+
- 0.5
|
238 |
+
std:
|
239 |
+
- 0.5
|
240 |
+
- 0.5
|
241 |
+
- 0.5
|
242 |
+
- target: dataset.transforms.AudioMelSpectrogram
|
243 |
+
params:
|
244 |
+
sample_rate: 16000
|
245 |
+
win_length: 400
|
246 |
+
hop_length: 160
|
247 |
+
n_fft: 1024
|
248 |
+
n_mels: 128
|
249 |
+
- target: dataset.transforms.AudioLog
|
250 |
+
- target: dataset.transforms.PadOrTruncate
|
251 |
+
params:
|
252 |
+
max_spec_t: 66
|
253 |
+
- target: dataset.transforms.AudioNormalizeAST
|
254 |
+
params:
|
255 |
+
mean: -4.2677393
|
256 |
+
std: 4.5689974
|
257 |
+
- target: dataset.transforms.PermuteStreams
|
258 |
+
params:
|
259 |
+
einops_order_audio: S F T -> S 1 F T
|
260 |
+
einops_order_rgb: S T C H W -> S T C H W
|
261 |
+
logging:
|
262 |
+
logdir: /scratch/project_462000293/vladimir/logs/sync/sync_models/
|
263 |
+
log_code_state: true
|
264 |
+
log_frequency: 20
|
265 |
+
patterns_to_ignore:
|
266 |
+
- logs
|
267 |
+
- .git
|
268 |
+
- __pycache__
|
269 |
+
- data
|
270 |
+
- '*.pt'
|
271 |
+
- sbatch_logs
|
272 |
+
- '*.mp4'
|
273 |
+
- '*.wav'
|
274 |
+
- '*.jpg'
|
275 |
+
- '*.gif'
|
276 |
+
- misc*
|
277 |
+
vis_segment_sim: true
|
278 |
+
log_max_items: 500000
|
279 |
+
use_wandb: true
|
280 |
+
start_time: 24-01-04T16-39-21
|
281 |
+
config: ./configs/sync.yaml
|
282 |
+
ckpt_path: /scratch/project_462000293/vladimir/logs/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt
|
config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"afeat_extractor_config": {},
|
3 |
+
"afps": 16000,
|
4 |
+
"in_size": 256,
|
5 |
+
"max_off_sec": 2,
|
6 |
+
"model_type": "synchformer",
|
7 |
+
"transformer_config": {},
|
8 |
+
"transformers_version": "4.45.2",
|
9 |
+
"use_half_precision": true,
|
10 |
+
"vfeat_extractor_config": {},
|
11 |
+
"vfps": 25
|
12 |
+
}
|
model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b4b3557fbd96b61aaffa8bc70b28f9ff53f8fa98edc202655c5d94ab3c719ee
|
3 |
+
size 1131153989
|
model_files.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50f5534dc446ddf64b8da751931b490c2990727784d3cab28f3c5b04c783160a
|
3 |
+
size 787762366
|
modeling_synchformer.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.auto.modeling_auto import _BaseAutoModelClass
|
2 |
+
from transformers.models.auto.configuration_auto import _LazyAutoMapping
|
3 |
+
from transformers.configuration_utils import PretrainedConfig
|
4 |
+
from transformers.models.auto import AutoModel
|
5 |
+
|
6 |
+
from synchformer_config import SynchformerConfig
|
7 |
+
from synchformer_model import SynchformerModel
|
8 |
+
|
9 |
+
# Register the model with the transformers library
|
10 |
+
AutoModel.register(SynchformerConfig, SynchformerModel)
|
module_loader.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import importlib.util
|
5 |
+
|
6 |
+
def setup_modules():
|
7 |
+
# Get the directory where this script is located
|
8 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
modules_dir = os.path.join(current_dir, "modules")
|
10 |
+
|
11 |
+
# Add modules directory to path
|
12 |
+
if modules_dir not in sys.path:
|
13 |
+
sys.path.insert(0, modules_dir)
|
14 |
+
|
15 |
+
# Import required modules
|
16 |
+
required_modules = ["dataset", "model", "scripts", "utils"]
|
17 |
+
for module_name in required_modules:
|
18 |
+
module_path = os.path.join(modules_dir, module_name)
|
19 |
+
if os.path.exists(module_path):
|
20 |
+
# Check if module is already imported
|
21 |
+
if module_name not in sys.modules:
|
22 |
+
# Import the module
|
23 |
+
spec = importlib.util.spec_from_file_location(
|
24 |
+
module_name,
|
25 |
+
os.path.join(module_path, "__init__.py")
|
26 |
+
)
|
27 |
+
if spec:
|
28 |
+
module = importlib.util.module_from_spec(spec)
|
29 |
+
sys.modules[module_name] = module
|
30 |
+
spec.loader.exec_module(module)
|
31 |
+
|
32 |
+
return True
|
modules/dataset/__init__.py
ADDED
File without changes
|
modules/dataset/audioset.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import logging
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
from glob import glob
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
sys.path.insert(0, '.') # nopep8
|
11 |
+
from dataset.dataset_utils import (get_fixed_offsets, get_video_and_audio)
|
12 |
+
|
13 |
+
|
14 |
+
class AudioSet(torch.utils.data.Dataset):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
split,
|
18 |
+
vids_dir,
|
19 |
+
transforms=None,
|
20 |
+
to_filter_bad_examples=True,
|
21 |
+
splits_path='./data',
|
22 |
+
meta_path='./data/audioset.csv',
|
23 |
+
seed=1337,
|
24 |
+
load_fixed_offsets_on=['valid' 'test'],
|
25 |
+
vis_load_backend='read_video',
|
26 |
+
size_ratio=None,
|
27 |
+
attr_annot_path=None,
|
28 |
+
max_attr_per_vid=None):
|
29 |
+
super().__init__()
|
30 |
+
self.max_clip_len_sec = None
|
31 |
+
self.split = split
|
32 |
+
self.vids_dir = Path(vids_dir)
|
33 |
+
self.transforms = transforms
|
34 |
+
self.to_filter_bad_examples = to_filter_bad_examples
|
35 |
+
self.splits_path = Path(splits_path)
|
36 |
+
self.meta_path = Path(meta_path)
|
37 |
+
self.seed = seed
|
38 |
+
self.load_fixed_offsets_on = [] if load_fixed_offsets_on is None else load_fixed_offsets_on
|
39 |
+
self.vis_load_backend = vis_load_backend
|
40 |
+
self.size_ratio = size_ratio
|
41 |
+
|
42 |
+
self.split2short = {'train': 'unbalanced', 'valid': 'balanced', 'test': 'eval'}
|
43 |
+
short2long = {'unbalanced': 'unbalanced_train_segments',
|
44 |
+
'balanced': 'balanced_train_segments',
|
45 |
+
'eval': 'eval_segments'}
|
46 |
+
|
47 |
+
# read meta
|
48 |
+
split_meta = []
|
49 |
+
for shortdir_vid, start, end, targets, phase in csv.reader(open(meta_path), quotechar='"'):
|
50 |
+
if shortdir_vid.startswith(self.split2short[split]):
|
51 |
+
# shortdir_vid 'unbalanced/NFap9qgsI_s' -> 'unbalanced_train_segments/NFap9qgsI_s'
|
52 |
+
shortdir, vid = shortdir_vid.split('/')
|
53 |
+
longdir_vid = '/'.join([short2long[shortdir], vid])
|
54 |
+
split_meta.append([longdir_vid, float(start), float(end), targets, phase])
|
55 |
+
|
56 |
+
# filter "bad" examples
|
57 |
+
if to_filter_bad_examples:
|
58 |
+
split_meta = self.filter_bad_examples(split_meta)
|
59 |
+
|
60 |
+
# label maps
|
61 |
+
self.label2target = {l: int(t) for t, _, l in csv.reader(open(self.splits_path / 'audioset_labels.csv'))}
|
62 |
+
self.target2label = {t: l for l, t in self.label2target.items()}
|
63 |
+
self.video2target = {key: list(map(int, targets.split(','))) for key, _, _, targets, _ in split_meta}
|
64 |
+
|
65 |
+
clip_paths = [self.vids_dir / f'{k}_{int(s*1000)}_{int(e*1000)}.mp4' for k, s, e, t, p in split_meta]
|
66 |
+
clip_paths = sorted(clip_paths)
|
67 |
+
|
68 |
+
# loading the fixed offsets. COMMENT THIS IF YOU DON'T HAVE A FILE YET
|
69 |
+
if transforms is not None and split in load_fixed_offsets_on:
|
70 |
+
logging.info(f'Using fixed offset for {split}')
|
71 |
+
self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'audioset')
|
72 |
+
|
73 |
+
self.dataset = clip_paths
|
74 |
+
if size_ratio is not None and 0.0 < size_ratio < 1.0:
|
75 |
+
cut_off = int(len(self.dataset) * size_ratio)
|
76 |
+
random.seed(seed)
|
77 |
+
random.shuffle(self.dataset)
|
78 |
+
self.dataset = self.dataset[:cut_off]
|
79 |
+
|
80 |
+
logging.info(f'{split} has {len(self.dataset)} items')
|
81 |
+
|
82 |
+
def filter_bad_examples(self, audioset_meta):
|
83 |
+
bad = set()
|
84 |
+
base_path = Path('./data/filtered_examples_audioset')
|
85 |
+
files = sorted(glob(str(base_path / '*.txt')))
|
86 |
+
lists = [open(p).read().splitlines() for p in files]
|
87 |
+
logging.info(f'Filtering for {files}')
|
88 |
+
for s in lists:
|
89 |
+
bad = bad.union(s)
|
90 |
+
# the ugly string converts '---g-f_I2yQ', '1' into `---g-f_I2yQ_1000_11000`
|
91 |
+
audioset_meta = [r for r in audioset_meta if f'{r[0]}_{int(r[1]*1000)}_{int(r[2]*1000)}' not in bad]
|
92 |
+
return audioset_meta
|
93 |
+
|
94 |
+
def __getitem__(self, index):
|
95 |
+
path = self.dataset[index]
|
96 |
+
rgb, audio, meta = self.load_media(path)
|
97 |
+
item = self.make_datapoint(path, rgb, audio, meta)
|
98 |
+
if self.transforms is not None:
|
99 |
+
item = self.transforms(item)
|
100 |
+
return item
|
101 |
+
|
102 |
+
def make_datapoint(self, path, rgb, audio, meta):
|
103 |
+
# (Tv, 3, H, W) in [0, 225], (Ta, C) in [-1, 1]
|
104 |
+
# TODO: since audioset is annotated by tagging, targets have multiple elemenets: default collate fails
|
105 |
+
# targets = self.video2target[f'{Path(path).parent.stem}/{Path(path).stem[:11]}']
|
106 |
+
item = {
|
107 |
+
'video': rgb,
|
108 |
+
'audio': audio,
|
109 |
+
'meta': meta,
|
110 |
+
'path': str(path),
|
111 |
+
# 'targets': {'audioset_target': [targets], 'audioset_label': [self.target2label[t] for t in targets]},
|
112 |
+
'targets': {},
|
113 |
+
'split': self.split,
|
114 |
+
}
|
115 |
+
|
116 |
+
# loading the fixed offsets. COMMENT THIS IF YOU DON'T HAVE A FILE YET
|
117 |
+
if self.transforms is not None and self.split in self.load_fixed_offsets_on:
|
118 |
+
key = f'{self.split2short[self.split]}/{Path(path).stem}'
|
119 |
+
item['targets']['offset_sec'] = self.vid2offset_params[key]['offset_sec']
|
120 |
+
item['targets']['v_start_i_sec'] = self.vid2offset_params[key]['v_start_i_sec']
|
121 |
+
|
122 |
+
return item
|
123 |
+
|
124 |
+
def load_media(self, path):
|
125 |
+
rgb, audio, meta = get_video_and_audio(path, get_meta=True, end_sec=self.max_clip_len_sec)
|
126 |
+
return rgb, audio, meta
|
127 |
+
|
128 |
+
def __len__(self):
|
129 |
+
return len(self.dataset)
|
130 |
+
|
131 |
+
class AudioSetBalanced737k(AudioSet):
|
132 |
+
|
133 |
+
def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True, splits_path='./data',
|
134 |
+
# here
|
135 |
+
meta_path='./data/audioset_balanced_737k.csv',
|
136 |
+
seed=1337, load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
|
137 |
+
attr_annot_path=None, max_attr_per_vid=None):
|
138 |
+
super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path,
|
139 |
+
seed, load_fixed_offsets_on, vis_load_backend, size_ratio)
|
140 |
+
|
141 |
+
class AudioSetBalanced540k(AudioSet):
|
142 |
+
''' MBT's balanced 500k (from unbalanced part) + 20k from balaced part + 20k from eval part '''
|
143 |
+
|
144 |
+
def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True, splits_path='./data',
|
145 |
+
# here
|
146 |
+
meta_path='./data/audioset_balanced_540k.csv',
|
147 |
+
seed=1337, load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
|
148 |
+
attr_annot_path=None, max_attr_per_vid=None):
|
149 |
+
super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path,
|
150 |
+
seed, load_fixed_offsets_on, vis_load_backend, size_ratio)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == '__main__':
|
154 |
+
from omegaconf import OmegaConf
|
155 |
+
from scripts.train_utils import get_transforms
|
156 |
+
from utils.utils import cfg_sanity_check_and_patch
|
157 |
+
cfg = OmegaConf.load('./configs/sparse_sync.yaml')
|
158 |
+
vis_load_backend = 'read_video'
|
159 |
+
|
160 |
+
transforms = get_transforms(cfg)
|
161 |
+
|
162 |
+
# vids_path = 'PLACEHOLDER'
|
163 |
+
vids_path = '/scratch/project_2000936/vladimir/data/audioset/h264_video_25fps_256side_16000hz_aac'
|
164 |
+
load_fixed_offsets_on = []
|
165 |
+
|
166 |
+
cfg.data.dataset.params.size_ratio = 0.1
|
167 |
+
|
168 |
+
cfg_sanity_check_and_patch(cfg)
|
169 |
+
|
170 |
+
datasets = {
|
171 |
+
'train': AudioSet('train', vids_path, transforms['train'], vis_load_backend=vis_load_backend,
|
172 |
+
to_filter_bad_examples=True, size_ratio=cfg.data.dataset.params.size_ratio,
|
173 |
+
load_fixed_offsets_on=load_fixed_offsets_on),
|
174 |
+
'valid': AudioSet('valid', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
|
175 |
+
to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
|
176 |
+
'test': AudioSet('test', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
|
177 |
+
to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
|
178 |
+
}
|
179 |
+
for phase in ['train', 'valid', 'test']:
|
180 |
+
print(phase, len(datasets[phase]))
|
181 |
+
|
182 |
+
print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
|
183 |
+
print(datasets['train'][0]['meta'])
|
184 |
+
print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
|
185 |
+
print(datasets['valid'][0]['meta'])
|
186 |
+
print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
|
187 |
+
print(datasets['test'][0]['meta'])
|
188 |
+
|
189 |
+
for i in range(300, 1000):
|
190 |
+
datasets['train'][i]['path']
|
191 |
+
print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
|
192 |
+
print(datasets['train'][0]['meta'])
|
193 |
+
|
194 |
+
datasets = {
|
195 |
+
'train': AudioSetBalanced737k('train', vids_path, transforms['train'], vis_load_backend=vis_load_backend,
|
196 |
+
to_filter_bad_examples=True, size_ratio=cfg.data.dataset.params.size_ratio,
|
197 |
+
load_fixed_offsets_on=load_fixed_offsets_on),
|
198 |
+
'valid': AudioSetBalanced737k('valid', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
|
199 |
+
to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
|
200 |
+
'test': AudioSetBalanced737k('test', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
|
201 |
+
to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
|
202 |
+
}
|
203 |
+
for phase in ['train', 'valid', 'test']:
|
204 |
+
print(phase, len(datasets[phase]))
|
205 |
+
|
206 |
+
print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
|
207 |
+
print(datasets['train'][0]['meta'])
|
208 |
+
print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
|
209 |
+
print(datasets['valid'][0]['meta'])
|
210 |
+
print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
|
211 |
+
print(datasets['test'][0]['meta'])
|
212 |
+
|
213 |
+
for i in range(300, 1000):
|
214 |
+
datasets['train'][i]['path']
|
215 |
+
print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
|
216 |
+
print(datasets['train'][0]['meta'])
|
modules/dataset/dataset_utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from pathlib import Path
|
5 |
+
from glob import glob
|
6 |
+
import shutil
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import torchaudio
|
10 |
+
import torchvision
|
11 |
+
|
12 |
+
from utils.utils import get_fixed_off_fname
|
13 |
+
|
14 |
+
|
15 |
+
def get_fixed_offsets(transforms, split, splits_path, dataset_name):
|
16 |
+
'''dataset_name: `vggsound` or `lrs3`'''
|
17 |
+
logging.info(f'Using fixed offset for {split}')
|
18 |
+
vid2offset_params = {}
|
19 |
+
fixed_offset_fname = get_fixed_off_fname(transforms, split)
|
20 |
+
if fixed_offset_fname is None:
|
21 |
+
raise ValueError('Cant find fixed offsets for given params. Perhaps you need to make it first?')
|
22 |
+
fixed_offset_path = os.path.join(splits_path, f'fixed_offsets_{dataset_name}', fixed_offset_fname)
|
23 |
+
fixed_offset_paths = sorted(glob(fixed_offset_path.replace(split, '*')))
|
24 |
+
assert len(fixed_offset_paths) > 0, f'Perhaps: {fixed_offset_path} does not exist. Make fixed offsets'
|
25 |
+
|
26 |
+
for fix_off_path in fixed_offset_paths:
|
27 |
+
reader = csv.reader(open(fix_off_path))
|
28 |
+
# k700_2020 has no header, and also `vstart` comes before `offset_sec`
|
29 |
+
if dataset_name == 'k700_2020':
|
30 |
+
header = ['path', 'vstart_sec', 'offset_sec', 'oos_target']
|
31 |
+
else:
|
32 |
+
header = next(reader)
|
33 |
+
for line in reader:
|
34 |
+
data = dict()
|
35 |
+
for f, value in zip(header, line):
|
36 |
+
if f == 'path':
|
37 |
+
v = value
|
38 |
+
elif f == 'offset_sec':
|
39 |
+
data[f] = float(value)
|
40 |
+
elif f in ['vstart_sec', 'v_start_sec']:
|
41 |
+
f = 'v_start_i_sec'
|
42 |
+
data[f] = float(value)
|
43 |
+
elif f == 'oos_target':
|
44 |
+
data[f] = int(value)
|
45 |
+
else:
|
46 |
+
data[f] = value
|
47 |
+
# assert v not in vid2offset_params, 'otherwise, offs from other splits will override each other'
|
48 |
+
|
49 |
+
# even if we have multiple splits (val=test), we want to make sure that the offsets are the same
|
50 |
+
if v in vid2offset_params:
|
51 |
+
assert all([vid2offset_params[v][k] == data[k] for k in data]), f'{v} isnt unique and vary'
|
52 |
+
|
53 |
+
vid2offset_params[v] = data
|
54 |
+
return vid2offset_params
|
55 |
+
|
56 |
+
|
57 |
+
def maybe_cache_file(path: os.PathLike):
|
58 |
+
'''Motivation: if every job reads from a shared disk it`ll get very slow, consider an image can
|
59 |
+
be 2MB, then with batch size 32, 16 workers in dataloader you`re already requesting 1GB!! -
|
60 |
+
imagine this for all users and all jobs simultaneously.'''
|
61 |
+
# checking if we are on cluster, not on a local machine
|
62 |
+
if 'LOCAL_SCRATCH' in os.environ:
|
63 |
+
cache_dir = os.environ.get('LOCAL_SCRATCH')
|
64 |
+
# a bit ugly but we need not just fname to be appended to `cache_dir` but parent folders,
|
65 |
+
# otherwise the same fnames in multiple folders will create a bug (the same input for multiple paths)
|
66 |
+
cache_path = os.path.join(cache_dir, Path(path).relative_to('/'))
|
67 |
+
if not os.path.exists(cache_path):
|
68 |
+
os.makedirs(Path(cache_path).parent, exist_ok=True)
|
69 |
+
shutil.copyfile(path, cache_path)
|
70 |
+
return cache_path
|
71 |
+
else:
|
72 |
+
return path
|
73 |
+
|
74 |
+
|
75 |
+
def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None):
|
76 |
+
orig_path = path
|
77 |
+
path = maybe_cache_file(path)
|
78 |
+
# (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta)
|
79 |
+
rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, 'sec', output_format='TCHW')
|
80 |
+
assert meta['video_fps'], f'No video fps for {orig_path}'
|
81 |
+
# (Ta) <- (Ca, Ta)
|
82 |
+
audio = audio.mean(dim=0)
|
83 |
+
# FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader.
|
84 |
+
meta = {'video': {'fps': [meta['video_fps']]}, 'audio': {'framerate': [meta['audio_fps']]}, }
|
85 |
+
return rgb, audio, meta
|
86 |
+
|
87 |
+
|
88 |
+
def get_audio_stream(path, get_meta=False):
|
89 |
+
'''Used only in feature extractor training'''
|
90 |
+
path = str(Path(path).with_suffix('.wav'))
|
91 |
+
path = maybe_cache_file(path)
|
92 |
+
waveform, _ = torchaudio.load(path)
|
93 |
+
waveform = waveform.mean(dim=0)
|
94 |
+
if get_meta:
|
95 |
+
info = torchaudio.info(path)
|
96 |
+
duration = info.num_frames / info.sample_rate
|
97 |
+
meta = {'audio': {'duration': [duration], 'framerate': [info.sample_rate]}}
|
98 |
+
return waveform, meta
|
99 |
+
else:
|
100 |
+
return waveform
|
101 |
+
|
102 |
+
def subsample_dataset(dataset: list, size_ratio: float, shuffle: bool = False):
|
103 |
+
if size_ratio is not None and 0.0 < size_ratio < 1.0:
|
104 |
+
logging.info(f'Subsampling dataset to {size_ratio}')
|
105 |
+
# shuffling is important only during subsampling (sometimes paths are sorted by class)
|
106 |
+
if shuffle:
|
107 |
+
random.shuffle(dataset)
|
108 |
+
cut_off = int(len(dataset) * size_ratio)
|
109 |
+
# making sure that we have at least one example
|
110 |
+
dataset = dataset[:max(1, cut_off)]
|
111 |
+
logging.info(f'Subsampled dataset to {size_ratio} (size: {len(dataset)})')
|
112 |
+
return dataset
|
modules/dataset/lrs.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from glob import glob
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
sys.path.insert(0, '.') # nopep8
|
13 |
+
from dataset.dataset_utils import (get_fixed_offsets, get_video_and_audio, subsample_dataset)
|
14 |
+
|
15 |
+
|
16 |
+
class LRS3(torch.utils.data.Dataset):
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
split,
|
20 |
+
vids_dir,
|
21 |
+
transforms=None,
|
22 |
+
splits_path='./data',
|
23 |
+
seed=1337,
|
24 |
+
load_fixed_offsets_on=['valid', 'test'],
|
25 |
+
vis_load_backend='VideoReader',
|
26 |
+
size_ratio=None,
|
27 |
+
attr_annot_path=None,
|
28 |
+
max_attr_per_vid=None,
|
29 |
+
to_filter_bad_examples=True,):
|
30 |
+
super().__init__()
|
31 |
+
self.max_clip_len_sec = 11
|
32 |
+
logging.info(f'During IO, the length of clips is limited to {self.max_clip_len_sec} sec')
|
33 |
+
self.split = split
|
34 |
+
self.vids_dir = vids_dir
|
35 |
+
self.transforms = transforms
|
36 |
+
self.splits_path = splits_path
|
37 |
+
self.seed = seed
|
38 |
+
self.load_fixed_offsets_on = [] if load_fixed_offsets_on is None else load_fixed_offsets_on
|
39 |
+
self.vis_load_backend = vis_load_backend
|
40 |
+
self.size_ratio = size_ratio
|
41 |
+
|
42 |
+
split_clip_ids_path = os.path.join(splits_path, f'lrs3_{split}.txt')
|
43 |
+
if not os.path.exists(split_clip_ids_path):
|
44 |
+
vid_folder = Path(vids_dir) / 'pretrain'
|
45 |
+
clip_paths = sorted(vid_folder.rglob('*/*.mp4'))
|
46 |
+
if to_filter_bad_examples:
|
47 |
+
clip_paths = self.filter_bad_examples(clip_paths)
|
48 |
+
self.make_split_files(clip_paths)
|
49 |
+
|
50 |
+
# read the ids from a split
|
51 |
+
split_clip_ids = sorted(open(split_clip_ids_path).read().splitlines())
|
52 |
+
|
53 |
+
# make paths from the ids
|
54 |
+
clip_paths = [os.path.join(vids_dir, v + '.mp4') for v in split_clip_ids]
|
55 |
+
|
56 |
+
if split in self.load_fixed_offsets_on:
|
57 |
+
logging.info(f'Using fixed offset for {split}')
|
58 |
+
self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'lrs3')
|
59 |
+
|
60 |
+
self.dataset = clip_paths
|
61 |
+
self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
|
62 |
+
|
63 |
+
logging.info(f'{split} has {len(self.dataset)} items')
|
64 |
+
|
65 |
+
def __getitem__(self, index):
|
66 |
+
path = self.dataset[index]
|
67 |
+
rgb, audio, meta = get_video_and_audio(path, get_meta=True, end_sec=self.max_clip_len_sec)
|
68 |
+
|
69 |
+
# (Tv, 3, H, W) in [0, 225], (Ta, C) in [-1, 1]
|
70 |
+
item = {'video': rgb, 'audio': audio, 'meta': meta, 'path': path, 'targets': {}, 'split': self.split}
|
71 |
+
|
72 |
+
# loading fixed offsets so we could evaluate on the same data each time (valid and test)
|
73 |
+
if self.split in self.load_fixed_offsets_on:
|
74 |
+
unique_id = path.replace(f'{self.vids_dir}/', '').replace(self.vids_dir, '').replace('.mp4', '')
|
75 |
+
offset_params = self.vid2offset_params[unique_id]
|
76 |
+
item['targets']['offset_sec'] = offset_params['offset_sec']
|
77 |
+
item['targets']['v_start_i_sec'] = offset_params['v_start_i_sec']
|
78 |
+
if 'oos_target' in offset_params:
|
79 |
+
item['targets']['offset_target'] = {
|
80 |
+
'oos': offset_params['oos_target'], 'offset': item['targets']['offset_sec'],
|
81 |
+
}
|
82 |
+
|
83 |
+
if self.transforms is not None:
|
84 |
+
item = self.transforms(item)
|
85 |
+
|
86 |
+
return item
|
87 |
+
|
88 |
+
def filter_bad_examples(self, paths):
|
89 |
+
bad = set()
|
90 |
+
base_path = Path('./data/filtered_examples_lrs')
|
91 |
+
lists = [open(p).read().splitlines() for p in sorted(glob(str(base_path / '*.txt')))]
|
92 |
+
for s in lists:
|
93 |
+
bad = bad.union(s)
|
94 |
+
logging.info(f'Number of clips before filtering: {len(paths)}')
|
95 |
+
video_ids = [str(i).replace(self.vids_dir, '') for i in paths]
|
96 |
+
video_ids = [str(i).replace(f'{self.vids_dir}/', '') for i in video_ids]
|
97 |
+
paths = sorted([r for r in video_ids if r not in bad])
|
98 |
+
logging.info(f'Number of clips after filtering: {len(paths)}')
|
99 |
+
return paths
|
100 |
+
|
101 |
+
def make_split_files(self, paths):
|
102 |
+
logging.warning(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
|
103 |
+
|
104 |
+
# will be splitting using videos, not clips to prevent train-test intersection
|
105 |
+
all_vids = sorted(list(set([Path(p).parent.name for p in paths])))
|
106 |
+
random.Random(self.seed).shuffle(all_vids)
|
107 |
+
|
108 |
+
# 0.1: splits are 8:1:1
|
109 |
+
hold_out_ratio = 0.1
|
110 |
+
hold_out_size = int(len(all_vids) * hold_out_ratio)
|
111 |
+
test_vids, train_valid_vids = all_vids[:hold_out_size], all_vids[hold_out_size:]
|
112 |
+
valid_vids, train_vids = train_valid_vids[:hold_out_size], train_valid_vids[hold_out_size:]
|
113 |
+
|
114 |
+
# making files
|
115 |
+
for phase, vids in zip(['train', 'valid', 'test'], [train_vids, valid_vids, test_vids]):
|
116 |
+
with open(os.path.join(self.splits_path, f'lrs3_{phase}.txt'), 'w') as wfile:
|
117 |
+
for path in paths:
|
118 |
+
vid_name = Path(path).parent.name
|
119 |
+
# just in the case I forgot the trailing '/' in the path
|
120 |
+
unique_id = path.replace(f'{self.vids_dir}/', '').replace(self.vids_dir, '') \
|
121 |
+
.replace('.mp4', '')
|
122 |
+
if vid_name in vids:
|
123 |
+
wfile.write(unique_id + '\n')
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
return len(self.dataset)
|
127 |
+
|
128 |
+
class LongerLRS3(LRS3):
|
129 |
+
'''This class is different to the parent in the extra filtering it does. If the parent was
|
130 |
+
making the splits with filtering for shorter than 9 second, this class filters for shorter than 9.5 sec.
|
131 |
+
by applying extra filtering.
|
132 |
+
'''
|
133 |
+
|
134 |
+
def __init__(self,
|
135 |
+
split,
|
136 |
+
vids_dir,
|
137 |
+
transforms=None,
|
138 |
+
splits_path='./data',
|
139 |
+
seed=1337,
|
140 |
+
load_fixed_offsets_on=['valid', 'test'],
|
141 |
+
vis_load_backend='VideoReader',
|
142 |
+
size_ratio=None,
|
143 |
+
attr_annot_path=None,
|
144 |
+
max_attr_per_vid=None,
|
145 |
+
to_filter_bad_examples=True,):
|
146 |
+
# size_ratio is not used here as we are doing it this class (avoiding double subsampling)
|
147 |
+
super().__init__(split, vids_dir, transforms, splits_path, seed, load_fixed_offsets_on,
|
148 |
+
vis_load_backend, None, attr_annot_path, max_attr_per_vid,
|
149 |
+
to_filter_bad_examples)
|
150 |
+
# does extra filtering
|
151 |
+
if to_filter_bad_examples:
|
152 |
+
self.dataset = self.filter_bad_examples(self.dataset)
|
153 |
+
self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
|
154 |
+
logging.info(f'{split} has {len(self.dataset)} items')
|
155 |
+
|
156 |
+
def filter_bad_examples(self, paths):
|
157 |
+
bad = set()
|
158 |
+
base_path = Path('./data/filtered_examples_lrs_extra')
|
159 |
+
lists = [open(p).read().splitlines() for p in sorted(glob(str(base_path / '*.txt')))]
|
160 |
+
for s in lists:
|
161 |
+
bad = bad.union(s)
|
162 |
+
logging.info(f'Number of clips before filtering: {len(paths)}')
|
163 |
+
video_ids = [str(i).replace(self.vids_dir, '').replace(f'{self.vids_dir}/', '') for i in paths]
|
164 |
+
paths = sorted([os.path.join(self.vids_dir, r) for r in video_ids if r not in bad])
|
165 |
+
logging.info(f'Number of clips after filtering: {len(paths)}')
|
166 |
+
return paths
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
from time import time
|
171 |
+
from omegaconf import OmegaConf
|
172 |
+
import sys
|
173 |
+
sys.path.insert(0, '.') # nopep8
|
174 |
+
from scripts.train_utils import get_transforms
|
175 |
+
from utils.utils import cfg_sanity_check_and_patch
|
176 |
+
cfg = OmegaConf.load('./configs/sparse_sync.yaml')
|
177 |
+
cfg.data.vids_path = '/scratch/local/hdd/vi/data/lrs3/h264_uncropped_25fps_256side_16000hz_aac/'
|
178 |
+
cfg.data.dataset.params.load_fixed_offsets_on = ['valid', 'test']
|
179 |
+
|
180 |
+
cfg_sanity_check_and_patch(cfg)
|
181 |
+
transforms = get_transforms(cfg)
|
182 |
+
|
183 |
+
datasets = {
|
184 |
+
'train': LRS3('train', cfg.data.vids_path, transforms['train'], load_fixed_offsets_on=[]),
|
185 |
+
'valid': LRS3('valid', cfg.data.vids_path, transforms['test'], load_fixed_offsets_on=[]),
|
186 |
+
'test': LRS3('test', cfg.data.vids_path, transforms['test'], load_fixed_offsets_on=[]),
|
187 |
+
}
|
188 |
+
for phase in ['train', 'valid', 'test']:
|
189 |
+
print(phase, len(datasets[phase]))
|
190 |
+
|
191 |
+
print(datasets['train'][1]['audio'].shape, datasets['train'][1]['video'].shape)
|
192 |
+
print(datasets['valid'][1]['audio'].shape, datasets['valid'][1]['video'].shape)
|
modules/dataset/transforms.py
ADDED
@@ -0,0 +1,1074 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
from typing import Tuple
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import torchaudio
|
8 |
+
import numpy as np
|
9 |
+
import einops
|
10 |
+
|
11 |
+
|
12 |
+
def sec2frames(sec, fps):
|
13 |
+
return int(sec * fps)
|
14 |
+
|
15 |
+
def frames2sec(frames, fps):
|
16 |
+
return frames / fps
|
17 |
+
|
18 |
+
|
19 |
+
class EqualifyFromRight(torch.nn.Module):
|
20 |
+
|
21 |
+
def __init__(self, clip_max_len_sec=10):
|
22 |
+
'''
|
23 |
+
Takes the dataset item and makes sure more streams are of an equal size in terms of fps.
|
24 |
+
It, however, assumes that the signal is synched and trims the ending parts ('from the right').
|
25 |
+
'''
|
26 |
+
super().__init__()
|
27 |
+
self.clip_max_len_sec = clip_max_len_sec
|
28 |
+
|
29 |
+
def forward(self, item):
|
30 |
+
'''
|
31 |
+
`item`: {'video': (Tv, C, H, W), 'audio': (Ta,),
|
32 |
+
'meta': {
|
33 |
+
'audio': {'framerate': [float], 'duration': [float]}
|
34 |
+
'video': {'fps': [float], 'duration': [float]}}
|
35 |
+
'''
|
36 |
+
a_fps = item['meta']['audio']['framerate'][0]
|
37 |
+
v_fps = item['meta']['video']['fps'][0]
|
38 |
+
|
39 |
+
Ta = item['audio'].shape[0]
|
40 |
+
Tv, C, H, W = item['video'].shape
|
41 |
+
|
42 |
+
a_len_secs = Ta / a_fps
|
43 |
+
v_len_secs = Tv / v_fps
|
44 |
+
min_len = min(self.clip_max_len_sec, a_len_secs, v_len_secs)
|
45 |
+
|
46 |
+
a_frames_per_v_frame = a_fps // v_fps
|
47 |
+
v_len_frames = int(v_fps * min_len)
|
48 |
+
a_len_frames = int(a_frames_per_v_frame * v_len_frames)
|
49 |
+
# print(a_len_frames, v_len_frames)
|
50 |
+
|
51 |
+
assert a_len_frames <= Ta and v_len_frames <= Tv
|
52 |
+
|
53 |
+
item['audio'] = item['audio'][:a_len_frames]
|
54 |
+
item['video'] = item['video'][:v_len_frames, :, :, :]
|
55 |
+
|
56 |
+
return item
|
57 |
+
|
58 |
+
|
59 |
+
class RGBSpatialCrop(torch.nn.Module):
|
60 |
+
|
61 |
+
def __init__(self, input_size, is_random):
|
62 |
+
super().__init__()
|
63 |
+
assert input_size is not None, f'smaller_input_size is `{input_size}`'
|
64 |
+
if isinstance(input_size, int):
|
65 |
+
input_size = (input_size, input_size)
|
66 |
+
self.input_size = input_size
|
67 |
+
self.is_random = is_random
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def get_random_crop_sides(vid, output_size):
|
71 |
+
'''Slice parameters for random crop'''
|
72 |
+
h, w = vid.shape[-2:]
|
73 |
+
th, tw = output_size
|
74 |
+
if w == tw and h == th:
|
75 |
+
return 0, 0, h, w
|
76 |
+
i = random.randint(0, h - th)
|
77 |
+
j = random.randint(0, w - tw)
|
78 |
+
return i, j, th, tw
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def get_center_crop_sides(vid, output_size):
|
82 |
+
'''Slice parameters for center crop'''
|
83 |
+
h, w = vid.shape[-2:]
|
84 |
+
th, tw = output_size
|
85 |
+
|
86 |
+
i = int(round((h - th) / 2.))
|
87 |
+
j = int(round((w - tw) / 2.))
|
88 |
+
return i, j, th, tw
|
89 |
+
|
90 |
+
def forward(self, item):
|
91 |
+
# (Tv, C, H, W)
|
92 |
+
vid = item['video']
|
93 |
+
if self.is_random:
|
94 |
+
i, j, h, w = self.get_random_crop_sides(vid, self.input_size)
|
95 |
+
else:
|
96 |
+
i, j, h, w = self.get_center_crop_sides(vid, self.input_size)
|
97 |
+
item['video'] = vid[..., i:(i + h), j:(j + w)]
|
98 |
+
return item
|
99 |
+
|
100 |
+
class Resize(torchvision.transforms.Resize):
|
101 |
+
|
102 |
+
def __init__(self, *args, **kwargs):
|
103 |
+
super().__init__(*args, **kwargs)
|
104 |
+
|
105 |
+
def forward(self, item):
|
106 |
+
item['video'] = super().forward(item['video'])
|
107 |
+
return item
|
108 |
+
|
109 |
+
|
110 |
+
class RGBSpatialCropSometimesUpscale(torch.nn.Module):
|
111 |
+
'''This (randomly) crops the input video and with prob `sometimes_p` this crop is smaller but upscaled
|
112 |
+
to `target_input_size`'''
|
113 |
+
|
114 |
+
def __init__(self, sometimes_p, target_input_size, is_random, smaller_input_size=None):
|
115 |
+
super().__init__()
|
116 |
+
self.sometimes_p = sometimes_p
|
117 |
+
self.do_sometimes_upscale = sometimes_p is not None and sometimes_p > 0
|
118 |
+
|
119 |
+
self.crop_only = RGBSpatialCrop(target_input_size, is_random)
|
120 |
+
|
121 |
+
if self.do_sometimes_upscale:
|
122 |
+
self.crop_further_and_upscale = torchvision.transforms.Compose([
|
123 |
+
RGBSpatialCrop(smaller_input_size, is_random),
|
124 |
+
Resize(target_input_size, antialias=None),
|
125 |
+
])
|
126 |
+
|
127 |
+
def forward(self, item):
|
128 |
+
assert len(item['video'].shape) == 4, \
|
129 |
+
f"{item['video'].shape}: if it is applied after GenerateMultipleClips," \
|
130 |
+
"augs should be applied to each clip separately, not to the whole video array. " \
|
131 |
+
"Otherwise, ignore this warning (comment it)."
|
132 |
+
if self.do_sometimes_upscale and self.sometimes_p > torch.rand(1):
|
133 |
+
return self.crop_further_and_upscale(item)
|
134 |
+
else:
|
135 |
+
return self.crop_only(item)
|
136 |
+
|
137 |
+
|
138 |
+
class RandomApplyColorDistortion(torch.nn.Module):
|
139 |
+
|
140 |
+
def __init__(self, p_gray_scale=0., p_color_jitter=0., s=1.) -> None:
|
141 |
+
super().__init__()
|
142 |
+
self.p_gray_scale = p_gray_scale
|
143 |
+
self.p_color_jitter = p_color_jitter
|
144 |
+
self.s = s
|
145 |
+
assert 0 <= self.p_color_jitter <= 1 and 0 <= self.p_gray_scale <= 1, (p_color_jitter, p_gray_scale)
|
146 |
+
# SimCLR params
|
147 |
+
color_jitter = torchvision.transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
|
148 |
+
rand_color_jitter = torchvision.transforms.RandomApply([color_jitter], p_color_jitter)
|
149 |
+
rand_gray = torchvision.transforms.RandomGrayscale(p_gray_scale)
|
150 |
+
self.transforms = torchvision.transforms.Compose([rand_color_jitter, rand_gray])
|
151 |
+
|
152 |
+
def apply_to_single_clip(self, clip):
|
153 |
+
return self.transforms(clip)
|
154 |
+
|
155 |
+
def apply_to_each_clip(self, clips):
|
156 |
+
for i, clip in enumerate(clips):
|
157 |
+
clips[i] = self.apply_to_single_clip(clip)
|
158 |
+
return clips
|
159 |
+
|
160 |
+
def forward(self, item):
|
161 |
+
has_batch_dim = len(item['video'].shape) == 5
|
162 |
+
if has_batch_dim:
|
163 |
+
fn = self.apply_to_each_clip
|
164 |
+
else:
|
165 |
+
fn = self.apply_to_single_clip
|
166 |
+
item['video'] = fn(item['video'])
|
167 |
+
return item
|
168 |
+
|
169 |
+
|
170 |
+
class ApplyColorJitterFrameWise(torch.nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, s=1.) -> None:
|
173 |
+
super().__init__()
|
174 |
+
self.s = s
|
175 |
+
# SimCLR params
|
176 |
+
self.transform = torchvision.transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
|
177 |
+
|
178 |
+
def apply_to_single_clip(self, clip):
|
179 |
+
for i, frame in enumerate(clip):
|
180 |
+
clip[i] = self.transform(frame)
|
181 |
+
return clip
|
182 |
+
|
183 |
+
def apply_to_each_clip(self, clips):
|
184 |
+
for i, clip in enumerate(clips):
|
185 |
+
clips[i] = self.apply_to_single_clip(clip)
|
186 |
+
return clips
|
187 |
+
|
188 |
+
def forward(self, item):
|
189 |
+
has_batch_dim = len(item['video'].shape) == 5
|
190 |
+
if has_batch_dim:
|
191 |
+
fn = self.apply_to_each_clip
|
192 |
+
else:
|
193 |
+
fn = self.apply_to_single_clip
|
194 |
+
item['video'] = fn(item['video'])
|
195 |
+
return item
|
196 |
+
|
197 |
+
|
198 |
+
class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
|
199 |
+
|
200 |
+
def __init__(self, p=0.5):
|
201 |
+
super().__init__(p)
|
202 |
+
|
203 |
+
def apply_to_single_clip(self, clip):
|
204 |
+
return super().forward(clip)
|
205 |
+
|
206 |
+
def apply_to_each_clip(self, clips):
|
207 |
+
for i, clip in enumerate(clips):
|
208 |
+
clips[i] = self.apply_to_single_clip(clip)
|
209 |
+
return clips
|
210 |
+
|
211 |
+
def forward(self, item):
|
212 |
+
has_batch_dim = len(item['video'].shape) == 5
|
213 |
+
if has_batch_dim:
|
214 |
+
fn = self.apply_to_each_clip
|
215 |
+
else:
|
216 |
+
fn = self.apply_to_single_clip
|
217 |
+
item['video'] = fn(item['video'])
|
218 |
+
return item
|
219 |
+
|
220 |
+
|
221 |
+
def make_class_grid(leftmost_val, rightmost_val, grid_size, add_extreme_offset: bool = False,
|
222 |
+
seg_size_vframes: int = None, nseg: int = None, step_size_seg: float = None,
|
223 |
+
vfps: float = None):
|
224 |
+
assert grid_size >= 3, f'grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()'
|
225 |
+
grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float()
|
226 |
+
if add_extreme_offset:
|
227 |
+
assert all([seg_size_vframes, nseg, step_size_seg]), f'{seg_size_vframes} {nseg} {step_size_seg}'
|
228 |
+
seg_size_sec = seg_size_vframes / vfps
|
229 |
+
trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1)
|
230 |
+
extreme_value = trim_size_in_seg * seg_size_sec
|
231 |
+
grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid
|
232 |
+
return grid
|
233 |
+
|
234 |
+
|
235 |
+
def quantize_offset(grid: torch.Tensor, off_sec: float) -> Tuple[float, int]:
|
236 |
+
'''Takes in the offset in seconds and snaps it onto the closest grid element.
|
237 |
+
Returns the grid value and its index.'''
|
238 |
+
closest_grid_el = (grid - off_sec).abs().argmin()
|
239 |
+
return grid[closest_grid_el], closest_grid_el
|
240 |
+
|
241 |
+
def apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps, max_a_jitter_sec):
|
242 |
+
max_a_start_i = a_len_frames - a_crop_len_frames
|
243 |
+
max_a_jitter_i = sec2frames(max_a_jitter_sec, a_fps)
|
244 |
+
max_a_jitter_i_left = min(a_start_i, max_a_jitter_i)
|
245 |
+
max_a_jitter_i_right = min(max_a_start_i - a_start_i, max_a_jitter_i)
|
246 |
+
# jitter is U[left, right]
|
247 |
+
a_jitter_i = random.randint(-max_a_jitter_i_left, max_a_jitter_i_right)
|
248 |
+
# apply jitter
|
249 |
+
a_start_i = a_start_i + a_jitter_i
|
250 |
+
# making sure that any value from `a_start_i + U[left, right]` will be inside of [0, len-crop] region
|
251 |
+
assert 0 <= a_start_i <= max_a_start_i, f'{a_jitter_i} {max_a_jitter_i_left} {max_a_jitter_i_right} {max_a_start_i}'
|
252 |
+
return a_start_i, a_jitter_i
|
253 |
+
|
254 |
+
|
255 |
+
class TemporalCropAndOffset(torch.nn.Module):
|
256 |
+
|
257 |
+
def __init__(self, crop_len_sec: float, max_off_sec: float, offset_type='grid', do_offset: bool = True,
|
258 |
+
grid_size: int = None, max_wiggle_sec: float = None, add_doubt_cls: bool = False,
|
259 |
+
segment_size_vframes: int = None, n_segments: int = None, step_size_seg: float = None,
|
260 |
+
vfps: float = None, prob_oos: float = None):
|
261 |
+
super().__init__()
|
262 |
+
self.crop_len_sec = crop_len_sec
|
263 |
+
self.do_offset = do_offset
|
264 |
+
self.grid_size = grid_size
|
265 |
+
self.offset_type = offset_type
|
266 |
+
self.max_off_sec = max_off_sec
|
267 |
+
self.max_a_jitter_sec = max_wiggle_sec
|
268 |
+
if do_offset:
|
269 |
+
if offset_type == 'grid':
|
270 |
+
self.class_grid = make_class_grid(-max_off_sec, max_off_sec, grid_size, add_doubt_cls,
|
271 |
+
segment_size_vframes, n_segments, step_size_seg, vfps)
|
272 |
+
logging.info(f'Offsets class grid: {self.class_grid}')
|
273 |
+
if self.max_a_jitter_sec is not None:
|
274 |
+
assert (max_wiggle_sec-1e-6) <= ((self.class_grid[1] - self.class_grid[0]) / 2), f'{self.class_grid}'
|
275 |
+
elif offset_type == 'uniform':
|
276 |
+
self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec)
|
277 |
+
logging.info(f'Offset uniform distribution: {self.off_dist}')
|
278 |
+
elif offset_type == 'uniform_binary':
|
279 |
+
self.itu_t_range = (-0.125, 0.045)
|
280 |
+
self.prob_oos = prob_oos
|
281 |
+
self.ins_dist = torch.distributions.uniform.Uniform(self.itu_t_range[0], self.itu_t_range[1])
|
282 |
+
self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec)
|
283 |
+
else:
|
284 |
+
raise NotImplementedError(f'Unknown offset type: {offset_type}')
|
285 |
+
|
286 |
+
def forward(self, item):
|
287 |
+
vid = item['video']
|
288 |
+
aud = item['audio']
|
289 |
+
v_len_frames, C, H, W = vid.shape
|
290 |
+
a_len_frames = aud.shape[0]
|
291 |
+
|
292 |
+
v_fps = int(item['meta']['video']['fps'][0])
|
293 |
+
a_fps = int(item['meta']['audio']['framerate'][0])
|
294 |
+
|
295 |
+
v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps)
|
296 |
+
a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps)
|
297 |
+
|
298 |
+
if self.do_offset:
|
299 |
+
# trying to get the offset parameters (for instance during valid and test we have fixed offsets)
|
300 |
+
offset_sec = item['targets'].get('offset_sec', None)
|
301 |
+
v_start_i_sec = item['targets'].get('v_start_i_sec', None)
|
302 |
+
if 'offset_target' in item['targets']:
|
303 |
+
is_oos = item['targets']['offset_target'].get('oos', None)
|
304 |
+
# train-time
|
305 |
+
if offset_sec is None and v_start_i_sec is None:
|
306 |
+
# aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec
|
307 |
+
if self.offset_type == 'grid':
|
308 |
+
offset_sec = random.choice(self.class_grid.tolist())
|
309 |
+
elif self.offset_type == 'uniform':
|
310 |
+
offset_sec = self.off_dist.sample().item()
|
311 |
+
elif self.offset_type == 'uniform_binary':
|
312 |
+
# in-sync: Uniform(-0.125, 0.045)
|
313 |
+
# out-of-sync: Uniform(-5.5, 5.5) and resampled until not in Uniform(-0.125, 0.045)
|
314 |
+
# first, we sample if the offset is out-of-sync with prob_oss
|
315 |
+
is_oos = (torch.rand(1) < self.prob_oos).item()
|
316 |
+
if is_oos:
|
317 |
+
# second, we sample the offset itself (if in in-sync range, trying again)
|
318 |
+
offset_sec = self.off_dist.sample().item()
|
319 |
+
while self.itu_t_range[0] <= offset_sec <= self.itu_t_range[1]:
|
320 |
+
offset_sec = self.off_dist.sample().item()
|
321 |
+
else:
|
322 |
+
offset_sec = self.ins_dist.sample().item()
|
323 |
+
offset_sec = round(offset_sec, 2)
|
324 |
+
v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps)
|
325 |
+
assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}'
|
326 |
+
# `v_start_sec` IS NOT rounded to the fps grid
|
327 |
+
v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec-offset_sec))
|
328 |
+
assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}'
|
329 |
+
v_start_i = sec2frames(v_start_sec, v_fps)
|
330 |
+
# `v_start_i_sec` IS rounded to the fps grid
|
331 |
+
v_start_i_sec = frames2sec(v_start_i, v_fps)
|
332 |
+
else:
|
333 |
+
offset_sec = round(offset_sec, 2)
|
334 |
+
v_start_i = sec2frames(v_start_i_sec, v_fps)
|
335 |
+
v_end_i = v_start_i + v_crop_len_frames
|
336 |
+
# `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise
|
337 |
+
# (v_start_sec) we have ±0.1 jittering
|
338 |
+
a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
|
339 |
+
else:
|
340 |
+
offset_sec = 0.0
|
341 |
+
is_random_crop = item['split'] == 'train'
|
342 |
+
v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop)
|
343 |
+
v_start_i_sec = frames2sec(v_start_i, v_fps)
|
344 |
+
a_start_i = sec2frames(v_start_i_sec, a_fps)
|
345 |
+
|
346 |
+
# sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48
|
347 |
+
# given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5)
|
348 |
+
if a_start_i < 0:
|
349 |
+
how_much_out = a_start_i
|
350 |
+
logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}')
|
351 |
+
if abs(how_much_out) <= a_fps / v_fps:
|
352 |
+
logging.info('fixing it')
|
353 |
+
a_start_i += abs(how_much_out)
|
354 |
+
else:
|
355 |
+
raise Exception(f'{how_much_out} {item["path"]}')
|
356 |
+
|
357 |
+
if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
|
358 |
+
a_start_i, a_jitter_i = apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps,
|
359 |
+
self.max_a_jitter_sec)
|
360 |
+
item['meta']['a_jitter_i'] = a_jitter_i
|
361 |
+
|
362 |
+
a_end_i = a_start_i + a_crop_len_frames
|
363 |
+
|
364 |
+
assert v_start_i < v_end_i and a_start_i < a_end_i
|
365 |
+
assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}'
|
366 |
+
assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}'
|
367 |
+
|
368 |
+
vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i]
|
369 |
+
|
370 |
+
item['video'] = vid
|
371 |
+
item['audio'] = aud
|
372 |
+
|
373 |
+
assert item['video'].shape[0] == v_fps * self.crop_len_sec, f'{item["video"].shape} {item["path"]}'
|
374 |
+
assert item['audio'].shape[0] == a_fps * self.crop_len_sec, f'{item["audio"].shape} {item["path"]}'
|
375 |
+
|
376 |
+
# caching parameters
|
377 |
+
if self.do_offset:
|
378 |
+
if self.offset_type == 'grid':
|
379 |
+
offset_label, offset_target = quantize_offset(self.class_grid, offset_sec)
|
380 |
+
elif self.offset_type == 'uniform':
|
381 |
+
offset_label, offset_target = offset_sec, offset_sec
|
382 |
+
elif self.offset_type == 'uniform_binary':
|
383 |
+
offset_label, offset_target = offset_sec, {'oos': is_oos, 'offset': offset_sec}
|
384 |
+
item['targets']['offset_sec'] = offset_sec
|
385 |
+
item['targets']['v_start_i_sec'] = v_start_i_sec
|
386 |
+
item['targets']['offset_label'] = offset_label
|
387 |
+
# assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?'
|
388 |
+
item['targets']['offset_target'] = offset_target
|
389 |
+
|
390 |
+
return item
|
391 |
+
|
392 |
+
def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True):
|
393 |
+
if len_frames == crop_len_frames:
|
394 |
+
return 0, len_frames
|
395 |
+
if is_random:
|
396 |
+
left_i = random.randint(0, len_frames - crop_len_frames)
|
397 |
+
else:
|
398 |
+
left_i = int(round((len_frames - crop_len_frames) / 2.))
|
399 |
+
return left_i, left_i+crop_len_frames
|
400 |
+
|
401 |
+
|
402 |
+
class GenerateMultipleSegments(torch.nn.Module):
|
403 |
+
'''
|
404 |
+
Given an item with video and audio, generates a batch of `n_segments` segments
|
405 |
+
of length `segment_size_vframes` (if None, the max number of segments will be made).
|
406 |
+
If `is_start_random` is True, the starting position of the 1st segment will be random but respecting
|
407 |
+
n_segments.
|
408 |
+
`audio_jitter_sec` is the amount of audio offset in seconds.
|
409 |
+
'''
|
410 |
+
|
411 |
+
def __init__(self, segment_size_vframes: int, n_segments: int = None, is_start_random: bool = False,
|
412 |
+
audio_jitter_sec: float = 0., step_size_seg: float = 1):
|
413 |
+
super().__init__()
|
414 |
+
self.segment_size_vframes = segment_size_vframes
|
415 |
+
self.n_segments = n_segments
|
416 |
+
self.is_start_random = is_start_random
|
417 |
+
self.audio_jitter_sec = audio_jitter_sec
|
418 |
+
self.step_size_seg = step_size_seg
|
419 |
+
logging.info(f'Segment step size: {self.step_size_seg}')
|
420 |
+
|
421 |
+
def forward(self, item):
|
422 |
+
v_len_frames, C, H, W = item['video'].shape
|
423 |
+
a_len_frames = item['audio'].shape[0]
|
424 |
+
|
425 |
+
v_fps = int(item['meta']['video']['fps'][0])
|
426 |
+
a_fps = int(item['meta']['audio']['framerate'][0])
|
427 |
+
|
428 |
+
## Determining the number of segments
|
429 |
+
# segment size
|
430 |
+
segment_size_vframes = self.segment_size_vframes
|
431 |
+
segment_size_aframes = sec2frames(frames2sec(self.segment_size_vframes, v_fps), a_fps)
|
432 |
+
# step size (stride)
|
433 |
+
stride_vframes = int(self.step_size_seg * segment_size_vframes)
|
434 |
+
stride_aframes = int(self.step_size_seg * segment_size_aframes)
|
435 |
+
# calculating the number of segments. (W - F + 2P) / S + 1
|
436 |
+
n_segments_max_v = math.floor((v_len_frames - segment_size_vframes) / stride_vframes) + 1
|
437 |
+
n_segments_max_a = math.floor((a_len_frames - segment_size_aframes) / stride_aframes) + 1
|
438 |
+
# making sure audio and video can accommodate the same number of segments
|
439 |
+
n_segments_max = min(n_segments_max_v, n_segments_max_a)
|
440 |
+
n_segments = n_segments_max if self.n_segments is None else self.n_segments
|
441 |
+
|
442 |
+
assert n_segments <= n_segments_max, \
|
443 |
+
f'cant make {n_segments} segs of len {self.segment_size_vframes} in a vid ' \
|
444 |
+
f'of len {v_len_frames} for {item["path"]}'
|
445 |
+
|
446 |
+
# (n_segments, 2) each
|
447 |
+
v_ranges, a_ranges = self.get_sequential_seg_ranges(v_len_frames, a_len_frames, v_fps, a_fps,
|
448 |
+
n_segments, segment_size_aframes)
|
449 |
+
|
450 |
+
# segmenting original streams (n_segments, segment_size_frames, C, H, W)
|
451 |
+
item['video'] = torch.stack([item['video'][s:e] for s, e in v_ranges], dim=0)
|
452 |
+
item['audio'] = torch.stack([item['audio'][s:e] for s, e in a_ranges], dim=0)
|
453 |
+
return item
|
454 |
+
|
455 |
+
def get_sequential_seg_ranges(self, v_len_frames, a_len_frames, v_fps, a_fps, n_seg, seg_size_aframes):
|
456 |
+
# if is_start_random is True, the starting position of the 1st segment will
|
457 |
+
# be random but respecting n_segments like so: "-CCCCCCCC---" (maybe with fixed overlap),
|
458 |
+
# else the segments are taken from the middle of the video respecting n_segments: "--CCCCCCCC--"
|
459 |
+
|
460 |
+
seg_size_vframes = self.segment_size_vframes # for brevity
|
461 |
+
|
462 |
+
# calculating the step size in frames
|
463 |
+
step_size_vframes = int(self.step_size_seg * seg_size_vframes)
|
464 |
+
step_size_aframes = int(self.step_size_seg * seg_size_aframes)
|
465 |
+
|
466 |
+
# calculating the length of the sequence of segments (and in frames)
|
467 |
+
seg_seq_len = n_seg * self.step_size_seg + (1 - self.step_size_seg)
|
468 |
+
vframes_seg_seq_len = int(seg_seq_len * seg_size_vframes)
|
469 |
+
aframes_seg_seq_len = int(seg_seq_len * seg_size_aframes)
|
470 |
+
|
471 |
+
# doing temporal crop
|
472 |
+
max_v_start_i = v_len_frames - vframes_seg_seq_len
|
473 |
+
if self.is_start_random:
|
474 |
+
v_start_i = random.randint(0, max_v_start_i)
|
475 |
+
else:
|
476 |
+
v_start_i = max_v_start_i // 2
|
477 |
+
a_start_i = sec2frames(frames2sec(v_start_i, v_fps), a_fps) # vid frames -> seconds -> aud frames
|
478 |
+
|
479 |
+
# make segments starts
|
480 |
+
v_start_seg_i = torch.tensor([v_start_i + i * step_size_vframes for i in range(n_seg)]).int()
|
481 |
+
a_start_seg_i = torch.tensor([a_start_i + i * step_size_aframes for i in range(n_seg)]).int()
|
482 |
+
|
483 |
+
# apply jitter to audio
|
484 |
+
if self.audio_jitter_sec > 0:
|
485 |
+
jitter_aframes = sec2frames(self.audio_jitter_sec, a_fps)
|
486 |
+
# making sure after applying jitter, the audio is still within the audio boundaries
|
487 |
+
jitter_aframes = min(jitter_aframes, a_start_i, a_len_frames-a_start_i-aframes_seg_seq_len)
|
488 |
+
a_start_seg_i += random.randint(-jitter_aframes, jitter_aframes) # applying jitter to segments
|
489 |
+
|
490 |
+
# make segment ends
|
491 |
+
v_ends_seg_i = v_start_seg_i + seg_size_vframes
|
492 |
+
a_ends_seg_i = a_start_seg_i + seg_size_aframes # using the adjusted a_start_seg_i (with jitter)
|
493 |
+
|
494 |
+
# make ranges
|
495 |
+
v_ranges = torch.stack([v_start_seg_i, v_ends_seg_i], dim=1)
|
496 |
+
a_ranges = torch.stack([a_start_seg_i, a_ends_seg_i], dim=1)
|
497 |
+
assert (a_ranges >= 0).all() and (a_ranges <= a_len_frames).all(), f'{a_ranges} out of {a_len_frames}'
|
498 |
+
assert (v_ranges <= v_len_frames).all(), f'{v_ranges} out of {v_len_frames}'
|
499 |
+
return v_ranges, a_ranges
|
500 |
+
|
501 |
+
|
502 |
+
class TemporalCropAndOffsetForSyncabilityTraining(torch.nn.Module):
|
503 |
+
|
504 |
+
def __init__(self, max_off_sec: float, do_offset: bool = True,
|
505 |
+
grid_size: int = None, max_wiggle_sec: float = None,
|
506 |
+
segment_size_vframes: int = None, n_segments: int = None, step_size_seg: float = None,
|
507 |
+
vfps: float = None):
|
508 |
+
super().__init__()
|
509 |
+
seg_size_sec = segment_size_vframes / vfps
|
510 |
+
trim_size_in_seg = n_segments - (1 - step_size_seg) * (n_segments - 1)
|
511 |
+
self.crop_len_sec = round(trim_size_in_seg * seg_size_sec, 2)
|
512 |
+
logging.info(f'Crop len: {self.crop_len_sec}')
|
513 |
+
self.do_offset = do_offset
|
514 |
+
self.grid_size = grid_size
|
515 |
+
self.max_off_sec = max_off_sec
|
516 |
+
self.max_a_jitter_sec = max_wiggle_sec
|
517 |
+
self.segment_size_vframes = segment_size_vframes
|
518 |
+
self.n_segments = n_segments
|
519 |
+
self.step_size_seg = step_size_seg
|
520 |
+
self.prob_syncable = 0.5
|
521 |
+
if do_offset:
|
522 |
+
self.class_grid = make_class_grid(-max_off_sec, max_off_sec, grid_size)
|
523 |
+
logging.info(f'Offset class grid: {self.class_grid}')
|
524 |
+
if self.max_a_jitter_sec is not None:
|
525 |
+
assert (max_wiggle_sec-1e-6) <= ((self.class_grid[1] - self.class_grid[0]) / 2), f'{self.class_grid}'
|
526 |
+
|
527 |
+
def forward(self, item):
|
528 |
+
vid = item['video']
|
529 |
+
aud = item['audio']
|
530 |
+
v_len_frames, C, H, W = vid.shape
|
531 |
+
a_len_frames = aud.shape[0]
|
532 |
+
|
533 |
+
v_fps = int(item['meta']['video']['fps'][0])
|
534 |
+
a_fps = int(item['meta']['audio']['framerate'][0])
|
535 |
+
|
536 |
+
v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps)
|
537 |
+
a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps)
|
538 |
+
|
539 |
+
if self.do_offset:
|
540 |
+
# trying to get the offset parameters (for instance during valid and test we have fixed offsets)
|
541 |
+
offset_sec = item['targets'].get('offset_sec', None)
|
542 |
+
v_start_i_sec = item['targets'].get('v_start_i_sec', None)
|
543 |
+
# train-time
|
544 |
+
if offset_sec is None and v_start_i_sec is None:
|
545 |
+
|
546 |
+
# for the syncability training, we want to have a syncable or non-syncable offset with 50% prob
|
547 |
+
offset_is_syncable = random.random() < self.prob_syncable # 1=syncable, 0=non-syncable
|
548 |
+
if offset_is_syncable:
|
549 |
+
offset_sec = random.choice(self.class_grid.tolist())
|
550 |
+
else:
|
551 |
+
offset_sec = random.choice([-self.crop_len_sec, self.crop_len_sec]) # either - or + offset
|
552 |
+
# aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec
|
553 |
+
|
554 |
+
offset_sec = round(offset_sec, 2)
|
555 |
+
v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps)
|
556 |
+
assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}'
|
557 |
+
# `v_start_sec` IS NOT rounded to the fps grid
|
558 |
+
v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec-offset_sec))
|
559 |
+
assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}'
|
560 |
+
v_start_i = sec2frames(v_start_sec, v_fps)
|
561 |
+
v_end_i = v_start_i + v_crop_len_frames
|
562 |
+
# `v_start_i_sec` IS rounded to the fps grid
|
563 |
+
v_start_i_sec = frames2sec(v_start_i, v_fps)
|
564 |
+
# `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise
|
565 |
+
# (v_start_sec) we have ±0.1 jittering
|
566 |
+
a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
|
567 |
+
if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
|
568 |
+
a_start_i, a_jitter_i = apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps,
|
569 |
+
self.max_a_jitter_sec)
|
570 |
+
item['meta']['a_jitter_i'] = a_jitter_i
|
571 |
+
a_end_i = a_start_i + a_crop_len_frames
|
572 |
+
else:
|
573 |
+
offset_sec = round(offset_sec, 2)
|
574 |
+
v_start_i = sec2frames(v_start_i_sec, v_fps)
|
575 |
+
a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
|
576 |
+
v_end_i = v_start_i + v_crop_len_frames
|
577 |
+
a_end_i = a_start_i + a_crop_len_frames
|
578 |
+
else:
|
579 |
+
offset_sec = 0.0
|
580 |
+
is_random_crop = item['split'] == 'train'
|
581 |
+
v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop)
|
582 |
+
v_start_i_sec = frames2sec(v_start_i, v_fps)
|
583 |
+
a_start_i = sec2frames(v_start_i_sec, a_fps)
|
584 |
+
if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
|
585 |
+
a_start_i, a_jitter_i = apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps,
|
586 |
+
self.max_a_jitter_sec)
|
587 |
+
item['meta']['a_jitter_i'] = a_jitter_i
|
588 |
+
a_end_i = a_start_i + a_crop_len_frames
|
589 |
+
|
590 |
+
# sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48
|
591 |
+
# given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5)
|
592 |
+
if a_start_i < 0:
|
593 |
+
how_much_out = a_start_i
|
594 |
+
logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}')
|
595 |
+
if abs(how_much_out) <= a_fps / v_fps:
|
596 |
+
logging.info('fixing it')
|
597 |
+
a_start_i += abs(how_much_out)
|
598 |
+
a_end_i += abs(how_much_out)
|
599 |
+
else:
|
600 |
+
raise Exception(f'{how_much_out} {item["path"]}')
|
601 |
+
|
602 |
+
assert v_start_i < v_end_i and a_start_i < a_end_i
|
603 |
+
assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}'
|
604 |
+
assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}'
|
605 |
+
|
606 |
+
vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i]
|
607 |
+
|
608 |
+
item['video'] = vid
|
609 |
+
item['audio'] = aud
|
610 |
+
|
611 |
+
assert item['video'].shape[0] == int(v_fps*self.crop_len_sec), f'{item["video"].shape} {item["path"]}'
|
612 |
+
assert item['audio'].shape[0] == int(a_fps*self.crop_len_sec), f'{item["audio"].shape} {item["path"]}'
|
613 |
+
|
614 |
+
# caching parameters
|
615 |
+
if self.do_offset:
|
616 |
+
# NOTE: this is useless for the extreme offsetting
|
617 |
+
offset_label, offset_target = quantize_offset(self.class_grid, offset_sec)
|
618 |
+
item['targets']['offset_sec'] = offset_sec
|
619 |
+
item['targets']['offset_label'] = offset_label
|
620 |
+
# assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?'
|
621 |
+
item['targets']['offset_target'] = offset_target
|
622 |
+
item['targets']['v_start_i_sec'] = v_start_i_sec
|
623 |
+
item['targets']['sync_target'] = int(offset_is_syncable)
|
624 |
+
|
625 |
+
return item
|
626 |
+
|
627 |
+
def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True):
|
628 |
+
if len_frames == crop_len_frames:
|
629 |
+
return 0, len_frames
|
630 |
+
if is_random:
|
631 |
+
left_i = random.randint(0, len_frames - crop_len_frames)
|
632 |
+
else:
|
633 |
+
left_i = int(round((len_frames - crop_len_frames) / 2.))
|
634 |
+
return left_i, left_i+crop_len_frames
|
635 |
+
|
636 |
+
|
637 |
+
class RGBToFloatToZeroOne(torch.nn.Module):
|
638 |
+
|
639 |
+
def __init__(self) -> None:
|
640 |
+
super().__init__()
|
641 |
+
|
642 |
+
def forward(self, item):
|
643 |
+
item['video'] = item['video'].to(torch.float32).div(255.)
|
644 |
+
return item
|
645 |
+
|
646 |
+
|
647 |
+
class RGBToHalfToZeroOne(torch.nn.Module):
|
648 |
+
|
649 |
+
def __init__(self) -> None:
|
650 |
+
super().__init__()
|
651 |
+
|
652 |
+
def forward(self, item):
|
653 |
+
item['video'] = item['video'].half().div(255.)
|
654 |
+
return item
|
655 |
+
|
656 |
+
|
657 |
+
class RGBNormalize(torchvision.transforms.Normalize):
|
658 |
+
'''The same as the torchvision`s but with different interface for the dict.
|
659 |
+
This should work for any shape (..., C, H, W)'''
|
660 |
+
|
661 |
+
def __init__(self, mean, std, inplace=False):
|
662 |
+
super().__init__(mean, std, inplace)
|
663 |
+
logging.info(f'RGBNormalize: mean={mean}, std={std}')
|
664 |
+
|
665 |
+
def forward(self, item):
|
666 |
+
item['video'] = super().forward(item['video'])
|
667 |
+
item['meta']['video']['norm_stats'] = {'mean': torch.as_tensor(self.mean),
|
668 |
+
'std': torch.as_tensor(self.std)}
|
669 |
+
return item
|
670 |
+
|
671 |
+
|
672 |
+
class AudioRandomVolume(torch.nn.Module):
|
673 |
+
|
674 |
+
def __init__(self, p: float, **kwargs):
|
675 |
+
super().__init__()
|
676 |
+
transform = torchaudio.transforms.Vol(**kwargs)
|
677 |
+
self.transform = torchvision.transforms.RandomApply([transform], p)
|
678 |
+
|
679 |
+
def apply_to_single_clip(self, clip):
|
680 |
+
return self.transform(clip)
|
681 |
+
|
682 |
+
def apply_to_each_clip(self, clips):
|
683 |
+
for i, clip in enumerate(clips):
|
684 |
+
clips[i] = self.apply_to_single_clip(clip)
|
685 |
+
return clips
|
686 |
+
|
687 |
+
def forward(self, item):
|
688 |
+
has_batch_dim = len(item['audio'].shape) == 2
|
689 |
+
if has_batch_dim:
|
690 |
+
fn = self.apply_to_each_clip
|
691 |
+
else:
|
692 |
+
fn = self.apply_to_single_clip
|
693 |
+
item['audio'] = fn(item['audio'])
|
694 |
+
return item
|
695 |
+
|
696 |
+
|
697 |
+
class AudioRandomLowpassFilter(torch.nn.Module):
|
698 |
+
|
699 |
+
def __init__(self, p: float, cutoff_freq: float, Q: float = 0.707):
|
700 |
+
super().__init__()
|
701 |
+
self.p = p
|
702 |
+
self.cutoff_freq = cutoff_freq
|
703 |
+
self.Q = Q
|
704 |
+
|
705 |
+
def apply_to_single_clip(self, clip, sr):
|
706 |
+
if self.p > torch.rand(1):
|
707 |
+
return torchaudio.functional.lowpass_biquad(clip, sr, self.cutoff_freq, self.Q)
|
708 |
+
else:
|
709 |
+
return clip
|
710 |
+
|
711 |
+
def apply_to_each_clip(self, clips, sr):
|
712 |
+
for i, clip in enumerate(clips):
|
713 |
+
clips[i] = self.apply_to_single_clip(clip, sr)
|
714 |
+
return clips
|
715 |
+
|
716 |
+
def forward(self, item):
|
717 |
+
has_batch_dim = len(item['audio'].shape) == 2
|
718 |
+
sr = int(item['meta']['audio']['framerate'][0])
|
719 |
+
if has_batch_dim:
|
720 |
+
fn = self.apply_to_each_clip
|
721 |
+
else:
|
722 |
+
fn = self.apply_to_single_clip
|
723 |
+
item['audio'] = fn(item['audio'], sr)
|
724 |
+
return item
|
725 |
+
|
726 |
+
|
727 |
+
class AudioRandomPitchShift(torch.nn.Module):
|
728 |
+
|
729 |
+
def __init__(self, p: float, shift: int) -> None:
|
730 |
+
super().__init__()
|
731 |
+
self.p = p
|
732 |
+
self.shift = shift
|
733 |
+
|
734 |
+
def apply_to_single_clip(self, wave, sr):
|
735 |
+
if self.p > torch.rand(1):
|
736 |
+
effects = [['pitch', f'{self.shift}'], ['rate', f'{sr}']]
|
737 |
+
wave = wave.unsqueeze(0)
|
738 |
+
wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, sr, effects)
|
739 |
+
wave = wave.squeeze(0)
|
740 |
+
return wave
|
741 |
+
|
742 |
+
def apply_to_each_clip(self, waves, sr):
|
743 |
+
for i, wave in enumerate(waves):
|
744 |
+
waves[i] = self.apply_to_single_clip(wave, sr)
|
745 |
+
return waves
|
746 |
+
|
747 |
+
def forward(self, item):
|
748 |
+
has_batch_dim = len(item['audio'].shape) == 2
|
749 |
+
sr = int(item['meta']['audio']['framerate'][0])
|
750 |
+
if has_batch_dim:
|
751 |
+
fn = self.apply_to_each_clip
|
752 |
+
else:
|
753 |
+
fn = self.apply_to_single_clip
|
754 |
+
item['audio'] = fn(item['audio'], sr)
|
755 |
+
return item
|
756 |
+
|
757 |
+
|
758 |
+
class AudioRandomReverb(torch.nn.Module):
|
759 |
+
|
760 |
+
def __init__(self, p: float) -> None:
|
761 |
+
super().__init__()
|
762 |
+
self.p = p
|
763 |
+
self.effects = [['reverb', '-w']]
|
764 |
+
|
765 |
+
def apply_to_single_clip(self, wave, fps):
|
766 |
+
if self.p > torch.rand(1):
|
767 |
+
wave = wave.unsqueeze(0)
|
768 |
+
wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, fps, self.effects)
|
769 |
+
wave = wave.mean(dim=0)
|
770 |
+
return wave
|
771 |
+
|
772 |
+
def apply_to_each_clip(self, waves, fps):
|
773 |
+
for i, wave in enumerate(waves):
|
774 |
+
waves[i] = self.apply_to_single_clip(wave, fps)
|
775 |
+
return waves
|
776 |
+
|
777 |
+
def forward(self, item):
|
778 |
+
has_batch_dim = len(item['audio'].shape) == 2
|
779 |
+
sr = int(item['meta']['audio']['framerate'][0])
|
780 |
+
if has_batch_dim:
|
781 |
+
fn = self.apply_to_each_clip
|
782 |
+
else:
|
783 |
+
fn = self.apply_to_single_clip
|
784 |
+
item['audio'] = fn(item['audio'], sr)
|
785 |
+
return item
|
786 |
+
|
787 |
+
class AudioRandomGaussNoise(torch.nn.Module):
|
788 |
+
|
789 |
+
def __init__(self, p: float, amplitude=0.01) -> None:
|
790 |
+
super().__init__()
|
791 |
+
self.p = p
|
792 |
+
self.amplitude = amplitude
|
793 |
+
|
794 |
+
def apply_to_single_clip(self, wave):
|
795 |
+
if self.p > torch.rand(1):
|
796 |
+
noise = torch.randn_like(wave, dtype=wave.dtype)
|
797 |
+
wave = wave + self.amplitude * noise
|
798 |
+
return wave
|
799 |
+
|
800 |
+
def apply_to_each_clip(self, waves):
|
801 |
+
for i, wave in enumerate(waves):
|
802 |
+
waves[i] = self.apply_to_single_clip(wave)
|
803 |
+
return waves
|
804 |
+
|
805 |
+
def forward(self, item):
|
806 |
+
has_batch_dim = len(item['audio'].shape) == 2
|
807 |
+
if has_batch_dim:
|
808 |
+
fn = self.apply_to_each_clip
|
809 |
+
else:
|
810 |
+
fn = self.apply_to_single_clip
|
811 |
+
item['audio'] = fn(item['audio'])
|
812 |
+
return item
|
813 |
+
|
814 |
+
|
815 |
+
class AudioMelSpectrogram(torch.nn.Module):
|
816 |
+
|
817 |
+
def __init__(self, **kwargs):
|
818 |
+
super().__init__()
|
819 |
+
self.spec = torchaudio.transforms.MelSpectrogram(**kwargs)
|
820 |
+
|
821 |
+
def forward(self, item):
|
822 |
+
item['audio'] = self.spec(item['audio']) # safe for batched input
|
823 |
+
return item
|
824 |
+
|
825 |
+
|
826 |
+
class AudioLog(torch.nn.Module):
|
827 |
+
|
828 |
+
def __init__(self, eps=1e-6) -> None:
|
829 |
+
super().__init__()
|
830 |
+
self.eps = eps
|
831 |
+
|
832 |
+
def forward(self, item):
|
833 |
+
item['audio'] = torch.log(item['audio'] + self.eps)
|
834 |
+
return item
|
835 |
+
|
836 |
+
class PadOrTruncate(torch.nn.Module):
|
837 |
+
|
838 |
+
def __init__(self, max_spec_t: int, pad_mode: str = 'constant', pad_value: float = 0.0):
|
839 |
+
super().__init__()
|
840 |
+
self.max_spec_t = max_spec_t
|
841 |
+
self.pad_mode = pad_mode
|
842 |
+
self.pad_value = pad_value
|
843 |
+
|
844 |
+
def forward(self, item):
|
845 |
+
item['audio'] = self.pad_or_truncate(item['audio'])
|
846 |
+
return item
|
847 |
+
|
848 |
+
def pad_or_truncate(self, audio):
|
849 |
+
difference = self.max_spec_t - audio.shape[-1] # safe for batched input
|
850 |
+
# pad or truncate, depending on difference
|
851 |
+
if difference > 0:
|
852 |
+
# pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
|
853 |
+
pad_dims = (0, difference)
|
854 |
+
audio = torch.nn.functional.pad(audio, pad_dims, self.pad_mode, self.pad_value)
|
855 |
+
elif difference < 0:
|
856 |
+
logging.warning(f'Truncating spec ({audio.shape}) to max_spec_t ({self.max_spec_t}).')
|
857 |
+
audio = audio[..., :self.max_spec_t] # safe for batched input
|
858 |
+
return audio
|
859 |
+
|
860 |
+
|
861 |
+
class AudioNormalizeAST(torch.nn.Module):
|
862 |
+
'''Normalization is done with two specified mean and std (half)'''
|
863 |
+
def __init__(self, mean: float, std: float) -> None:
|
864 |
+
super().__init__()
|
865 |
+
self.mean = mean
|
866 |
+
self.std = std
|
867 |
+
|
868 |
+
def forward(self, item):
|
869 |
+
item['audio'] = (item['audio'] - self.mean) / (2 * self.std)
|
870 |
+
item['meta']['audio']['norm_stats'] = {'mean': self.mean, 'std': self.std}
|
871 |
+
return item
|
872 |
+
|
873 |
+
|
874 |
+
class PermuteStreams(torch.nn.Module):
|
875 |
+
|
876 |
+
def __init__(self, einops_order_audio: str, einops_order_rgb: str) -> None:
|
877 |
+
''' For example:
|
878 |
+
einops_order_audio: "S F T -> S T F"
|
879 |
+
einops_order_rgb: "S T C H W -> S C T H W"'''
|
880 |
+
super().__init__()
|
881 |
+
self.einops_order_audio = einops_order_audio
|
882 |
+
self.einops_order_rgb = einops_order_rgb
|
883 |
+
|
884 |
+
def forward(self, item):
|
885 |
+
if self.einops_order_audio is not None:
|
886 |
+
item['audio'] = einops.rearrange(item['audio'], self.einops_order_audio).contiguous()
|
887 |
+
if self.einops_order_rgb is not None:
|
888 |
+
item['video'] = einops.rearrange(item['video'], self.einops_order_rgb).contiguous()
|
889 |
+
return item
|
890 |
+
|
891 |
+
|
892 |
+
class ResampleAudio(torch.nn.Module):
|
893 |
+
|
894 |
+
def __init__(self, new_fps: int):
|
895 |
+
super().__init__()
|
896 |
+
self.new_fps = new_fps
|
897 |
+
|
898 |
+
def forward(self, item):
|
899 |
+
orig_fps = int(item['meta']['audio']['framerate'][0])
|
900 |
+
item['meta']['audio']['orig_shape'] = item['audio'].shape
|
901 |
+
if orig_fps != self.new_fps:
|
902 |
+
item['audio'] = torchaudio.functional.resample(item['audio'], orig_fps, self.new_fps)
|
903 |
+
item['meta']['audio']['framerate'][0] = self.new_fps
|
904 |
+
return item
|
905 |
+
|
906 |
+
class ResampleRGB(torch.nn.Module):
|
907 |
+
|
908 |
+
def __init__(self, new_fps: int) -> None:
|
909 |
+
super().__init__()
|
910 |
+
self.new_fps = new_fps
|
911 |
+
|
912 |
+
def forward(self, item):
|
913 |
+
orig_fps = float(item['meta']['video']['fps'][0])
|
914 |
+
item['meta']['video']['orig_shape'] = item['video'].shape
|
915 |
+
if orig_fps != self.new_fps:
|
916 |
+
duration_sec = item['video'].shape[0] / orig_fps
|
917 |
+
indices = torch.arange(0, orig_fps * duration_sec - 1e-9, orig_fps / self.new_fps)
|
918 |
+
# basically, rounding
|
919 |
+
indices = indices.to(dtype=torch.long)
|
920 |
+
item['video'] = item['video'][indices]
|
921 |
+
item['meta']['video']['fps'][0] = self.new_fps
|
922 |
+
return item
|
923 |
+
|
924 |
+
class ResizeAndLetterboxPad(torch.nn.Module):
|
925 |
+
'''Adapted from WACV24 Amazon`s challenge'''
|
926 |
+
|
927 |
+
def __init__(self, new_h, new_w):
|
928 |
+
super().__init__()
|
929 |
+
self.new_h = new_h
|
930 |
+
self.new_w = new_w
|
931 |
+
self.aspect_ratio = new_w / new_h
|
932 |
+
|
933 |
+
def forward(self, item):
|
934 |
+
item['video'] = self.resize_and_pad(item['video'])
|
935 |
+
return item
|
936 |
+
|
937 |
+
def resize_and_pad(self, rgb: torch.Tensor):
|
938 |
+
_, _, height, width = rgb.shape
|
939 |
+
current_aspect_ratio = width / height
|
940 |
+
if current_aspect_ratio > self.aspect_ratio:
|
941 |
+
scaled_height = round(self.new_w / current_aspect_ratio)
|
942 |
+
rgb = torchvision.transforms.functional.resize(rgb, (scaled_height, self.new_w), antialias=None)
|
943 |
+
top = (self.new_h - scaled_height) // 2
|
944 |
+
bottom = self.new_h - (scaled_height + top)
|
945 |
+
rgb = torch.nn.ConstantPad2d((0, 0, top, bottom), 0)(rgb)
|
946 |
+
elif current_aspect_ratio < self.aspect_ratio:
|
947 |
+
scaled_width = round(self.new_h*current_aspect_ratio)
|
948 |
+
rgb = torchvision.transforms.functional.resize(rgb, (self.new_h, scaled_width), antialias=None)
|
949 |
+
left = (self.new_w - scaled_width) // 2
|
950 |
+
right = self.new_w - (scaled_width + left)
|
951 |
+
rgb = torch.nn.ConstantPad2d((left, right, 0, 0), 0)(rgb)
|
952 |
+
return rgb
|
953 |
+
|
954 |
+
|
955 |
+
class ResampleResizeLetterboxPad(torch.nn.Module):
|
956 |
+
|
957 |
+
def __init__(self, afps, vfps, new_h, new_w) -> None:
|
958 |
+
super().__init__()
|
959 |
+
self.transforms = torchvision.transforms.Compose([
|
960 |
+
ResampleAudio(new_fps=afps),
|
961 |
+
ResampleRGB(new_fps=vfps),
|
962 |
+
ResizeAndLetterboxPad(new_h=new_h, new_w=new_w)
|
963 |
+
])
|
964 |
+
|
965 |
+
def forward(self, x: dict) -> dict:
|
966 |
+
return self.transforms(x)
|
967 |
+
|
968 |
+
class DoNothing(torch.nn.Module):
|
969 |
+
def __init__(self, *args, **kwargs) -> None:
|
970 |
+
super().__init__()
|
971 |
+
|
972 |
+
def forward(self, x: dict) -> dict:
|
973 |
+
return x
|
974 |
+
|
975 |
+
|
976 |
+
if __name__ == '__main__':
|
977 |
+
grid = make_class_grid(-1, 1, 21)
|
978 |
+
grid = make_class_grid(-2, 2, 41)
|
979 |
+
print('grid:', grid)
|
980 |
+
print('value quantization:', quantize_offset(grid, 0.06))
|
981 |
+
v_fps = 25.0
|
982 |
+
duration = 10.0
|
983 |
+
|
984 |
+
input = {
|
985 |
+
'video': torch.randint(0, 256, (int(duration * v_fps), 3, 720//2, 1280//2), dtype=torch.uint8),
|
986 |
+
'audio': torch.arange(221184-1).float(),
|
987 |
+
'targets': {},
|
988 |
+
'meta': {
|
989 |
+
'video': {'duration': [duration], 'fps': [v_fps]},
|
990 |
+
'audio': {'duration': [duration], 'framerate': [22050.0]},
|
991 |
+
'subtitles': {'duration': []},
|
992 |
+
'cc': {'duration': []},
|
993 |
+
},
|
994 |
+
'path': '/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4',
|
995 |
+
'split': 'train',
|
996 |
+
}
|
997 |
+
|
998 |
+
print(input['audio'].shape, input['video'].shape)
|
999 |
+
|
1000 |
+
fn = EqualifyFromRight(clip_max_len_sec=10)
|
1001 |
+
input = fn(input)
|
1002 |
+
print(input['audio'].shape, input['video'].shape)
|
1003 |
+
|
1004 |
+
fn = RGBSpatialCrop((224, 224), is_random=True)
|
1005 |
+
# fn = RGBSpatialCrop((112, 112), is_random=True)
|
1006 |
+
input = fn(input)
|
1007 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1008 |
+
|
1009 |
+
fn = Resize((224, 224))
|
1010 |
+
input = fn(input)
|
1011 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1012 |
+
|
1013 |
+
fn = GenerateMultipleSegments(segment_size_vframes=16, n_segments=14,
|
1014 |
+
is_start_random=False, audio_jitter_sec=0.05, step_size_seg=0.5)
|
1015 |
+
input = fn(input)
|
1016 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1017 |
+
|
1018 |
+
fn = RandomApplyColorDistortion(p_gray_scale=0.5, p_color_jitter=0.5, s=1.0)
|
1019 |
+
input = fn(input)
|
1020 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1021 |
+
|
1022 |
+
fn = RGBToFloatToZeroOne()
|
1023 |
+
input = fn(input)
|
1024 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1025 |
+
print(input['meta'])
|
1026 |
+
|
1027 |
+
fn = RGBNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
1028 |
+
input = fn(input)
|
1029 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1030 |
+
print(input['video'].mean(dim=(0, 2, 3)))
|
1031 |
+
print(input['meta'])
|
1032 |
+
|
1033 |
+
fn = AudioRandomReverb(p=1.0)
|
1034 |
+
input = fn(input)
|
1035 |
+
|
1036 |
+
fn = AudioRandomVolume(p=1.0, gain=2.0, gain_type='amplitude')
|
1037 |
+
input = fn(input)
|
1038 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1039 |
+
|
1040 |
+
fn = AudioRandomPitchShift(p=1.0, shift=1000)
|
1041 |
+
input = fn(input)
|
1042 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1043 |
+
|
1044 |
+
fn = AudioRandomLowpassFilter(p=1.0, cutoff_freq=100)
|
1045 |
+
input = fn(input)
|
1046 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1047 |
+
|
1048 |
+
fn = AudioRandomGaussNoise(p=1.0, amplitude=0.01)
|
1049 |
+
input = fn(input)
|
1050 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1051 |
+
|
1052 |
+
fn = AudioLog()
|
1053 |
+
input = fn(input)
|
1054 |
+
print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
|
1055 |
+
|
1056 |
+
# audio only
|
1057 |
+
input = {
|
1058 |
+
'audio': torch.arange(221184).float(),
|
1059 |
+
'meta': {
|
1060 |
+
'video': {'duration': [10.0], 'fps': [10.0]},
|
1061 |
+
'audio': {'duration': [11.0], 'framerate': [22050.0]},
|
1062 |
+
'subtitles': {'duration': []},
|
1063 |
+
'cc': {'duration': []}
|
1064 |
+
},
|
1065 |
+
'path': '/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4'
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
print(input['audio'].shape)
|
1069 |
+
|
1070 |
+
fn = AudioLog()
|
1071 |
+
input = fn(input)
|
1072 |
+
print(input['audio'].shape, input['meta']['audio'])
|
1073 |
+
print(input['meta'])
|
1074 |
+
print(input['audio'].min(), input['audio'].max())
|
modules/dataset/vggsound.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from collections import Counter
|
7 |
+
from glob import glob
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
sys.path.insert(0, '.') # nopep8
|
13 |
+
from dataset.dataset_utils import get_fixed_offsets, get_video_and_audio, subsample_dataset
|
14 |
+
|
15 |
+
|
16 |
+
class VGGSound(torch.utils.data.Dataset):
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
split,
|
20 |
+
vids_dir,
|
21 |
+
transforms=None,
|
22 |
+
to_filter_bad_examples=True,
|
23 |
+
splits_path='./data',
|
24 |
+
meta_path='./data/vggsound.csv',
|
25 |
+
seed=1337,
|
26 |
+
load_fixed_offsets_on=['valid', 'test'],
|
27 |
+
vis_load_backend='read_video',
|
28 |
+
size_ratio=None,
|
29 |
+
attr_annot_path=None,
|
30 |
+
max_attr_per_vid=None):
|
31 |
+
super().__init__()
|
32 |
+
self.max_clip_len_sec = None
|
33 |
+
self.split = split
|
34 |
+
self.vids_dir = vids_dir
|
35 |
+
self.transforms = transforms
|
36 |
+
self.to_filter_bad_examples = to_filter_bad_examples
|
37 |
+
self.splits_path = splits_path
|
38 |
+
self.meta_path = meta_path
|
39 |
+
self.seed = seed
|
40 |
+
self.load_fixed_offsets_on = [] if load_fixed_offsets_on is None else load_fixed_offsets_on
|
41 |
+
self.vis_load_backend = vis_load_backend
|
42 |
+
self.size_ratio = size_ratio
|
43 |
+
|
44 |
+
vggsound_meta = list(csv.reader(open(meta_path), quotechar='"'))
|
45 |
+
|
46 |
+
# filter "bad" examples
|
47 |
+
if to_filter_bad_examples:
|
48 |
+
vggsound_meta = self.filter_bad_examples(vggsound_meta)
|
49 |
+
|
50 |
+
unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
|
51 |
+
self.label2target = {label: target for target, label in enumerate(unique_classes)}
|
52 |
+
self.target2label = {target: label for label, target in self.label2target.items()}
|
53 |
+
self.video2target = {row[0]: self.label2target[row[2]] for row in vggsound_meta}
|
54 |
+
|
55 |
+
split_clip_ids_path = os.path.join(splits_path, f'vggsound_{split}.txt')
|
56 |
+
if not os.path.exists(split_clip_ids_path):
|
57 |
+
self.make_split_files()
|
58 |
+
# the ugly string converts ['AdfsGsfII2yQ', '1'] into `AdfsGsfII2yQ_1000_11000`
|
59 |
+
meta_available = set([f'{r[0]}_{int(r[1])*1000}_{(int(r[1])+10)*1000}' for r in vggsound_meta])
|
60 |
+
within_split = set(open(split_clip_ids_path).read().splitlines())
|
61 |
+
clip_paths = [os.path.join(vids_dir, v + '.mp4') for v in meta_available.intersection(within_split)]
|
62 |
+
clip_paths = sorted(clip_paths)
|
63 |
+
|
64 |
+
if split in self.load_fixed_offsets_on:
|
65 |
+
logging.info(f'Using fixed offset for {split}')
|
66 |
+
self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'vggsound')
|
67 |
+
|
68 |
+
# making sure that all classes have at least one example
|
69 |
+
counter = Counter([self.video2target[Path(p).stem[:11]] for p in clip_paths])
|
70 |
+
assert all(counter[c] > 0 for c in self.target2label.keys()), \
|
71 |
+
f'Some classes have 0 count: {dict(counter)}'
|
72 |
+
|
73 |
+
self.dataset = clip_paths
|
74 |
+
self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
|
75 |
+
|
76 |
+
def filter_bad_examples(self, vggsound_meta):
|
77 |
+
bad = set()
|
78 |
+
base_path = Path('./data/filtered_examples_vggsound')
|
79 |
+
lists = [open(p).read().splitlines() for p in sorted(glob(str(base_path / '*.txt')))]
|
80 |
+
for s in lists:
|
81 |
+
bad = bad.union(s)
|
82 |
+
# the ugly string converts '---g-f_I2yQ', '1' into `---g-f_I2yQ_1000_11000`
|
83 |
+
vggsound_meta = [r for r in vggsound_meta if f'{r[0]}_{int(r[1])*1000}_{(int(r[1])+10)*1000}' not in bad]
|
84 |
+
return vggsound_meta
|
85 |
+
|
86 |
+
def __getitem__(self, index):
|
87 |
+
path = self.dataset[index]
|
88 |
+
rgb, audio, meta = self.load_media(path)
|
89 |
+
item = self.make_datapoint(path, rgb, audio, meta)
|
90 |
+
if self.transforms is not None:
|
91 |
+
item = self.transforms(item)
|
92 |
+
return item
|
93 |
+
|
94 |
+
def make_datapoint(self, path, rgb, audio, meta):
|
95 |
+
# (Tv, 3, H, W) in [0, 225], (Ta, C) in [-1, 1]
|
96 |
+
target = self.video2target[Path(path).stem[:11]]
|
97 |
+
item = {
|
98 |
+
'video': rgb,
|
99 |
+
'audio': audio,
|
100 |
+
'meta': meta,
|
101 |
+
'path': str(path),
|
102 |
+
'targets': {'vggsound_target': target, 'vggsound_label': self.target2label[target]},
|
103 |
+
'split': self.split,
|
104 |
+
}
|
105 |
+
|
106 |
+
if self.split in self.load_fixed_offsets_on:
|
107 |
+
unique_id = Path(path).stem
|
108 |
+
offset_params = self.vid2offset_params[unique_id]
|
109 |
+
item['targets']['offset_sec'] = self.vid2offset_params[unique_id]['offset_sec']
|
110 |
+
item['targets']['v_start_i_sec'] = self.vid2offset_params[unique_id]['v_start_i_sec']
|
111 |
+
if 'oos_target' in offset_params:
|
112 |
+
item['targets']['offset_target'] = {
|
113 |
+
'oos': offset_params['oos_target'], 'offset': item['targets']['offset_sec'],
|
114 |
+
}
|
115 |
+
|
116 |
+
return item
|
117 |
+
|
118 |
+
def load_media(self, path):
|
119 |
+
rgb, audio, meta = get_video_and_audio(path, get_meta=True, end_sec=self.max_clip_len_sec)
|
120 |
+
return rgb, audio, meta
|
121 |
+
|
122 |
+
def make_split_files(self):
|
123 |
+
if self.to_filter_bad_examples:
|
124 |
+
logging.warning('`to_filter_bad_examples` is True. `make_split_files` expects otherwise')
|
125 |
+
|
126 |
+
logging.info(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
|
127 |
+
# The downloaded videos (some went missing on YouTube and no longer available)
|
128 |
+
available_vid_paths = sorted(glob(os.path.join(self.vids_dir, '*.mp4')))
|
129 |
+
logging.info(f'The number of clips available after download: {len(available_vid_paths)}')
|
130 |
+
|
131 |
+
# original (full) train and test sets
|
132 |
+
vggsound_meta = list(csv.reader(open(self.meta_path), quotechar='"'))
|
133 |
+
train_vids = {row[0] for row in vggsound_meta if row[3] == 'train'}
|
134 |
+
test_vids = {row[0] for row in vggsound_meta if row[3] == 'test'}
|
135 |
+
|
136 |
+
# # the cleaned test set
|
137 |
+
# vggsound_meta_test_v2 = list(csv.reader(open(self.meta_path_clean_test), quotechar='"'))
|
138 |
+
logging.info(f'The number of videos in vggsound train set: {len(train_vids)}')
|
139 |
+
logging.info(f'The number of videos in vggsound test set: {len(test_vids)}')
|
140 |
+
|
141 |
+
# class counts in test set. We would like to have the same distribution in valid
|
142 |
+
unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
|
143 |
+
label2target = {label: target for target, label in enumerate(unique_classes)}
|
144 |
+
video2target = {row[0]: label2target[row[2]] for row in vggsound_meta}
|
145 |
+
test_vid_classes = [video2target[vid] for vid in test_vids]
|
146 |
+
test_target2count = Counter(test_vid_classes)
|
147 |
+
|
148 |
+
# now given the counts from test set, sample the same count for validation and the rest leave in train
|
149 |
+
train_vids_wo_valid, valid_vids = set(), set()
|
150 |
+
for target, label in enumerate(label2target.keys()):
|
151 |
+
class_train_vids = [vid for vid in sorted(list(train_vids)) if video2target[vid] == target]
|
152 |
+
random.Random(self.seed).shuffle(class_train_vids)
|
153 |
+
count = test_target2count[target]
|
154 |
+
valid_vids.update(class_train_vids[:count])
|
155 |
+
train_vids_wo_valid.update(class_train_vids[count:])
|
156 |
+
|
157 |
+
# make file with a list of available test videos (each video should contain timestamps as well)
|
158 |
+
train_i = valid_i = test_i = 0
|
159 |
+
with open(os.path.join(self.splits_path, 'vggsound_train.txt'), 'w') as train_file, \
|
160 |
+
open(os.path.join(self.splits_path, 'vggsound_valid.txt'), 'w') as valid_file, \
|
161 |
+
open(os.path.join(self.splits_path, 'vggsound_test.txt'), 'w') as test_file:
|
162 |
+
# open(os.path.join(self.splits_path, 'vggsound_test_v2.txt'), 'w') as test_file_v2:
|
163 |
+
for path in available_vid_paths:
|
164 |
+
path = path.replace('.mp4', '')
|
165 |
+
vid_name = Path(path).name
|
166 |
+
# 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
|
167 |
+
if vid_name[:11] in train_vids_wo_valid:
|
168 |
+
train_file.write(vid_name + '\n')
|
169 |
+
train_i += 1
|
170 |
+
elif vid_name[:11] in valid_vids:
|
171 |
+
valid_file.write(vid_name + '\n')
|
172 |
+
valid_i += 1
|
173 |
+
elif vid_name[:11] in test_vids:
|
174 |
+
test_file.write(vid_name + '\n')
|
175 |
+
test_i += 1
|
176 |
+
# else:
|
177 |
+
# raise Exception(f'Clip {vid_name} is neither in train, valid nor test. Strange.')
|
178 |
+
|
179 |
+
logging.info(f'Put {train_i} clips to the train set and saved it to ./data/vggsound_train.txt')
|
180 |
+
logging.info(f'Put {valid_i} clips to the valid set and saved it to ./data/vggsound_valid.txt')
|
181 |
+
logging.info(f'Put {test_i} clips to the test set and saved it to ./data/vggsound_test.txt')
|
182 |
+
|
183 |
+
def __len__(self):
|
184 |
+
return len(self.dataset)
|
185 |
+
|
186 |
+
class VGGSoundSparse(VGGSound):
|
187 |
+
'''
|
188 |
+
The same as VGGSound, except the list of videos is filtered for sparse sounds (sparse_meta_path)
|
189 |
+
'''
|
190 |
+
|
191 |
+
def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
|
192 |
+
splits_path='./data', meta_path='./data/vggsound.csv',
|
193 |
+
sparse_meta_path='./data/sparse_classes.csv', seed=1337, load_fixed_offsets_on=['valid', 'test'],
|
194 |
+
vis_load_backend='read_video', size_ratio=None, attr_annot_path=None, max_attr_per_vid=None):
|
195 |
+
super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path, seed,
|
196 |
+
load_fixed_offsets_on, vis_load_backend, size_ratio)
|
197 |
+
self.sparse_meta_path = sparse_meta_path
|
198 |
+
sparse_meta = list(csv.reader(open(sparse_meta_path), quotechar='"', delimiter='\t'))
|
199 |
+
sparse_classes = set([row[0] for row in sparse_meta if row[1] == 'y'])
|
200 |
+
label2new_target = {label: target for target, label in enumerate(sorted(list(sparse_classes)))}
|
201 |
+
new_target2label = {target: label for label, target in label2new_target.items()}
|
202 |
+
|
203 |
+
sparse_dataset = []
|
204 |
+
video2new_target = {}
|
205 |
+
for path in self.dataset:
|
206 |
+
vid_id = Path(path).stem[:11]
|
207 |
+
vid_target = self.video2target[vid_id]
|
208 |
+
vid_label = self.target2label[vid_target]
|
209 |
+
if vid_label in sparse_classes:
|
210 |
+
sparse_dataset.append(path)
|
211 |
+
video2new_target[vid_id] = label2new_target[vid_label]
|
212 |
+
|
213 |
+
self.dataset = sparse_dataset
|
214 |
+
logging.debug(f'Filtered VGGSound dataset to sparse classes from {Path(sparse_meta_path).name}.')
|
215 |
+
|
216 |
+
# redefining the label <-> target variable
|
217 |
+
self.label2target = label2new_target
|
218 |
+
self.target2label = new_target2label
|
219 |
+
self.video2target = video2new_target
|
220 |
+
logging.debug('Redefined label <-> target mapping to sparse classes.')
|
221 |
+
|
222 |
+
counter = Counter([self.video2target[Path(p).stem[:11]] for p in self.dataset])
|
223 |
+
assert len(self.dataset) < 1000 or all(counter[c] > 0 for c in self.target2label.keys()), \
|
224 |
+
f'Some classes have 0 count: {dict(counter)}'
|
225 |
+
|
226 |
+
|
227 |
+
class VGGSoundSparsePicked(VGGSoundSparse):
|
228 |
+
|
229 |
+
def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
|
230 |
+
splits_path='./data', meta_path='./data/vggsound.csv',
|
231 |
+
sparse_meta_path='./data/picked_sparse_classes.csv', seed=1337,
|
232 |
+
load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
|
233 |
+
attr_annot_path=None, max_attr_per_vid=None):
|
234 |
+
super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path,
|
235 |
+
meta_path, sparse_meta_path, seed, load_fixed_offsets_on, vis_load_backend,
|
236 |
+
size_ratio)
|
237 |
+
|
238 |
+
|
239 |
+
class VGGSoundSparsePickedCleanTest(VGGSoundSparse):
|
240 |
+
|
241 |
+
def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
|
242 |
+
splits_path='./data', meta_path='./data/vggsound.csv',
|
243 |
+
sparse_meta_path='./data/picked_sparse_classes.csv', seed=1337,
|
244 |
+
load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
|
245 |
+
attr_annot_path=None, max_attr_per_vid=None):
|
246 |
+
super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path,
|
247 |
+
meta_path, sparse_meta_path, seed, load_fixed_offsets_on, vis_load_backend,
|
248 |
+
size_ratio)
|
249 |
+
|
250 |
+
def filter_bad_examples(self, vggsound_meta):
|
251 |
+
bad = set()
|
252 |
+
base_path1 = Path('./data/filtered_examples_vggsound')
|
253 |
+
base_path2 = Path('./data/filtered_examples_vggsound_extra')
|
254 |
+
lists1 = [open(p).read().splitlines() for p in sorted(glob(str(base_path1 / '*.txt')))]
|
255 |
+
lists2 = [open(p).read().splitlines() for p in sorted(glob(str(base_path2 / '*.txt')))]
|
256 |
+
lists = lists1 + lists2
|
257 |
+
for s in lists:
|
258 |
+
bad = bad.union(s)
|
259 |
+
# the ugly string converts '---g-f_I2yQ', '1' into `---g-f_I2yQ_1000_11000`
|
260 |
+
vggsound_meta = [r for r in vggsound_meta if f'{r[0]}_{int(r[1])*1000}_{(int(r[1])+10)*1000}' not in bad]
|
261 |
+
return vggsound_meta
|
262 |
+
|
263 |
+
|
264 |
+
class VGGSoundSparsePickedCleanTestFixedOffsets(VGGSoundSparse):
|
265 |
+
'''This dataset only operates on manually annotated fixed offsets. Meant for evaluation purpose.'''
|
266 |
+
|
267 |
+
def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
|
268 |
+
splits_path='./data', meta_path='./data/vggsound.csv',
|
269 |
+
sparse_meta_path='./data/picked_sparse_classes.csv', seed=1337,
|
270 |
+
load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
|
271 |
+
attr_annot_path=None, max_attr_per_vid=None):
|
272 |
+
super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path,
|
273 |
+
meta_path, sparse_meta_path, seed, load_fixed_offsets_on, vis_load_backend,
|
274 |
+
size_ratio)
|
275 |
+
# redefine the dataset to only use fixed offsets
|
276 |
+
fix_off_path = './data/vggsound_sparse_clean_fixed_offsets.csv'
|
277 |
+
self.vid2offset_params = {}
|
278 |
+
# lines have the format: dataset_name,video_id,vstart_sec,offset_sec,is_sync
|
279 |
+
reader = csv.reader(open(fix_off_path))
|
280 |
+
next(reader) # skipping the header
|
281 |
+
for _, vi, st, of, sy in reader:
|
282 |
+
# FIXME: if there are more than one fixed offset per video, we will need to change the logic here
|
283 |
+
assert vi not in self.vid2offset_params, 'offsets from other splits will override each other'
|
284 |
+
if sy == '1':
|
285 |
+
self.vid2offset_params[vi] = {'offset_sec': float(of), 'v_start_i_sec': float(st)}
|
286 |
+
logging.debug(f'Loaded {len(self.vid2offset_params)} fixed offsets from {Path(fix_off_path).name}.')
|
287 |
+
# since we care only about the videos in the fixed offset split, we filter the dataset
|
288 |
+
self.dataset = [p for p in self.dataset if Path(p).stem in self.vid2offset_params]
|
289 |
+
logging.debug(f'Filtered VGGSoundSparse dataset to fixed offsets from {Path(fix_off_path).name}.')
|
290 |
+
|
291 |
+
|
292 |
+
class LongerVGGSound(VGGSound):
|
293 |
+
'''This class is different to the parent in the extra filtering it does. If the parent was
|
294 |
+
making the splits with filtering for shorter than 9 second, this class filters for shorter than 9.5 sec.
|
295 |
+
by applying extra filtering.
|
296 |
+
'''
|
297 |
+
|
298 |
+
def __init__(self,
|
299 |
+
split,
|
300 |
+
vids_dir,
|
301 |
+
transforms=None,
|
302 |
+
to_filter_bad_examples=True,
|
303 |
+
splits_path='./data',
|
304 |
+
meta_path='./data/vggsound.csv',
|
305 |
+
seed=1337,
|
306 |
+
load_fixed_offsets_on=['valid', 'test'],
|
307 |
+
vis_load_backend='read_video',
|
308 |
+
size_ratio=None,
|
309 |
+
attr_annot_path=None,
|
310 |
+
max_attr_per_vid=None):
|
311 |
+
# size_ratio=None and fixed_offsets_on=[] to avoid double subsampling and loading fixed offsets
|
312 |
+
super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path, seed,
|
313 |
+
[], vis_load_backend, None, attr_annot_path, max_attr_per_vid)
|
314 |
+
# redefining the load_fixed_offsets_on because the parent class does not load them (see above)
|
315 |
+
self.load_fixed_offsets_on = load_fixed_offsets_on
|
316 |
+
# doing the extra filtering for longer than 9.5 sec
|
317 |
+
if to_filter_bad_examples:
|
318 |
+
filt_ex_path = Path('./data/filtered_examples_vggsound_shorter/less_than_9.5s.txt')
|
319 |
+
bad = set([p for p in open(filt_ex_path).read().splitlines()])
|
320 |
+
self.dataset = [p for p in self.dataset if Path(p).stem not in bad]
|
321 |
+
|
322 |
+
# longer vid clips require new fixed offsets, loading them here again (VGGSound loads them in init)
|
323 |
+
if split in self.load_fixed_offsets_on:
|
324 |
+
logging.info(f'Using fixed offset for {split}')
|
325 |
+
self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'vggsound')
|
326 |
+
|
327 |
+
self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
|
328 |
+
logging.info(f'{split} has {len(self.dataset)} items')
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
if __name__ == '__main__':
|
333 |
+
from omegaconf import OmegaConf
|
334 |
+
from scripts.train_utils import get_transforms
|
335 |
+
from utils.utils import cfg_sanity_check_and_patch
|
336 |
+
cfg = OmegaConf.load('./configs/transformer.yaml')
|
337 |
+
vis_load_backend = 'read_video'
|
338 |
+
|
339 |
+
# set logging for info level
|
340 |
+
logging.basicConfig(level=logging.INFO)
|
341 |
+
|
342 |
+
transforms = get_transforms(cfg)
|
343 |
+
|
344 |
+
vids_path = 'PLACEHOLDER'
|
345 |
+
|
346 |
+
# cfg.data.dataset.params.size_ratio = 0.1
|
347 |
+
|
348 |
+
cfg_sanity_check_and_patch(cfg)
|
349 |
+
|
350 |
+
datasets = {
|
351 |
+
'train': VGGSound('train', vids_path, transforms['train'], vis_load_backend=vis_load_backend,
|
352 |
+
size_ratio=cfg.data.dataset.params.size_ratio),
|
353 |
+
'valid': VGGSound('valid', vids_path, transforms['test'], vis_load_backend=vis_load_backend),
|
354 |
+
'test': VGGSound('test', vids_path, transforms['test'], vis_load_backend=vis_load_backend),
|
355 |
+
}
|
356 |
+
for phase in ['train', 'valid', 'test']:
|
357 |
+
print(phase, len(datasets[phase]))
|
358 |
+
|
359 |
+
print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
|
360 |
+
print(datasets['train'][0]['meta'])
|
361 |
+
print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
|
362 |
+
print(datasets['valid'][0]['meta'])
|
363 |
+
print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
|
364 |
+
print(datasets['test'][0]['meta'])
|
365 |
+
|
366 |
+
# for i in range(300, 1000):
|
367 |
+
# datasets['train'][i]['path']
|
368 |
+
# print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
|
369 |
+
# print(datasets['train'][0]['meta'])
|
370 |
+
|
371 |
+
datasets = {
|
372 |
+
'train': VGGSoundSparse('train', vids_path, transforms['train']),
|
373 |
+
'valid': VGGSoundSparse('valid', vids_path, transforms['test']),
|
374 |
+
'test': VGGSoundSparse('test', vids_path, transforms['test']),
|
375 |
+
}
|
376 |
+
for phase in ['train', 'valid', 'test']:
|
377 |
+
print(phase, len(datasets[phase]))
|
378 |
+
|
379 |
+
print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
|
380 |
+
print(datasets['train'][0]['meta'])
|
381 |
+
print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
|
382 |
+
print(datasets['valid'][0]['meta'])
|
383 |
+
print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
|
384 |
+
print(datasets['test'][0]['meta'])
|
385 |
+
|
386 |
+
datasets = {
|
387 |
+
'test': VGGSoundSparsePickedCleanTestFixedOffsets('test', vids_path, transforms['test']),
|
388 |
+
}
|
389 |
+
for phase in ['test']:
|
390 |
+
print(phase, len(datasets[phase]))
|
391 |
+
|
392 |
+
for i in range(len(datasets['test'])):
|
393 |
+
datasets['test'][i]['path']
|
394 |
+
print(datasets['test'][i]['audio'].shape, datasets['test'][i]['video'].shape)
|
modules/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Auto-generated __init__.py
|
modules/model/modules/bridges.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class BridgeBase(torch.nn.Module):
|
6 |
+
|
7 |
+
def __init__(self) -> None:
|
8 |
+
super().__init__()
|
9 |
+
self.bridge = None
|
10 |
+
|
11 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
12 |
+
try:
|
13 |
+
x = self.bridge(x)
|
14 |
+
return x
|
15 |
+
except TypeError as e:
|
16 |
+
raise TypeError('The class cant be called on its own. Please, use a class that inherits it', e)
|
17 |
+
|
18 |
+
|
19 |
+
class ConvBridgeBase(BridgeBase):
|
20 |
+
|
21 |
+
def __init__(self, block, **kwargs) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.bridge = torch.nn.Sequential(
|
24 |
+
block(**kwargs),
|
25 |
+
torch.nn.GELU(),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class AvgPoolBridgeBase(BridgeBase):
|
30 |
+
|
31 |
+
def __init__(self, block, **kwargs) -> None:
|
32 |
+
super().__init__()
|
33 |
+
# a list [1, 2, 3] specified in omegaconfig, has type ListConf which is not accepted by pytorch
|
34 |
+
# interestingly, conv layers don't care
|
35 |
+
for k in ['kernel_size', 'stride']:
|
36 |
+
kwargs[k] = list(kwargs[k])
|
37 |
+
self.bridge = block(**kwargs)
|
38 |
+
|
39 |
+
|
40 |
+
class ConvBridgeAudio(ConvBridgeBase):
|
41 |
+
|
42 |
+
def __init__(self, **kwargs) -> None:
|
43 |
+
super().__init__(block=torch.nn.Conv2d, **kwargs)
|
44 |
+
|
45 |
+
|
46 |
+
class ConvBridgeVisual(ConvBridgeBase):
|
47 |
+
|
48 |
+
def __init__(self, **kwargs) -> None:
|
49 |
+
super().__init__(block=torch.nn.Conv3d, **kwargs)
|
50 |
+
|
51 |
+
|
52 |
+
class AvgPoolBridgeVisual(AvgPoolBridgeBase):
|
53 |
+
|
54 |
+
def __init__(self, **kwargs) -> None:
|
55 |
+
super().__init__(block=torch.nn.AvgPool3d, **kwargs)
|
56 |
+
|
57 |
+
|
58 |
+
class AvgPoolBridgeAudio(AvgPoolBridgeBase):
|
59 |
+
|
60 |
+
def __init__(self, **kwargs) -> None:
|
61 |
+
super().__init__(block=torch.nn.AvgPool2d, **kwargs)
|
62 |
+
|
63 |
+
|
64 |
+
class DoNothingBridge(BridgeBase):
|
65 |
+
|
66 |
+
def __init__(self, **kwargs) -> None:
|
67 |
+
super().__init__()
|
68 |
+
self.bridge = torch.nn.Identity(**kwargs)
|
69 |
+
|
70 |
+
|
71 |
+
class AppendZerosToHidden(BridgeBase):
|
72 |
+
|
73 |
+
def __init__(self, target_hidden_size, dim) -> None:
|
74 |
+
super().__init__()
|
75 |
+
self.target_hidden_size = target_hidden_size
|
76 |
+
self.dim = dim
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
d_res = self.target_hidden_size - x.shape[self.dim]
|
80 |
+
# going to insert the new dimension into the x.shape output
|
81 |
+
shape_target = list(x.shape[:self.dim]) + [d_res] + list(x.shape[self.dim+1:])
|
82 |
+
# creating the zeros to append to x
|
83 |
+
zeros = torch.zeros(shape_target).to(x)
|
84 |
+
x = torch.cat([x, zeros], self.dim)
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class SpatialpoolConvTemporalpool(torch.nn.Module):
|
89 |
+
'''Similar to S3D but with slightly different kernel for F.avg_pool3d.
|
90 |
+
To be used in AVCLIP in visual branch'''
|
91 |
+
|
92 |
+
def __init__(self, **kwargs) -> None:
|
93 |
+
super().__init__()
|
94 |
+
self.conv = torch.nn.Conv3d(**kwargs)
|
95 |
+
|
96 |
+
def forward(self, x: torch.Tensor):
|
97 |
+
B, t, d, h, w = x.shape
|
98 |
+
x = x.permute(0, 2, 1, 3, 4) # (B, d, t, h, w)
|
99 |
+
# pool as in S3D but without temporal pooling (2-->1, h, w)
|
100 |
+
x = F.avg_pool3d(x, (1, h, w), stride=1) # (B, d, t, 1, 1)
|
101 |
+
x = self.conv(x) # (B, D, t, 1, 1)
|
102 |
+
x = x.view(B, self.conv.out_channels, t) # squeeze the spatial dimensions
|
103 |
+
x = x.mean(dim=-1) # temporal pooling
|
104 |
+
return x # (B, d)
|
105 |
+
|
106 |
+
|
107 |
+
class FrequencypoolConvTemporalpool(torch.nn.Module):
|
108 |
+
'''Similar to the visual branch of S3D, which is a stack of spatial pool, conv, temporal pool blocks.
|
109 |
+
Instead, this is a stack of frequency pool, conv, temporal pooling blocks.
|
110 |
+
To be used in AVCLIP in audio branch'''
|
111 |
+
|
112 |
+
def __init__(self, **kwargs) -> None:
|
113 |
+
super().__init__()
|
114 |
+
self.conv = torch.nn.Conv2d(**kwargs)
|
115 |
+
|
116 |
+
def forward(self, x: torch.Tensor):
|
117 |
+
B, d, f, t = x.shape
|
118 |
+
# frequency pooling (f-->1)
|
119 |
+
x = F.avg_pool2d(x, (f, 1), stride=1) # (B, d, 1, t)
|
120 |
+
x = self.conv(x) # (B, D, 1, t)
|
121 |
+
x = x.view(B, self.conv.out_channels, t) # squeeze the frequency dimension
|
122 |
+
x = x.mean(dim=-1) # temporal pooling
|
123 |
+
return x # (B, d)
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
v = torch.rand(2, 50, 512, 7, 7)
|
128 |
+
a = torch.rand(2, 512, 9, 27)
|
129 |
+
|
130 |
+
in_channels = 512
|
131 |
+
out_channels = 512
|
132 |
+
kernel_size_v = [1, 7, 7]
|
133 |
+
kernel_size_a = [9, 1]
|
134 |
+
stride = 1
|
135 |
+
bias = True
|
136 |
+
|
137 |
+
conv_bridge_a = ConvBridgeAudio(in_channels=in_channels, out_channels=out_channels,
|
138 |
+
kernel_size=kernel_size_a, stride=stride, bias=bias)
|
139 |
+
conv_bridge_v = ConvBridgeVisual(in_channels=in_channels, out_channels=out_channels,
|
140 |
+
kernel_size=kernel_size_v, stride=stride, bias=bias)
|
141 |
+
avg_bridge_a = AvgPoolBridgeAudio(kernel_size=kernel_size_a)
|
142 |
+
avg_bridge_v = AvgPoolBridgeVisual(kernel_size=kernel_size_v)
|
143 |
+
i_bridge_a = DoNothingBridge(some_arg=123)
|
144 |
+
i_bridge_v = DoNothingBridge(some_arg=123)
|
145 |
+
h_bridge_a = AppendZerosToHidden(target_hidden_size=1024, dim=1)
|
146 |
+
h_bridge_v = AppendZerosToHidden(target_hidden_size=1024, dim=1)
|
147 |
+
|
148 |
+
print('v', v.shape)
|
149 |
+
print('conv_v(v)', conv_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
|
150 |
+
print()
|
151 |
+
|
152 |
+
print('a', a.shape)
|
153 |
+
print('conv_a(a)', conv_bridge_a(a).shape)
|
154 |
+
print()
|
155 |
+
|
156 |
+
print('v', v.shape)
|
157 |
+
print('avg3d(v)', avg_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
|
158 |
+
print()
|
159 |
+
|
160 |
+
print('a', a.shape)
|
161 |
+
print('avg2d(a)', avg_bridge_a(a).shape)
|
162 |
+
print()
|
163 |
+
|
164 |
+
print('v', v.shape)
|
165 |
+
print('i(v)', i_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
|
166 |
+
print()
|
167 |
+
|
168 |
+
print('a', a.shape)
|
169 |
+
print('i(a)', i_bridge_a(a).shape)
|
170 |
+
print()
|
171 |
+
|
172 |
+
print('v', v.shape)
|
173 |
+
print('h(v)', h_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
|
174 |
+
print()
|
175 |
+
|
176 |
+
print('a', a.shape)
|
177 |
+
print('h(a)', h_bridge_a(a).shape)
|
178 |
+
print()
|
modules/model/modules/feat_extractors/audio/ast.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
# importing modified version of AST
|
5 |
+
from model.modules.feat_extractors.audio.hf_src.modeling_ast import ASTForAudioClassification, ASTConfig
|
6 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
7 |
+
|
8 |
+
from model.modules.feat_extractors.visual.motionformer import (AveragePooling, BaseEncoderLayer,
|
9 |
+
TemporalTransformerEncoderLayer)
|
10 |
+
from utils.utils import check_if_file_exists_else_download
|
11 |
+
|
12 |
+
|
13 |
+
class AST(torch.nn.Module):
|
14 |
+
def __init__(self,
|
15 |
+
extract_features: bool = False,
|
16 |
+
ckpt_path: str = None,
|
17 |
+
feat_type: str = None,
|
18 |
+
max_spec_t: int = None,
|
19 |
+
factorize_freq_time: bool = None,
|
20 |
+
agg_freq_module: str = None,
|
21 |
+
agg_time_module: str = None,
|
22 |
+
add_global_repr: bool = True,
|
23 |
+
agg_segments_module: str = None,
|
24 |
+
max_segments: int = None,
|
25 |
+
) -> None:
|
26 |
+
'''
|
27 |
+
extract_features: if True, then the model will return the features instead of head's output
|
28 |
+
ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub.
|
29 |
+
feat_type: if extract_features is True, this parameter specifies the type of features to return
|
30 |
+
max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec
|
31 |
+
factorize_freq_time: if True, then the model will use a factorized freq/time aggregation
|
32 |
+
agg_freq_module: if specified, then the model will use this module for freq aggregation
|
33 |
+
agg_time_module: if specified, then the model will use this module for time aggregation
|
34 |
+
add_global_repr: if True, adds a global representation to the features (aggregation on segments)
|
35 |
+
agg_segments_module: if specified, then the model will use this module for segments aggregation
|
36 |
+
max_segments: if specified, the initialization of PE in the global agg module will use this value.
|
37 |
+
This should correspond to the max number of segments per video (if None, 16 is used)
|
38 |
+
'''
|
39 |
+
super().__init__()
|
40 |
+
self.extract_features = extract_features
|
41 |
+
self.ckpt_path = ckpt_path
|
42 |
+
self.max_spec_t = max_spec_t
|
43 |
+
self.max_segments = max_segments
|
44 |
+
|
45 |
+
# depending on whether the feat extractor was pre-trained contrastively or not, we need to
|
46 |
+
# load the state dict differently.
|
47 |
+
|
48 |
+
# if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model
|
49 |
+
if ckpt_path == 'MIT/ast-finetuned-audioset-10-10-0.4593':
|
50 |
+
revision = 'c1c0c66' # fixing the revision for compatibility (V4.27.4)
|
51 |
+
self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision)
|
52 |
+
full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision)
|
53 |
+
logging.info(f'Loaded AST from {ckpt_path}')
|
54 |
+
else:
|
55 |
+
self.config = ASTConfig()
|
56 |
+
self.config.num_labels = 527 # 2 by default, audioset has 527 labels
|
57 |
+
full_model = ASTForAudioClassification(self.config)
|
58 |
+
logging.info('Initialized AST from scratch with the AST AudioSet config')
|
59 |
+
|
60 |
+
was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith('.pt')
|
61 |
+
|
62 |
+
# feature extractor
|
63 |
+
self.ast = full_model.audio_spectrogram_transformer
|
64 |
+
|
65 |
+
if self.extract_features:
|
66 |
+
# assign `feat_type` (use default if not specified)
|
67 |
+
self.feat_type = 'last_hidden_state' if feat_type is None else feat_type
|
68 |
+
# define adapters if needed
|
69 |
+
self.factorize_freq_time = factorize_freq_time
|
70 |
+
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
|
71 |
+
transf_enc_layer_kwargs = dict(
|
72 |
+
d_model=self.config.hidden_size, nhead=self.config.num_attention_heads,
|
73 |
+
dim_feedforward=self.config.intermediate_size, activation=nn.GELU(), batch_first=True,
|
74 |
+
dropout=self.config.attention_probs_dropout_prob, layer_norm_eps=1e-6, norm_first=True,
|
75 |
+
)
|
76 |
+
if factorize_freq_time:
|
77 |
+
self.feat_type = 'last_hidden_state' # this feat_type supports factorization
|
78 |
+
# frequency aggreration
|
79 |
+
if agg_freq_module == 'TransformerEncoderLayer':
|
80 |
+
self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
81 |
+
elif agg_freq_module == 'AveragePooling':
|
82 |
+
self.freq_attn_agg = AveragePooling(avg_pattern='BS D f t -> BS D t',
|
83 |
+
then_permute_pattern='BS D t -> BS t D')
|
84 |
+
# time aggreration
|
85 |
+
if agg_time_module == 'TransformerEncoderLayer':
|
86 |
+
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
87 |
+
elif agg_time_module == 'AveragePooling':
|
88 |
+
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
|
89 |
+
elif 'Identity' in agg_time_module:
|
90 |
+
self.temp_attn_agg = nn.Identity()
|
91 |
+
# define a global aggregation layer (aggregarate over segments)
|
92 |
+
self.add_global_repr = add_global_repr
|
93 |
+
if add_global_repr:
|
94 |
+
if agg_segments_module == 'TransformerEncoderLayer':
|
95 |
+
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
|
96 |
+
# we need to add pos emb (PE) because previously we added the same PE for each segment
|
97 |
+
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
|
98 |
+
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
99 |
+
add_pos_emb=True, pos_emb_drop=self.config.hidden_dropout_prob,
|
100 |
+
pos_max_len=pos_max_len, **transf_enc_layer_kwargs
|
101 |
+
)
|
102 |
+
elif agg_segments_module == 'AveragePooling':
|
103 |
+
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
|
104 |
+
else:
|
105 |
+
self.classifier = full_model.classifier
|
106 |
+
|
107 |
+
# AST.device fails with AttributeError. This is a workaround
|
108 |
+
self.device = full_model.device
|
109 |
+
|
110 |
+
# pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74)
|
111 |
+
self.patch_position_emb()
|
112 |
+
|
113 |
+
if was_pt_on_avclip:
|
114 |
+
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
|
115 |
+
# and keep only the state_dict of the feat extractor
|
116 |
+
check_if_file_exists_else_download(self.ckpt_path)
|
117 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
118 |
+
ckpt_weights = dict()
|
119 |
+
for k, v in ckpt['state_dict'].items():
|
120 |
+
if k.startswith(('module.a_encoder.', 'a_encoder.')):
|
121 |
+
k = k.replace('module.', '').replace('a_encoder.', '')
|
122 |
+
ckpt_weights[k] = v
|
123 |
+
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
124 |
+
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
125 |
+
logging.warning(f'Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n' \
|
126 |
+
f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
127 |
+
f'{_load_status.missing_keys}, \n' \
|
128 |
+
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
129 |
+
f'{_load_status.unexpected_keys} \n' \
|
130 |
+
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
|
131 |
+
else:
|
132 |
+
logging.info(f'Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
133 |
+
|
134 |
+
# print the number of parameters
|
135 |
+
logging.info(f'AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}')
|
136 |
+
|
137 |
+
def forward(self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None,
|
138 |
+
**ast_kwargs) -> torch.Tensor:
|
139 |
+
'''
|
140 |
+
x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins,
|
141 |
+
ast_kwargs: additional arguments for the AST model
|
142 |
+
cont_mask: (B, S, T, F) where 0s are the values to be masked out
|
143 |
+
if `for_loop=True`, we use a for loop to extract features for each segment separately.
|
144 |
+
if `for_loop=False`, we extract features for all segments at once.
|
145 |
+
Using the for loop is slower but more memory efficient, while using all segments at once
|
146 |
+
is faster but more memory inefficient.
|
147 |
+
Using for loop allows to control the memory footprint by varying the number of videos in a
|
148 |
+
batch (batch size) rather than the number of segments in a video.
|
149 |
+
'''
|
150 |
+
B, S, T, F = x.shape
|
151 |
+
|
152 |
+
if for_loop:
|
153 |
+
assert cont_mask is None, 'cont_mask is not supported with for_loop=True'
|
154 |
+
orig_shape_s = (B, 1, T, F)
|
155 |
+
# NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F).
|
156 |
+
# (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1.
|
157 |
+
x = torch.cat(
|
158 |
+
[self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)],
|
159 |
+
dim=1)
|
160 |
+
else:
|
161 |
+
orig_shape = (B, S, T, F)
|
162 |
+
x = x.view(B * S, T, F)
|
163 |
+
if cont_mask is not None:
|
164 |
+
cont_mask = cont_mask.reshape(B * S, T, F)
|
165 |
+
# AST expects a tensor of shape (B*S, T, F).
|
166 |
+
x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
|
167 |
+
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
|
168 |
+
x = x.view(B, S, *x.shape[1:])
|
169 |
+
# x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
|
170 |
+
|
171 |
+
global_x = None
|
172 |
+
if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
|
173 |
+
assert len(x.shape) == 3, f'Local representation should be (B, S, D) {x.shape}'
|
174 |
+
global_x = self.global_attn_agg(x) # (B, D)
|
175 |
+
|
176 |
+
return x, global_x # x is (B, S, ...), global_x is (B, D) or None
|
177 |
+
|
178 |
+
def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
|
179 |
+
'''x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out'''
|
180 |
+
# 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, <tokens>]
|
181 |
+
# x_mask is (B, T) where 0s are the values to be masked out
|
182 |
+
x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
|
183 |
+
|
184 |
+
if self.extract_features:
|
185 |
+
x = self.get_features_by_type(x)
|
186 |
+
if self.factorize_freq_time:
|
187 |
+
x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
|
188 |
+
if cont_mask is not None:
|
189 |
+
# duplicating the mask for the latent dimension (D) to be compatible with the next func
|
190 |
+
x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
|
191 |
+
x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
|
192 |
+
# again removing the latent
|
193 |
+
x_mask = x_mask[:, 0, :, :]
|
194 |
+
else:
|
195 |
+
x_mask = None
|
196 |
+
x = self.freq_attn_agg(x, x_mask) # (BS, t, D)
|
197 |
+
x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
|
198 |
+
else:
|
199 |
+
x = x['pooler_output']
|
200 |
+
x = self.classifier(x)
|
201 |
+
return x
|
202 |
+
|
203 |
+
def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor:
|
204 |
+
if self.feat_type == 'pooler_output':
|
205 |
+
return x['pooler_output'] # (B, D)
|
206 |
+
elif self.feat_type == 'CLS':
|
207 |
+
return x['last_hidden_state'][:, 0, :] # (B, D)
|
208 |
+
elif self.feat_type == 'last_hidden_state':
|
209 |
+
return x['last_hidden_state'] # (B, 2+T, D)
|
210 |
+
elif self.feat_type == 'last_hidden_state_no_AUX':
|
211 |
+
return x['last_hidden_state'][:, 2:, :] # (B, T, D) removing CLS and distill tokens
|
212 |
+
else:
|
213 |
+
raise ValueError(f'Unknown feature type: {self.feat_type}')
|
214 |
+
|
215 |
+
def restore_freq_temp_dims(self, feats, orig_shape: tuple):
|
216 |
+
'''
|
217 |
+
feats are of shape (B*S, T, D)
|
218 |
+
where T = 2 + f * t (if feat_type == 'last_hidden_state')
|
219 |
+
where T = f * t (if feat_type == 'last_hidden_state_no_AUX')
|
220 |
+
Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching.
|
221 |
+
From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats:
|
222 |
+
`feats.transpose(1, 2).view(B*S, D, f, t)`
|
223 |
+
|
224 |
+
(Similar function is defined in for RGB features in `motionformer.py`)
|
225 |
+
'''
|
226 |
+
B, S, T, F = orig_shape
|
227 |
+
D = self.config.hidden_size
|
228 |
+
|
229 |
+
# num patches in each dimension
|
230 |
+
f, t = self.ast.embeddings.get_shape(self.config)
|
231 |
+
|
232 |
+
if self.feat_type == 'last_hidden_state':
|
233 |
+
feats = feats[:, 2:, :] # removing CLS and distill tokens
|
234 |
+
|
235 |
+
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
236 |
+
feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
|
237 |
+
|
238 |
+
return feats
|
239 |
+
|
240 |
+
def patch_position_emb(self):
|
241 |
+
if self.max_spec_t is not None:
|
242 |
+
self.config.max_length = self.max_spec_t
|
243 |
+
f, t = self.ast.embeddings.get_shape(self.config)
|
244 |
+
shortened = self.ast.embeddings.position_embeddings[:, :f*t+2].clone() # +2 for CLS and distill tokens
|
245 |
+
self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
|
246 |
+
|
247 |
+
def to(self, device):
|
248 |
+
'''AST.device fails with AttributeError. This is a workaround. '''
|
249 |
+
self.device = torch.device(device)
|
250 |
+
return super().to(device)
|
251 |
+
|
252 |
+
|
253 |
+
class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
|
254 |
+
''' This layer is used to aggregate the features along the frequency axis.
|
255 |
+
It follows the same logic as spatio-temporal aggregation in visual feature extractor.
|
256 |
+
Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py` '''
|
257 |
+
|
258 |
+
def __init__(self, *args, **kwargs):
|
259 |
+
super().__init__(*args, **kwargs)
|
260 |
+
|
261 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
262 |
+
''' x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out '''
|
263 |
+
BS, D, f, t = x.shape
|
264 |
+
|
265 |
+
# time as a batch dimension
|
266 |
+
x = x.permute(0, 3, 2, 1) # (B*S, t, f, D)
|
267 |
+
x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory
|
268 |
+
# similar to mask
|
269 |
+
if x_mask is not None:
|
270 |
+
x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f)
|
271 |
+
x_mask = x_mask.reshape(BS * t, f)
|
272 |
+
|
273 |
+
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
274 |
+
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
|
275 |
+
|
276 |
+
# reshape back to (B*S, t, D)
|
277 |
+
x = x.view(BS, t, D)
|
278 |
+
|
279 |
+
return x # (B*S, t, D)
|
modules/model/modules/feat_extractors/audio/hf_src/modeling_ast.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Modified by v-iashin to support token masking
|
17 |
+
|
18 |
+
""" PyTorch Audio Spectrogram Transformer (AST) model."""
|
19 |
+
|
20 |
+
import math
|
21 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
|
30 |
+
from transformers.modeling_utils import PreTrainedModel
|
31 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
32 |
+
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
|
33 |
+
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
# General docstring
|
39 |
+
_CONFIG_FOR_DOC = "ASTConfig"
|
40 |
+
|
41 |
+
# Base docstring
|
42 |
+
_CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
43 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
|
44 |
+
|
45 |
+
# Audio classification docstring
|
46 |
+
_SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
47 |
+
_SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
|
48 |
+
_SEQ_CLASS_EXPECTED_LOSS = 0.17
|
49 |
+
|
50 |
+
|
51 |
+
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
52 |
+
"MIT/ast-finetuned-audioset-10-10-0.4593",
|
53 |
+
# See all Audio Spectrogram Transformer models at https://huggingface.co/models?filter=ast
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
class ASTEmbeddings(nn.Module):
|
58 |
+
"""
|
59 |
+
Construct the CLS token, position and patch embeddings.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, config: ASTConfig) -> None:
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
66 |
+
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
67 |
+
self.patch_embeddings = ASTPatchEmbeddings(config)
|
68 |
+
|
69 |
+
frequency_out_dimension, time_out_dimension = self.get_shape(config)
|
70 |
+
num_patches = frequency_out_dimension * time_out_dimension
|
71 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
|
72 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
73 |
+
self.config = config
|
74 |
+
|
75 |
+
def get_shape(self, config):
|
76 |
+
# see Karpathy's cs231n blog on how to calculate the output dimensions
|
77 |
+
# https://cs231n.github.io/convolutional-networks/#conv
|
78 |
+
frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
|
79 |
+
time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
|
80 |
+
|
81 |
+
return frequency_out_dimension, time_out_dimension
|
82 |
+
|
83 |
+
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
84 |
+
batch_size = input_values.shape[0]
|
85 |
+
embeddings = self.patch_embeddings(input_values)
|
86 |
+
|
87 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
88 |
+
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
|
89 |
+
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
|
90 |
+
embeddings = embeddings + self.position_embeddings
|
91 |
+
embeddings = self.dropout(embeddings)
|
92 |
+
|
93 |
+
return embeddings
|
94 |
+
|
95 |
+
|
96 |
+
class ASTPatchEmbeddings(nn.Module):
|
97 |
+
"""
|
98 |
+
This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
|
99 |
+
seq_length, hidden_size)` to be consumed by a Transformer.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, config):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
patch_size = config.patch_size
|
106 |
+
frequency_stride = config.frequency_stride
|
107 |
+
time_stride = config.time_stride
|
108 |
+
|
109 |
+
self.projection = nn.Conv2d(
|
110 |
+
1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride)
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
114 |
+
input_values = input_values.unsqueeze(1)
|
115 |
+
input_values = input_values.transpose(2, 3)
|
116 |
+
embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
|
117 |
+
return embeddings
|
118 |
+
|
119 |
+
|
120 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
|
121 |
+
class ASTSelfAttention(nn.Module):
|
122 |
+
def __init__(self, config: ASTConfig) -> None:
|
123 |
+
super().__init__()
|
124 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
125 |
+
raise ValueError(
|
126 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
127 |
+
f"heads {config.num_attention_heads}."
|
128 |
+
)
|
129 |
+
|
130 |
+
self.num_attention_heads = config.num_attention_heads
|
131 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
132 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
133 |
+
|
134 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
135 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
136 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
137 |
+
|
138 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
139 |
+
|
140 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
141 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
142 |
+
x = x.view(new_x_shape)
|
143 |
+
return x.permute(0, 2, 1, 3)
|
144 |
+
|
145 |
+
def forward(
|
146 |
+
self, hidden_states, tok_mask: Optional[torch.Tensor] = None,
|
147 |
+
head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
148 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
149 |
+
mixed_query_layer = self.query(hidden_states)
|
150 |
+
|
151 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
152 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
153 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
154 |
+
|
155 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
156 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
157 |
+
|
158 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
159 |
+
|
160 |
+
# apply masking if provided, tok_mask is (BS, N): 1s - keep; attention_scores is (BS, H, N, N)
|
161 |
+
if tok_mask is not None:
|
162 |
+
BS, N = tok_mask.shape
|
163 |
+
attention_scores = attention_scores.masked_fill(tok_mask.view(BS, 1, 1, N) == 0, float('-inf'))
|
164 |
+
|
165 |
+
# Normalize the attention scores to probabilities.
|
166 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
167 |
+
|
168 |
+
# This is actually dropping out entire tokens to attend to, which might
|
169 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
170 |
+
attention_probs = self.dropout(attention_probs)
|
171 |
+
|
172 |
+
# Mask heads if we want to
|
173 |
+
if head_mask is not None:
|
174 |
+
attention_probs = attention_probs * head_mask
|
175 |
+
|
176 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
177 |
+
|
178 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
179 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
180 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
181 |
+
|
182 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
183 |
+
|
184 |
+
return outputs
|
185 |
+
|
186 |
+
|
187 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
|
188 |
+
class ASTSelfOutput(nn.Module):
|
189 |
+
"""
|
190 |
+
The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
|
191 |
+
layernorm applied before each block.
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(self, config: ASTConfig) -> None:
|
195 |
+
super().__init__()
|
196 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
197 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
198 |
+
|
199 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
200 |
+
hidden_states = self.dense(hidden_states)
|
201 |
+
hidden_states = self.dropout(hidden_states)
|
202 |
+
|
203 |
+
return hidden_states
|
204 |
+
|
205 |
+
|
206 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST
|
207 |
+
class ASTAttention(nn.Module):
|
208 |
+
def __init__(self, config: ASTConfig) -> None:
|
209 |
+
super().__init__()
|
210 |
+
self.attention = ASTSelfAttention(config)
|
211 |
+
self.output = ASTSelfOutput(config)
|
212 |
+
self.pruned_heads = set()
|
213 |
+
|
214 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
215 |
+
if len(heads) == 0:
|
216 |
+
return
|
217 |
+
heads, index = find_pruneable_heads_and_indices(
|
218 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
219 |
+
)
|
220 |
+
|
221 |
+
# Prune linear layers
|
222 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
223 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
224 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
225 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
226 |
+
|
227 |
+
# Update hyper params and store pruned heads
|
228 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
229 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
230 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
231 |
+
|
232 |
+
def forward(
|
233 |
+
self,
|
234 |
+
hidden_states: torch.Tensor,
|
235 |
+
tok_mask: Optional[torch.Tensor] = None,
|
236 |
+
head_mask: Optional[torch.Tensor] = None,
|
237 |
+
output_attentions: bool = False,
|
238 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
239 |
+
self_outputs = self.attention(hidden_states, tok_mask, head_mask, output_attentions)
|
240 |
+
|
241 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
242 |
+
|
243 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
244 |
+
return outputs
|
245 |
+
|
246 |
+
|
247 |
+
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
|
248 |
+
class ASTIntermediate(nn.Module):
|
249 |
+
def __init__(self, config: ASTConfig) -> None:
|
250 |
+
super().__init__()
|
251 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
252 |
+
if isinstance(config.hidden_act, str):
|
253 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
254 |
+
else:
|
255 |
+
self.intermediate_act_fn = config.hidden_act
|
256 |
+
|
257 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
258 |
+
hidden_states = self.dense(hidden_states)
|
259 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
260 |
+
|
261 |
+
return hidden_states
|
262 |
+
|
263 |
+
|
264 |
+
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST
|
265 |
+
class ASTOutput(nn.Module):
|
266 |
+
def __init__(self, config: ASTConfig) -> None:
|
267 |
+
super().__init__()
|
268 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
269 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
270 |
+
|
271 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
272 |
+
hidden_states = self.dense(hidden_states)
|
273 |
+
hidden_states = self.dropout(hidden_states)
|
274 |
+
|
275 |
+
hidden_states = hidden_states + input_tensor
|
276 |
+
|
277 |
+
return hidden_states
|
278 |
+
|
279 |
+
|
280 |
+
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST
|
281 |
+
class ASTLayer(nn.Module):
|
282 |
+
"""This corresponds to the Block class in the timm implementation."""
|
283 |
+
|
284 |
+
def __init__(self, config: ASTConfig) -> None:
|
285 |
+
super().__init__()
|
286 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
287 |
+
self.seq_len_dim = 1
|
288 |
+
self.attention = ASTAttention(config)
|
289 |
+
self.intermediate = ASTIntermediate(config)
|
290 |
+
self.output = ASTOutput(config)
|
291 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
292 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
293 |
+
|
294 |
+
def forward(
|
295 |
+
self,
|
296 |
+
hidden_states: torch.Tensor,
|
297 |
+
tok_mask: Optional[torch.Tensor] = None,
|
298 |
+
head_mask: Optional[torch.Tensor] = None,
|
299 |
+
output_attentions: bool = False,
|
300 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
301 |
+
self_attention_outputs = self.attention(
|
302 |
+
self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention
|
303 |
+
tok_mask,
|
304 |
+
head_mask,
|
305 |
+
output_attentions=output_attentions,
|
306 |
+
)
|
307 |
+
attention_output = self_attention_outputs[0]
|
308 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
309 |
+
|
310 |
+
# first residual connection
|
311 |
+
hidden_states = attention_output + hidden_states
|
312 |
+
|
313 |
+
# in AST, layernorm is also applied after self-attention
|
314 |
+
layer_output = self.layernorm_after(hidden_states)
|
315 |
+
layer_output = self.intermediate(layer_output)
|
316 |
+
|
317 |
+
# second residual connection is done here
|
318 |
+
layer_output = self.output(layer_output, hidden_states)
|
319 |
+
|
320 |
+
outputs = (layer_output,) + outputs
|
321 |
+
|
322 |
+
return outputs
|
323 |
+
|
324 |
+
|
325 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST
|
326 |
+
class ASTEncoder(nn.Module):
|
327 |
+
def __init__(self, config: ASTConfig) -> None:
|
328 |
+
super().__init__()
|
329 |
+
self.config = config
|
330 |
+
self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
|
331 |
+
self.gradient_checkpointing = False
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
hidden_states: torch.Tensor,
|
336 |
+
tok_mask: Optional[torch.Tensor] = None,
|
337 |
+
head_mask: Optional[torch.Tensor] = None,
|
338 |
+
output_attentions: bool = False,
|
339 |
+
output_hidden_states: bool = False,
|
340 |
+
return_dict: bool = True,
|
341 |
+
) -> Union[tuple, BaseModelOutput]:
|
342 |
+
all_hidden_states = () if output_hidden_states else None
|
343 |
+
all_self_attentions = () if output_attentions else None
|
344 |
+
|
345 |
+
for i, layer_module in enumerate(self.layer):
|
346 |
+
if output_hidden_states:
|
347 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
348 |
+
|
349 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
350 |
+
|
351 |
+
if self.gradient_checkpointing and self.training:
|
352 |
+
|
353 |
+
def create_custom_forward(module):
|
354 |
+
def custom_forward(*inputs):
|
355 |
+
return module(*inputs, output_attentions)
|
356 |
+
|
357 |
+
return custom_forward
|
358 |
+
|
359 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
360 |
+
create_custom_forward(layer_module),
|
361 |
+
hidden_states,
|
362 |
+
tok_mask,
|
363 |
+
layer_head_mask,
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions)
|
367 |
+
|
368 |
+
hidden_states = layer_outputs[0]
|
369 |
+
|
370 |
+
if output_attentions:
|
371 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
372 |
+
|
373 |
+
if output_hidden_states:
|
374 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
375 |
+
|
376 |
+
if not return_dict:
|
377 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
378 |
+
return BaseModelOutput(
|
379 |
+
last_hidden_state=hidden_states,
|
380 |
+
hidden_states=all_hidden_states,
|
381 |
+
attentions=all_self_attentions,
|
382 |
+
)
|
383 |
+
|
384 |
+
|
385 |
+
class ASTPreTrainedModel(PreTrainedModel):
|
386 |
+
"""
|
387 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
388 |
+
models.
|
389 |
+
"""
|
390 |
+
|
391 |
+
config_class = ASTConfig
|
392 |
+
base_model_prefix = "audio_spectrogram_transformer"
|
393 |
+
main_input_name = "input_values"
|
394 |
+
supports_gradient_checkpointing = True
|
395 |
+
|
396 |
+
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
|
397 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
398 |
+
"""Initialize the weights"""
|
399 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
400 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
401 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
402 |
+
module.weight.data = nn.init.trunc_normal_(
|
403 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
404 |
+
).to(module.weight.dtype)
|
405 |
+
if module.bias is not None:
|
406 |
+
module.bias.data.zero_()
|
407 |
+
elif isinstance(module, nn.LayerNorm):
|
408 |
+
module.bias.data.zero_()
|
409 |
+
module.weight.data.fill_(1.0)
|
410 |
+
|
411 |
+
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
|
412 |
+
def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None:
|
413 |
+
if isinstance(module, ASTEncoder):
|
414 |
+
module.gradient_checkpointing = value
|
415 |
+
|
416 |
+
|
417 |
+
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
|
418 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
419 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
420 |
+
behavior.
|
421 |
+
|
422 |
+
Parameters:
|
423 |
+
config ([`ASTConfig`]):
|
424 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
425 |
+
load the weights associated with the model, only the configuration. Check out the
|
426 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
427 |
+
"""
|
428 |
+
|
429 |
+
AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
|
430 |
+
Args:
|
431 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
432 |
+
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
|
433 |
+
[`ASTFeatureExtractor.__call__`] for details.
|
434 |
+
|
435 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
436 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
437 |
+
|
438 |
+
- 1 indicates the head is **not masked**,
|
439 |
+
- 0 indicates the head is **masked**.
|
440 |
+
|
441 |
+
output_attentions (`bool`, *optional*):
|
442 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
443 |
+
tensors for more detail.
|
444 |
+
output_hidden_states (`bool`, *optional*):
|
445 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
446 |
+
more detail.
|
447 |
+
return_dict (`bool`, *optional*):
|
448 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
449 |
+
"""
|
450 |
+
|
451 |
+
|
452 |
+
@add_start_docstrings(
|
453 |
+
"The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
|
454 |
+
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
|
455 |
+
)
|
456 |
+
class ASTModel(ASTPreTrainedModel):
|
457 |
+
def __init__(self, config: ASTConfig):
|
458 |
+
super().__init__(config)
|
459 |
+
self.config = config
|
460 |
+
|
461 |
+
self.embeddings = ASTEmbeddings(config)
|
462 |
+
self.encoder = ASTEncoder(config)
|
463 |
+
|
464 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
465 |
+
|
466 |
+
# Initialize weights and apply final processing
|
467 |
+
self.post_init()
|
468 |
+
|
469 |
+
def get_input_embeddings(self) -> ASTPatchEmbeddings:
|
470 |
+
return self.embeddings.patch_embeddings
|
471 |
+
|
472 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
473 |
+
"""
|
474 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
475 |
+
class PreTrainedModel
|
476 |
+
"""
|
477 |
+
for layer, heads in heads_to_prune.items():
|
478 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
479 |
+
|
480 |
+
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
|
481 |
+
@add_code_sample_docstrings(
|
482 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
483 |
+
output_type=BaseModelOutputWithPooling,
|
484 |
+
config_class=_CONFIG_FOR_DOC,
|
485 |
+
modality="audio",
|
486 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
487 |
+
)
|
488 |
+
def forward(
|
489 |
+
self,
|
490 |
+
input_values: Optional[torch.Tensor] = None,
|
491 |
+
cont_mask: Optional[torch.Tensor] = None,
|
492 |
+
head_mask: Optional[torch.Tensor] = None,
|
493 |
+
output_attentions: Optional[bool] = None,
|
494 |
+
output_hidden_states: Optional[bool] = None,
|
495 |
+
return_dict: Optional[bool] = None,
|
496 |
+
):
|
497 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
498 |
+
output_hidden_states = (
|
499 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
500 |
+
)
|
501 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
502 |
+
|
503 |
+
if input_values is None:
|
504 |
+
raise ValueError("You have to specify input_values")
|
505 |
+
|
506 |
+
# Prepare head mask if needed
|
507 |
+
# 1.0 in head_mask indicate we keep the head
|
508 |
+
# attention_probs has shape bsz x n_heads x N x N
|
509 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
510 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
511 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
512 |
+
|
513 |
+
embedding_output = self.embeddings(input_values)
|
514 |
+
|
515 |
+
# transforms the mask that has spectrogram dims to the token masking which is obtained after patching.
|
516 |
+
# Due to the ovelap in patching, getting the token mask from spectrogram mask is not straightforward,
|
517 |
+
# because one 16x16 content patch is encoded in two tokens if stride is <16. So, to get the mask for
|
518 |
+
# tokens I will apply the patching func (self.embeddings) to the tensor with infinities at the masked
|
519 |
+
# content position. For infs, the patching fn will return nans, which I'll use to get the token mask.
|
520 |
+
if cont_mask is not None:
|
521 |
+
indicator = torch.ones_like(input_values).to(input_values.dtype)
|
522 |
+
# replace content mask (0s) with infs
|
523 |
+
indicator[~cont_mask] = torch.inf
|
524 |
+
# apply patching; now nans are where the content mask was
|
525 |
+
with torch.no_grad():
|
526 |
+
indicator = self.embeddings(indicator) # BS, N, D
|
527 |
+
# replace nans with 0s; these are the tokens that correspond to the masked content
|
528 |
+
tok_mask = ~torch.isnan(indicator)
|
529 |
+
# since all values in the D-dimension (latent) will also be nans, we can just use the first el
|
530 |
+
tok_mask = tok_mask[:, :, 0] # (BS, 2+num_patches) -- 2 is from CLS and DISTIL tokens
|
531 |
+
else:
|
532 |
+
tok_mask = None
|
533 |
+
|
534 |
+
encoder_outputs = self.encoder(
|
535 |
+
embedding_output,
|
536 |
+
tok_mask=tok_mask,
|
537 |
+
head_mask=head_mask,
|
538 |
+
output_attentions=output_attentions,
|
539 |
+
output_hidden_states=output_hidden_states,
|
540 |
+
return_dict=return_dict,
|
541 |
+
)
|
542 |
+
sequence_output = encoder_outputs[0]
|
543 |
+
sequence_output = self.layernorm(sequence_output)
|
544 |
+
|
545 |
+
pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
|
546 |
+
|
547 |
+
if not return_dict:
|
548 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
549 |
+
|
550 |
+
return BaseModelOutputWithPooling(
|
551 |
+
last_hidden_state=sequence_output,
|
552 |
+
pooler_output=pooled_output,
|
553 |
+
hidden_states=encoder_outputs.hidden_states,
|
554 |
+
attentions=encoder_outputs.attentions,
|
555 |
+
), tok_mask
|
556 |
+
|
557 |
+
|
558 |
+
class ASTMLPHead(nn.Module):
|
559 |
+
def __init__(self, config: ASTConfig):
|
560 |
+
super().__init__()
|
561 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
562 |
+
self.dense = nn.Linear(
|
563 |
+
config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
564 |
+
|
565 |
+
def forward(self, hidden_state):
|
566 |
+
hidden_state = self.layernorm(hidden_state)
|
567 |
+
hidden_state = self.dense(hidden_state)
|
568 |
+
return hidden_state
|
569 |
+
|
570 |
+
|
571 |
+
@add_start_docstrings(
|
572 |
+
"""
|
573 |
+
Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
|
574 |
+
output) e.g. for datasets like AudioSet, Speech Commands v2.
|
575 |
+
""",
|
576 |
+
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
|
577 |
+
)
|
578 |
+
class ASTForAudioClassification(ASTPreTrainedModel):
|
579 |
+
def __init__(self, config: ASTConfig) -> None:
|
580 |
+
super().__init__(config)
|
581 |
+
|
582 |
+
self.num_labels = config.num_labels
|
583 |
+
self.audio_spectrogram_transformer = ASTModel(config)
|
584 |
+
|
585 |
+
# Classifier head
|
586 |
+
self.classifier = ASTMLPHead(config)
|
587 |
+
|
588 |
+
# Initialize weights and apply final processing
|
589 |
+
self.post_init()
|
590 |
+
|
591 |
+
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
|
592 |
+
@add_code_sample_docstrings(
|
593 |
+
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
594 |
+
output_type=SequenceClassifierOutput,
|
595 |
+
config_class=_CONFIG_FOR_DOC,
|
596 |
+
modality="audio",
|
597 |
+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
598 |
+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
599 |
+
)
|
600 |
+
def forward(
|
601 |
+
self,
|
602 |
+
input_values: Optional[torch.Tensor] = None,
|
603 |
+
cont_mask: Optional[torch.Tensor] = None,
|
604 |
+
head_mask: Optional[torch.Tensor] = None,
|
605 |
+
labels: Optional[torch.Tensor] = None,
|
606 |
+
output_attentions: Optional[bool] = None,
|
607 |
+
output_hidden_states: Optional[bool] = None,
|
608 |
+
return_dict: Optional[bool] = None,
|
609 |
+
) -> Union[tuple, SequenceClassifierOutput]:
|
610 |
+
r"""
|
611 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
612 |
+
Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
|
613 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
614 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
615 |
+
"""
|
616 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
617 |
+
|
618 |
+
outputs = self.audio_spectrogram_transformer(
|
619 |
+
input_values,
|
620 |
+
cont_mask=cont_mask,
|
621 |
+
head_mask=head_mask,
|
622 |
+
output_attentions=output_attentions,
|
623 |
+
output_hidden_states=output_hidden_states,
|
624 |
+
return_dict=return_dict,
|
625 |
+
)
|
626 |
+
|
627 |
+
pooled_output = outputs[1]
|
628 |
+
logits = self.classifier(pooled_output)
|
629 |
+
|
630 |
+
loss = None
|
631 |
+
if labels is not None:
|
632 |
+
if self.config.problem_type is None:
|
633 |
+
if self.num_labels == 1:
|
634 |
+
self.config.problem_type = "regression"
|
635 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
636 |
+
self.config.problem_type = "single_label_classification"
|
637 |
+
else:
|
638 |
+
self.config.problem_type = "multi_label_classification"
|
639 |
+
|
640 |
+
if self.config.problem_type == "regression":
|
641 |
+
loss_fct = MSELoss()
|
642 |
+
if self.num_labels == 1:
|
643 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
644 |
+
else:
|
645 |
+
loss = loss_fct(logits, labels)
|
646 |
+
elif self.config.problem_type == "single_label_classification":
|
647 |
+
loss_fct = CrossEntropyLoss()
|
648 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
649 |
+
elif self.config.problem_type == "multi_label_classification":
|
650 |
+
loss_fct = BCEWithLogitsLoss()
|
651 |
+
loss = loss_fct(logits, labels)
|
652 |
+
|
653 |
+
if not return_dict:
|
654 |
+
output = (logits,) + outputs[2:]
|
655 |
+
return ((loss,) + output) if loss is not None else output
|
656 |
+
|
657 |
+
return SequenceClassifierOutput(
|
658 |
+
loss=loss,
|
659 |
+
logits=logits,
|
660 |
+
hidden_states=outputs.hidden_states,
|
661 |
+
attentions=outputs.attentions,
|
662 |
+
)
|
modules/model/modules/feat_extractors/audio/resnet.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
import logging
|
4 |
+
import einops
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
|
8 |
+
|
9 |
+
sys.path.append('.') # nopep8
|
10 |
+
|
11 |
+
from utils.utils import check_if_file_exists_else_download
|
12 |
+
from model.modules.feat_extractors.audio.ast import FrequencyTransformerEncoderLayer
|
13 |
+
from model.modules.feat_extractors.visual.motionformer import AveragePooling, TemporalTransformerEncoderLayer
|
14 |
+
|
15 |
+
|
16 |
+
class ResNetAudio(ResNet):
|
17 |
+
|
18 |
+
def __init__(self, arch_name, num_classes, extract_features, ckpt_path=None, **kwargs):
|
19 |
+
|
20 |
+
if arch_name == 'resnet18':
|
21 |
+
block = BasicBlock
|
22 |
+
layers = [2, 2, 2, 2]
|
23 |
+
elif arch_name == 'resnet34':
|
24 |
+
block = BasicBlock
|
25 |
+
layers = [3, 4, 6, 3]
|
26 |
+
elif arch_name == 'resnet50':
|
27 |
+
block = Bottleneck
|
28 |
+
layers = [3, 4, 6, 3]
|
29 |
+
elif arch_name == 'resnet101':
|
30 |
+
block = Bottleneck
|
31 |
+
layers = [3, 4, 23, 3]
|
32 |
+
elif arch_name == 'resnet152':
|
33 |
+
block = Bottleneck
|
34 |
+
layers = [3, 8, 36, 3]
|
35 |
+
else:
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
super().__init__(block, layers, num_classes, **kwargs)
|
39 |
+
|
40 |
+
# replacing the old conv1 to the new one (RGB - 3; spectrogram - 1)
|
41 |
+
conv1 = self.conv1
|
42 |
+
self.conv1 = torch.nn.Conv2d(1, conv1.out_channels, conv1.kernel_size,
|
43 |
+
conv1.stride, conv1.padding, bias=conv1.bias)
|
44 |
+
self.extract_features = extract_features
|
45 |
+
self.embed_dim = self.fc.in_features
|
46 |
+
|
47 |
+
# load the ckpt
|
48 |
+
load_state_dict_resnet(self, ckpt_path, prefix='afeat_extractor.')
|
49 |
+
|
50 |
+
def _forward_impl(self, x):
|
51 |
+
# See note [TorchScript super()]
|
52 |
+
x = self.conv1(x)
|
53 |
+
x = self.bn1(x)
|
54 |
+
x = self.relu(x)
|
55 |
+
x = self.maxpool(x)
|
56 |
+
|
57 |
+
x = self.layer1(x)
|
58 |
+
x = self.layer2(x)
|
59 |
+
x = self.layer3(x)
|
60 |
+
x = self.layer4(x)
|
61 |
+
|
62 |
+
if self.extract_features:
|
63 |
+
return x
|
64 |
+
|
65 |
+
x = self.avgpool(x)
|
66 |
+
x = torch.flatten(x, 1)
|
67 |
+
x = self.fc(x)
|
68 |
+
|
69 |
+
return x
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return super().forward(x)
|
73 |
+
|
74 |
+
|
75 |
+
class ResNet18AudioFeatures(ResNetAudio):
|
76 |
+
|
77 |
+
# ckpt_path should default to None, otherwise when no pre-training is desired it will throw an error
|
78 |
+
def __init__(self,
|
79 |
+
extract_features: bool = False,
|
80 |
+
ckpt_path: str = None,
|
81 |
+
feat_type: str = None,
|
82 |
+
max_spec_t: int = None,
|
83 |
+
factorize_freq_time: bool = None,
|
84 |
+
agg_freq_module: str = None,
|
85 |
+
agg_time_module: str = None,
|
86 |
+
add_global_repr: bool = True,
|
87 |
+
agg_segments_module: str = None,
|
88 |
+
max_segments: int = None,
|
89 |
+
) -> None:
|
90 |
+
super().__init__(arch_name='resnet18', num_classes=308, extract_features=extract_features,
|
91 |
+
ckpt_path=ckpt_path)
|
92 |
+
assert extract_features, 'Not implemented otherwise'
|
93 |
+
self.extract_features = extract_features
|
94 |
+
self.feat_type = feat_type
|
95 |
+
self.max_spec_t = max_spec_t
|
96 |
+
# similar to s3d
|
97 |
+
self.nhead = 8
|
98 |
+
self.mlp_ratio = 4
|
99 |
+
self.drop_rate = 0.0
|
100 |
+
|
101 |
+
if ckpt_path is not None:
|
102 |
+
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
103 |
+
was_pt_on_vgs_cls = 'ResNetAudio-' in Path(ckpt_path).stem
|
104 |
+
if was_pt_on_vgs_cls:
|
105 |
+
self.load_state_dict(ckpt['model'], strict=True)
|
106 |
+
|
107 |
+
# saving some memory
|
108 |
+
if extract_features:
|
109 |
+
self.avgpool = torch.nn.Identity()
|
110 |
+
self.fc = torch.nn.Identity()
|
111 |
+
|
112 |
+
# define adapters if needed
|
113 |
+
self.factorize_freq_time = factorize_freq_time
|
114 |
+
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
|
115 |
+
transf_enc_layer_kwargs = dict(
|
116 |
+
d_model=self.embed_dim, nhead=self.nhead, dim_feedforward=self.mlp_ratio*self.embed_dim,
|
117 |
+
activation=torch.nn.GELU(), batch_first=True, dropout=self.drop_rate, layer_norm_eps=1e-6,
|
118 |
+
norm_first=True,
|
119 |
+
)
|
120 |
+
if factorize_freq_time:
|
121 |
+
self.feat_type = 'last_hidden_state' # this feat_type supports factorization
|
122 |
+
# frequency aggreration
|
123 |
+
if agg_freq_module == 'TransformerEncoderLayer':
|
124 |
+
self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
125 |
+
elif agg_freq_module == 'AveragePooling':
|
126 |
+
self.freq_attn_agg = AveragePooling(avg_pattern='BS D f t -> BS D t',
|
127 |
+
then_permute_pattern='BS D t -> BS t D')
|
128 |
+
# time aggreration
|
129 |
+
if agg_time_module == 'TransformerEncoderLayer':
|
130 |
+
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
131 |
+
elif agg_time_module == 'AveragePooling':
|
132 |
+
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
|
133 |
+
elif 'Identity' in agg_time_module:
|
134 |
+
self.temp_attn_agg = torch.nn.Identity()
|
135 |
+
# define a global aggregation layer (aggregarate over segments)
|
136 |
+
self.add_global_repr = add_global_repr
|
137 |
+
if add_global_repr:
|
138 |
+
if agg_segments_module == 'TransformerEncoderLayer':
|
139 |
+
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
|
140 |
+
# we need to add pos emb (PE) because previously we added the same PE for each segment
|
141 |
+
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
|
142 |
+
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
143 |
+
add_pos_emb=True, pos_emb_drop=self.drop_rate,
|
144 |
+
pos_max_len=pos_max_len, **transf_enc_layer_kwargs
|
145 |
+
)
|
146 |
+
elif agg_segments_module == 'AveragePooling':
|
147 |
+
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
|
148 |
+
|
149 |
+
# do not keep fc to save memory
|
150 |
+
self.fc = torch.nn.Identity()
|
151 |
+
|
152 |
+
if ckpt_path is not None:
|
153 |
+
ckpt = ckpt['state_dict']
|
154 |
+
was_pt_on_avclip = any('a_encoder.' in k[0] or 'v_encoder.' in k[0] for k in ckpt.items())
|
155 |
+
assert was_pt_on_vgs_cls is False, f'Unexpected ckpt: {ckpt_path}'
|
156 |
+
if was_pt_on_avclip:
|
157 |
+
ckpt_weights = dict()
|
158 |
+
for k, v in ckpt.items():
|
159 |
+
if k.startswith(('module.a_encoder.', 'a_encoder.')):
|
160 |
+
k = k.replace('module.', '').replace('a_encoder.', '')
|
161 |
+
ckpt_weights[k] = v
|
162 |
+
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
163 |
+
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
164 |
+
logging.warning(f'Loading exact ckpt from {ckpt_path} failed. \n' \
|
165 |
+
f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
166 |
+
f'{_load_status.missing_keys}, \n' \
|
167 |
+
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
168 |
+
f'{_load_status.unexpected_keys} \n' \
|
169 |
+
f'freq_attn_agg are expected to be unexpected if ckpt was pt contrastively '\
|
170 |
+
f'as well as fc could be missing because we use features, not a classifier.')
|
171 |
+
else:
|
172 |
+
logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
|
173 |
+
|
174 |
+
# print the number of parameters
|
175 |
+
logging.info(f'afeat_extractor: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}')
|
176 |
+
|
177 |
+
def forward(self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None):
|
178 |
+
assert for_loop is False and cont_mask is None, 'Not implemented'
|
179 |
+
B, S, T, F = x.shape
|
180 |
+
|
181 |
+
# (BS, D) <- (B, S, T, D)
|
182 |
+
x = self.forward_segments(x)
|
183 |
+
|
184 |
+
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
|
185 |
+
x = x.view(B, S, *x.shape[1:])
|
186 |
+
# x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
|
187 |
+
|
188 |
+
global_x = None
|
189 |
+
if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
|
190 |
+
assert len(x.shape) == 3, f'Local representation should be (B, S, D) {x.shape}'
|
191 |
+
global_x = self.global_attn_agg(x) # (B, D)
|
192 |
+
|
193 |
+
return x, global_x # x is (B, S, ...), global_x is (B, D) or None
|
194 |
+
|
195 |
+
def forward_segments(self, x):
|
196 |
+
x = einops.rearrange(x, 'B S T F -> (B S) 1 F T')
|
197 |
+
# (BS, D, f, t) <- (BS, 1, F, T)
|
198 |
+
x = super().forward(x)
|
199 |
+
|
200 |
+
if self.extract_features:
|
201 |
+
if self.factorize_freq_time:
|
202 |
+
x = self.freq_attn_agg(x) # (BS, t, D)
|
203 |
+
x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
def load_state_dict_resnet(model, ckpt_path, prefix):
|
208 |
+
if ckpt_path is not None:
|
209 |
+
check_if_file_exists_else_download(ckpt_path)
|
210 |
+
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
211 |
+
ckpt = ckpt.get('model', ckpt.get('state_dict', ckpt))
|
212 |
+
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
|
213 |
+
# and keep only the state_dict of the feat extractor
|
214 |
+
# FIXME: this is a bit hacky, but it works
|
215 |
+
was_pt_on_avclip = any('a_encoder.' in k[0] or 'v_encoder.' in k[0] for k in ckpt.items())
|
216 |
+
if not was_pt_on_avclip:
|
217 |
+
model.load_state_dict(ckpt)
|
218 |
+
logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
|
219 |
+
# if was_pt_on_avclip:
|
220 |
+
# ckpt_weights = dict()
|
221 |
+
# for k, v in ckpt.items():
|
222 |
+
# if k.startswith(('module.a_encoder.', 'a_encoder.')):
|
223 |
+
# k = k.replace('module.', '').replace('a_encoder.', '')
|
224 |
+
# ckpt_weights[k] = v
|
225 |
+
# _load_status = model.load_state_dict(ckpt_weights, strict=False)
|
226 |
+
# if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
227 |
+
# logging.warning(f'Loading exact ckpt from {ckpt_path} failed. \n' \
|
228 |
+
# f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
229 |
+
# f'{_load_status.missing_keys}, \n' \
|
230 |
+
# f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
231 |
+
# f'{_load_status.unexpected_keys} \n' \
|
232 |
+
# f'freq_attn_agg are expected to be unexpected if ckpt was pt contrastively '\
|
233 |
+
# f'as well as fc could be missing because we use features, not a classifier.')
|
234 |
+
# else:
|
235 |
+
# logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
|
236 |
+
# else:
|
237 |
+
# model.load_state_dict(ckpt)
|
238 |
+
# logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == '__main__':
|
242 |
+
B = 2
|
243 |
+
ckpt_path = './model/modules/feat_extractors/audio/22-06-24T08-10-33/ResNetAudio-22-06-24T08-10-33.pt'
|
244 |
+
afeat_extractor = ResNet18AudioFeatures(ckpt_path=ckpt_path)
|
245 |
+
# x = torch.rand(B, 1, 257, 1551)
|
246 |
+
# x = torch.rand(B, 1, 257, 1379)
|
247 |
+
x = torch.rand(B, 1, 128, 66)
|
248 |
+
x = afeat_extractor(x)
|
249 |
+
print(x.shape)
|
modules/model/modules/feat_extractors/train_clip_src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('model/modules/feat_extractors')
|
3 |
+
sys.path.append('model/modules/feat_extractors/train_clip_src')
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .coca_model import CoCa
|
2 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
3 |
+
from .factory import create_model, create_model_from_pretrained, get_tokenizer, create_loss
|
4 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
5 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
6 |
+
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
7 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
8 |
+
from .openai import load_openai_model, list_openai_models
|
9 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
10 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
11 |
+
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
12 |
+
from .tokenizer import SimpleTokenizer, tokenize, decode
|
13 |
+
from .transform import image_transform, AugmentationCfg
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/coca_model.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
from .transformer import (
|
10 |
+
LayerNormFp32,
|
11 |
+
LayerNorm,
|
12 |
+
QuickGELU,
|
13 |
+
MultimodalTransformer,
|
14 |
+
)
|
15 |
+
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
16 |
+
|
17 |
+
try:
|
18 |
+
from transformers import (
|
19 |
+
BeamSearchScorer,
|
20 |
+
LogitsProcessorList,
|
21 |
+
TopPLogitsWarper,
|
22 |
+
TopKLogitsWarper,
|
23 |
+
RepetitionPenaltyLogitsProcessor,
|
24 |
+
MinLengthLogitsProcessor,
|
25 |
+
MaxLengthCriteria,
|
26 |
+
StoppingCriteriaList
|
27 |
+
)
|
28 |
+
|
29 |
+
GENERATION_TYPES = {
|
30 |
+
"top_k": TopKLogitsWarper,
|
31 |
+
"top_p": TopPLogitsWarper,
|
32 |
+
"beam_search": "beam_search"
|
33 |
+
}
|
34 |
+
_has_transformers = True
|
35 |
+
except ImportError as e:
|
36 |
+
GENERATION_TYPES = {
|
37 |
+
"top_k": None,
|
38 |
+
"top_p": None,
|
39 |
+
"beam_search": "beam_search"
|
40 |
+
}
|
41 |
+
_has_transformers = False
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class MultimodalCfg(CLIPTextCfg):
|
46 |
+
mlp_ratio: int = 4
|
47 |
+
dim_head: int = 64
|
48 |
+
heads: int = 8
|
49 |
+
n_queries: int = 256
|
50 |
+
attn_pooler_heads: int = 8
|
51 |
+
|
52 |
+
|
53 |
+
def _build_text_decoder_tower(
|
54 |
+
embed_dim,
|
55 |
+
multimodal_cfg,
|
56 |
+
quick_gelu: bool = False,
|
57 |
+
cast_dtype: Optional[torch.dtype] = None,
|
58 |
+
):
|
59 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
60 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
61 |
+
norm_layer = (
|
62 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
63 |
+
)
|
64 |
+
|
65 |
+
decoder = MultimodalTransformer(
|
66 |
+
context_length=multimodal_cfg.context_length,
|
67 |
+
width=multimodal_cfg.width,
|
68 |
+
heads=multimodal_cfg.heads,
|
69 |
+
layers=multimodal_cfg.layers,
|
70 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
71 |
+
output_dim=embed_dim,
|
72 |
+
act_layer=act_layer,
|
73 |
+
norm_layer=norm_layer,
|
74 |
+
)
|
75 |
+
|
76 |
+
return decoder
|
77 |
+
|
78 |
+
|
79 |
+
class CoCa(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
embed_dim,
|
83 |
+
multimodal_cfg: MultimodalCfg,
|
84 |
+
text_cfg: CLIPTextCfg,
|
85 |
+
vision_cfg: CLIPVisionCfg,
|
86 |
+
quick_gelu: bool = False,
|
87 |
+
cast_dtype: Optional[torch.dtype] = None,
|
88 |
+
pad_id: int = 0,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
92 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
93 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
94 |
+
|
95 |
+
self.text = _build_text_tower(
|
96 |
+
embed_dim=embed_dim,
|
97 |
+
text_cfg=text_cfg,
|
98 |
+
quick_gelu=quick_gelu,
|
99 |
+
cast_dtype=cast_dtype,
|
100 |
+
)
|
101 |
+
|
102 |
+
vocab_size = (
|
103 |
+
text_cfg.vocab_size # for hf models
|
104 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
105 |
+
else text_cfg.vocab_size
|
106 |
+
)
|
107 |
+
|
108 |
+
self.visual = _build_vision_tower(
|
109 |
+
embed_dim=embed_dim,
|
110 |
+
vision_cfg=vision_cfg,
|
111 |
+
quick_gelu=quick_gelu,
|
112 |
+
cast_dtype=cast_dtype,
|
113 |
+
)
|
114 |
+
|
115 |
+
self.text_decoder = _build_text_decoder_tower(
|
116 |
+
vocab_size,
|
117 |
+
multimodal_cfg=multimodal_cfg,
|
118 |
+
quick_gelu=quick_gelu,
|
119 |
+
cast_dtype=cast_dtype,
|
120 |
+
)
|
121 |
+
|
122 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
123 |
+
self.pad_id = pad_id
|
124 |
+
|
125 |
+
@torch.jit.ignore
|
126 |
+
def set_grad_checkpointing(self, enable=True):
|
127 |
+
self.visual.set_grad_checkpointing(enable)
|
128 |
+
self.text.set_grad_checkpointing(enable)
|
129 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
130 |
+
|
131 |
+
def _encode_image(self, images, normalize=True):
|
132 |
+
image_latent, tokens_embs = self.visual(images)
|
133 |
+
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
134 |
+
return image_latent, tokens_embs
|
135 |
+
|
136 |
+
def _encode_text(self, text, normalize=True, embed_cls=True):
|
137 |
+
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
138 |
+
text_latent, token_emb = self.text(text)
|
139 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
140 |
+
return text_latent, token_emb
|
141 |
+
|
142 |
+
def encode_image(self, images, normalize=True):
|
143 |
+
image_latent, _ = self._encode_image(images, normalize=normalize)
|
144 |
+
return image_latent
|
145 |
+
|
146 |
+
def encode_text(self, text, normalize=True, embed_cls=True):
|
147 |
+
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
148 |
+
return text_latent
|
149 |
+
|
150 |
+
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
151 |
+
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
152 |
+
if image_latent is None or image_embs is None:
|
153 |
+
image_latent, image_embs = self._encode_image(image)
|
154 |
+
|
155 |
+
# TODO: add assertion to avoid bugs?
|
156 |
+
labels = text[:, -token_embs.shape[1]:]
|
157 |
+
|
158 |
+
logits = self.text_decoder(image_embs, token_embs)
|
159 |
+
return {
|
160 |
+
"image_features": image_latent,
|
161 |
+
"text_features": text_latent,
|
162 |
+
"logits": logits,
|
163 |
+
"labels": labels,
|
164 |
+
"logit_scale": self.logit_scale.exp()
|
165 |
+
}
|
166 |
+
|
167 |
+
def generate(
|
168 |
+
self,
|
169 |
+
image,
|
170 |
+
text=None,
|
171 |
+
seq_len=30,
|
172 |
+
max_seq_len=77,
|
173 |
+
temperature=1.,
|
174 |
+
generation_type="beam_search",
|
175 |
+
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
176 |
+
top_k=1, # keeps the top_k most probable tokens
|
177 |
+
pad_token_id=None,
|
178 |
+
eos_token_id=None,
|
179 |
+
sot_token_id=None,
|
180 |
+
num_beams=6,
|
181 |
+
num_beam_groups=3,
|
182 |
+
min_seq_len=5,
|
183 |
+
stopping_criteria=None,
|
184 |
+
repetition_penalty=1.0,
|
185 |
+
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
186 |
+
):
|
187 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
188 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
189 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
190 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
194 |
+
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
195 |
+
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
196 |
+
logit_processor = LogitsProcessorList(
|
197 |
+
[
|
198 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
199 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
200 |
+
]
|
201 |
+
)
|
202 |
+
|
203 |
+
if stopping_criteria is None:
|
204 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
205 |
+
|
206 |
+
stopping_criteria = StoppingCriteriaList(
|
207 |
+
stopping_criteria
|
208 |
+
)
|
209 |
+
|
210 |
+
device = image.device
|
211 |
+
|
212 |
+
if generation_type == "beam_search":
|
213 |
+
output = self._generate_beamsearch(
|
214 |
+
image_inputs = image,
|
215 |
+
pad_token_id=pad_token_id,
|
216 |
+
eos_token_id=eos_token_id,
|
217 |
+
sot_token_id=sot_token_id,
|
218 |
+
num_beams=num_beams,
|
219 |
+
num_beam_groups=num_beam_groups,
|
220 |
+
min_seq_len=min_seq_len,
|
221 |
+
stopping_criteria=stopping_criteria,
|
222 |
+
logit_processor=logit_processor,
|
223 |
+
)
|
224 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
225 |
+
return torch.cat(
|
226 |
+
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
227 |
+
dim=1
|
228 |
+
)
|
229 |
+
return output
|
230 |
+
|
231 |
+
elif generation_type == "top_p":
|
232 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
233 |
+
elif generation_type == "top_k":
|
234 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
235 |
+
else:
|
236 |
+
raise ValueError(
|
237 |
+
f"generation_type has to be one of "
|
238 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
239 |
+
)
|
240 |
+
|
241 |
+
image_latent, image_embs = self._encode_image(image)
|
242 |
+
|
243 |
+
if text is None:
|
244 |
+
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
245 |
+
|
246 |
+
was_training = self.training
|
247 |
+
num_dims = len(text.shape)
|
248 |
+
|
249 |
+
if num_dims == 1:
|
250 |
+
text = text[None, :]
|
251 |
+
|
252 |
+
cur_len = text.shape[1]
|
253 |
+
self.eval()
|
254 |
+
out = text
|
255 |
+
|
256 |
+
while True:
|
257 |
+
x = out[:, -max_seq_len:]
|
258 |
+
cur_len = x.shape[1]
|
259 |
+
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
260 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
261 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
262 |
+
|
263 |
+
if mask.all():
|
264 |
+
if not fixed_output_length:
|
265 |
+
break
|
266 |
+
else:
|
267 |
+
logits = logits[~mask, :]
|
268 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
269 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
270 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
271 |
+
|
272 |
+
if (cur_len + 1 == seq_len):
|
273 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
274 |
+
else:
|
275 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
276 |
+
|
277 |
+
out = torch.cat((out, sample), dim=-1)
|
278 |
+
|
279 |
+
cur_len += 1
|
280 |
+
|
281 |
+
if stopping_criteria(out, None):
|
282 |
+
break
|
283 |
+
|
284 |
+
if num_dims == 1:
|
285 |
+
out = out.squeeze(0)
|
286 |
+
|
287 |
+
self.train(was_training)
|
288 |
+
return out
|
289 |
+
|
290 |
+
def _generate_beamsearch(
|
291 |
+
self,
|
292 |
+
image_inputs,
|
293 |
+
pad_token_id=None,
|
294 |
+
eos_token_id=None,
|
295 |
+
sot_token_id=None,
|
296 |
+
num_beams=6,
|
297 |
+
num_beam_groups=3,
|
298 |
+
min_seq_len=5,
|
299 |
+
stopping_criteria=None,
|
300 |
+
logit_processor=None,
|
301 |
+
logit_warper=None,
|
302 |
+
):
|
303 |
+
device = image_inputs.device
|
304 |
+
batch_size = image_inputs.shape[0]
|
305 |
+
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
306 |
+
image_latent, image_embs = self._encode_image(image_inputs)
|
307 |
+
|
308 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
309 |
+
input_ids = input_ids * sot_token_id
|
310 |
+
beam_scorer = BeamSearchScorer(
|
311 |
+
batch_size=batch_size,
|
312 |
+
num_beams=num_beams,
|
313 |
+
device=device,
|
314 |
+
num_beam_groups=num_beam_groups,
|
315 |
+
)
|
316 |
+
# instantiate logits processors
|
317 |
+
logits_processor = (
|
318 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
319 |
+
if logit_processor is None
|
320 |
+
else logit_processor
|
321 |
+
)
|
322 |
+
|
323 |
+
batch_size = len(beam_scorer._beam_hyps)
|
324 |
+
num_beams = beam_scorer.num_beams
|
325 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
326 |
+
num_sub_beams = num_beams // num_beam_groups
|
327 |
+
batch_beam_size, cur_len = input_ids.shape
|
328 |
+
beam_indices = None
|
329 |
+
|
330 |
+
if num_beams * batch_size != batch_beam_size:
|
331 |
+
raise ValueError(
|
332 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
333 |
+
)
|
334 |
+
|
335 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
336 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
337 |
+
# the same group don't produce same tokens everytime.
|
338 |
+
beam_scores[:, ::num_sub_beams] = 0
|
339 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
340 |
+
|
341 |
+
while True:
|
342 |
+
|
343 |
+
# predicted tokens in cur_len step
|
344 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
345 |
+
|
346 |
+
# indices which will form the beams in the next time step
|
347 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
348 |
+
|
349 |
+
# do one decoder step on all beams of all sentences in batch
|
350 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
351 |
+
outputs = self(
|
352 |
+
model_inputs['images'],
|
353 |
+
model_inputs['text'],
|
354 |
+
embed_cls=False,
|
355 |
+
image_latent=image_latent,
|
356 |
+
image_embs=image_embs
|
357 |
+
)
|
358 |
+
|
359 |
+
for beam_group_idx in range(num_beam_groups):
|
360 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
361 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
362 |
+
group_size = group_end_idx - group_start_idx
|
363 |
+
|
364 |
+
# indices of beams of current group among all sentences in batch
|
365 |
+
batch_group_indices = []
|
366 |
+
|
367 |
+
for batch_idx in range(batch_size):
|
368 |
+
batch_group_indices.extend(
|
369 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
370 |
+
)
|
371 |
+
group_input_ids = input_ids[batch_group_indices]
|
372 |
+
|
373 |
+
# select outputs of beams of currentg group only
|
374 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
375 |
+
vocab_size = next_token_logits.shape[-1]
|
376 |
+
|
377 |
+
next_token_scores_processed = logits_processor(
|
378 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
379 |
+
)
|
380 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
381 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
382 |
+
|
383 |
+
# reshape for beam search
|
384 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
385 |
+
|
386 |
+
next_token_scores, next_tokens = torch.topk(
|
387 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
388 |
+
)
|
389 |
+
|
390 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
391 |
+
next_tokens = next_tokens % vocab_size
|
392 |
+
|
393 |
+
# stateless
|
394 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
395 |
+
beam_outputs = beam_scorer.process(
|
396 |
+
group_input_ids,
|
397 |
+
next_token_scores,
|
398 |
+
next_tokens,
|
399 |
+
next_indices,
|
400 |
+
pad_token_id=pad_token_id,
|
401 |
+
eos_token_id=eos_token_id,
|
402 |
+
beam_indices=process_beam_indices,
|
403 |
+
)
|
404 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
405 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
406 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
407 |
+
|
408 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
409 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
410 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
411 |
+
|
412 |
+
# (beam_idx // group_size) -> batch_idx
|
413 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
414 |
+
reordering_indices[batch_group_indices] = (
|
415 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
416 |
+
)
|
417 |
+
|
418 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
419 |
+
|
420 |
+
# increase cur_len
|
421 |
+
cur_len = cur_len + 1
|
422 |
+
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
423 |
+
break
|
424 |
+
|
425 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
426 |
+
sequence_outputs = beam_scorer.finalize(
|
427 |
+
input_ids,
|
428 |
+
beam_scores,
|
429 |
+
next_tokens,
|
430 |
+
next_indices,
|
431 |
+
pad_token_id=pad_token_id,
|
432 |
+
eos_token_id=eos_token_id,
|
433 |
+
max_length=stopping_criteria.max_length,
|
434 |
+
beam_indices=final_beam_indices,
|
435 |
+
)
|
436 |
+
return sequence_outputs['sequences']
|
437 |
+
|
438 |
+
|
439 |
+
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
440 |
+
if past:
|
441 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
442 |
+
|
443 |
+
attention_mask = kwargs.get("attention_mask", None)
|
444 |
+
position_ids = kwargs.get("position_ids", None)
|
445 |
+
|
446 |
+
if attention_mask is not None and position_ids is None:
|
447 |
+
# create position_ids on the fly for batch generation
|
448 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
449 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
450 |
+
else:
|
451 |
+
position_ids = None
|
452 |
+
return {
|
453 |
+
"text": input_ids,
|
454 |
+
"images": image_inputs,
|
455 |
+
"past_key_values": past,
|
456 |
+
"position_ids": position_ids,
|
457 |
+
"attention_mask": attention_mask,
|
458 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/factory.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from copy import deepcopy
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from utils.utils import instantiate_from_config
|
10 |
+
|
11 |
+
from .model import convert_to_custom_text_state_dict, resize_pos_embed
|
12 |
+
from .loss import AVCLIPLoss, ClipLoss, DistillClipLoss, CoCaLoss, MultilevelAVCLIPLoss
|
13 |
+
from .transform import image_transform
|
14 |
+
from .tokenizer import HFTokenizer, tokenize
|
15 |
+
|
16 |
+
|
17 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
18 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
19 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
20 |
+
|
21 |
+
|
22 |
+
def _natural_key(string_):
|
23 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
24 |
+
|
25 |
+
|
26 |
+
def _rescan_model_configs():
|
27 |
+
global _MODEL_CONFIGS
|
28 |
+
|
29 |
+
config_ext = ('.json',)
|
30 |
+
config_files = []
|
31 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
32 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
33 |
+
config_files.append(config_path)
|
34 |
+
elif config_path.is_dir():
|
35 |
+
for ext in config_ext:
|
36 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
37 |
+
|
38 |
+
for cf in config_files:
|
39 |
+
with open(cf, 'r') as f:
|
40 |
+
model_cfg = json.load(f)
|
41 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
42 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
43 |
+
|
44 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
45 |
+
|
46 |
+
|
47 |
+
_rescan_model_configs() # initial populate of model config registry
|
48 |
+
|
49 |
+
|
50 |
+
def list_models():
|
51 |
+
""" enumerate available model architectures based on config files """
|
52 |
+
return list(_MODEL_CONFIGS.keys())
|
53 |
+
|
54 |
+
|
55 |
+
def add_model_config(path):
|
56 |
+
""" add model config path or file and update registry """
|
57 |
+
if not isinstance(path, Path):
|
58 |
+
path = Path(path)
|
59 |
+
_MODEL_CONFIG_PATHS.append(path)
|
60 |
+
_rescan_model_configs()
|
61 |
+
|
62 |
+
|
63 |
+
def get_model_config(model_name):
|
64 |
+
if model_name in _MODEL_CONFIGS:
|
65 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
66 |
+
else:
|
67 |
+
return None
|
68 |
+
|
69 |
+
|
70 |
+
def get_tokenizer(model_name):
|
71 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
72 |
+
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
73 |
+
else:
|
74 |
+
config = get_model_config(model_name)
|
75 |
+
tokenizer = HFTokenizer(
|
76 |
+
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
77 |
+
return tokenizer
|
78 |
+
|
79 |
+
|
80 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
81 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
82 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
83 |
+
state_dict = checkpoint['state_dict']
|
84 |
+
else:
|
85 |
+
state_dict = checkpoint
|
86 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
87 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
88 |
+
return state_dict
|
89 |
+
|
90 |
+
|
91 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
92 |
+
state_dict = load_state_dict(checkpoint_path)
|
93 |
+
# detect old format and make compatible with new format
|
94 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
95 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
96 |
+
resize_pos_embed(state_dict, model)
|
97 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
98 |
+
return incompatible_keys
|
99 |
+
|
100 |
+
|
101 |
+
def create_model(cfg, device: Union[str, torch.device] = 'cpu'):
|
102 |
+
if isinstance(device, str):
|
103 |
+
device = torch.device(device)
|
104 |
+
model = instantiate_from_config(cfg.model)
|
105 |
+
model.to(device=device)
|
106 |
+
return model
|
107 |
+
|
108 |
+
|
109 |
+
def create_loss(args):
|
110 |
+
if args.distill:
|
111 |
+
return DistillClipLoss(
|
112 |
+
local_loss=args.local_loss,
|
113 |
+
gather_with_grad=args.gather_with_grad,
|
114 |
+
cache_labels=True,
|
115 |
+
rank=args.rank,
|
116 |
+
world_size=args.world_size,
|
117 |
+
)
|
118 |
+
elif 'coca' in args.model.target.split('.')[-1].lower():
|
119 |
+
return CoCaLoss(
|
120 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
121 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
122 |
+
local_loss=args.local_loss,
|
123 |
+
gather_with_grad=args.gather_with_grad,
|
124 |
+
cache_labels=True,
|
125 |
+
rank=args.rank,
|
126 |
+
world_size=args.world_size,
|
127 |
+
)
|
128 |
+
elif 'multilevel' in args.model.target.split('.')[-1].lower():
|
129 |
+
return MultilevelAVCLIPLoss(
|
130 |
+
local_loss=args.local_loss,
|
131 |
+
gather_with_grad=args.gather_with_grad,
|
132 |
+
cache_labels=True,
|
133 |
+
rank=args.rank,
|
134 |
+
world_size=args.world_size,
|
135 |
+
)
|
136 |
+
elif 'avclip' in args.model.target.split('.')[-1].lower():
|
137 |
+
return AVCLIPLoss(
|
138 |
+
local_loss=args.local_loss,
|
139 |
+
gather_with_grad=args.gather_with_grad,
|
140 |
+
cache_labels=True,
|
141 |
+
rank=args.rank,
|
142 |
+
world_size=args.world_size,
|
143 |
+
)
|
144 |
+
|
145 |
+
return ClipLoss(
|
146 |
+
local_loss=args.local_loss,
|
147 |
+
gather_with_grad=args.gather_with_grad,
|
148 |
+
cache_labels=True,
|
149 |
+
rank=args.rank,
|
150 |
+
world_size=args.world_size,
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def create_model_from_pretrained(
|
155 |
+
model_name: str,
|
156 |
+
pretrained: Optional[str] = None,
|
157 |
+
precision: str = 'fp32',
|
158 |
+
device: Union[str, torch.device] = 'cpu',
|
159 |
+
jit: bool = False,
|
160 |
+
force_quick_gelu: bool = False,
|
161 |
+
force_custom_text: bool = False,
|
162 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
163 |
+
return_transform: bool = True,
|
164 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
165 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
166 |
+
cache_dir: Optional[str] = None,
|
167 |
+
):
|
168 |
+
model = create_model(
|
169 |
+
model_name,
|
170 |
+
pretrained,
|
171 |
+
precision=precision,
|
172 |
+
device=device,
|
173 |
+
jit=jit,
|
174 |
+
force_quick_gelu=force_quick_gelu,
|
175 |
+
force_custom_text=force_custom_text,
|
176 |
+
force_image_size=force_image_size,
|
177 |
+
cache_dir=cache_dir,
|
178 |
+
require_pretrained=True,
|
179 |
+
)
|
180 |
+
|
181 |
+
if not return_transform:
|
182 |
+
return model
|
183 |
+
|
184 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
185 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
186 |
+
preprocess = image_transform(
|
187 |
+
model.visual.image_size,
|
188 |
+
is_train=False,
|
189 |
+
mean=image_mean,
|
190 |
+
std=image_std,
|
191 |
+
)
|
192 |
+
|
193 |
+
return model, preprocess
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/generation_utils.py
ADDED
File without changes
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_configs.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings"
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings"
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens"
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_model.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch import TensorType
|
11 |
+
|
12 |
+
try:
|
13 |
+
import transformers
|
14 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
15 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
16 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
17 |
+
except ImportError as e:
|
18 |
+
transformers = None
|
19 |
+
|
20 |
+
|
21 |
+
class BaseModelOutput:
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class PretrainedConfig:
|
26 |
+
pass
|
27 |
+
|
28 |
+
from .hf_configs import arch_dict
|
29 |
+
|
30 |
+
|
31 |
+
# utils
|
32 |
+
def _camel2snake(s):
|
33 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
34 |
+
|
35 |
+
|
36 |
+
# TODO: ?last - for gpt-like models
|
37 |
+
_POOLERS = {}
|
38 |
+
|
39 |
+
|
40 |
+
def register_pooler(cls):
|
41 |
+
"""Decorator registering pooler class"""
|
42 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
43 |
+
return cls
|
44 |
+
|
45 |
+
|
46 |
+
@register_pooler
|
47 |
+
class MeanPooler(nn.Module):
|
48 |
+
"""Mean pooling"""
|
49 |
+
|
50 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
51 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
52 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
53 |
+
|
54 |
+
|
55 |
+
@register_pooler
|
56 |
+
class MaxPooler(nn.Module):
|
57 |
+
"""Max pooling"""
|
58 |
+
|
59 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
60 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
61 |
+
return masked_output.max(1).values
|
62 |
+
|
63 |
+
|
64 |
+
@register_pooler
|
65 |
+
class ClsPooler(nn.Module):
|
66 |
+
"""CLS token pooling"""
|
67 |
+
|
68 |
+
def __init__(self, use_pooler_output=True):
|
69 |
+
super().__init__()
|
70 |
+
self.cls_token_position = 0
|
71 |
+
self.use_pooler_output = use_pooler_output
|
72 |
+
|
73 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
74 |
+
if (self.use_pooler_output and
|
75 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
76 |
+
(x.pooler_output is not None)
|
77 |
+
):
|
78 |
+
return x.pooler_output
|
79 |
+
|
80 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
81 |
+
|
82 |
+
|
83 |
+
class HFTextEncoder(nn.Module):
|
84 |
+
"""HuggingFace model adapter"""
|
85 |
+
output_tokens: torch.jit.Final[bool]
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
model_name_or_path: str,
|
90 |
+
output_dim: int,
|
91 |
+
config: PretrainedConfig = None,
|
92 |
+
pooler_type: str = None,
|
93 |
+
proj: str = None,
|
94 |
+
pretrained: bool = True,
|
95 |
+
output_tokens: bool = False,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
self.output_tokens = output_tokens
|
99 |
+
self.output_dim = output_dim
|
100 |
+
|
101 |
+
# TODO: find better way to get this information
|
102 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
103 |
+
|
104 |
+
if transformers is None:
|
105 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
106 |
+
if config is None:
|
107 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
108 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
109 |
+
AutoModel.from_config, self.config)
|
110 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
111 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
112 |
+
self.transformer = create_func(model_args)
|
113 |
+
self.transformer = self.transformer.encoder
|
114 |
+
else:
|
115 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
116 |
+
else:
|
117 |
+
self.config = config
|
118 |
+
self.transformer = AutoModel.from_config(config)
|
119 |
+
if pooler_type is None: # get default arch pooler
|
120 |
+
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
121 |
+
|
122 |
+
self.pooler = _POOLERS[pooler_type]()
|
123 |
+
|
124 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
125 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
126 |
+
self.proj = nn.Identity()
|
127 |
+
elif proj == 'linear':
|
128 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
129 |
+
elif proj == 'mlp':
|
130 |
+
hidden_size = (d_model + output_dim) // 2
|
131 |
+
self.proj = nn.Sequential(
|
132 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
133 |
+
nn.GELU(),
|
134 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x: TensorType):
|
138 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
139 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
140 |
+
pooled_out = self.pooler(out, attn_mask)
|
141 |
+
projected = self.proj(pooled_out)
|
142 |
+
|
143 |
+
seq_len = out.last_hidden_state.shape[1]
|
144 |
+
tokens = (
|
145 |
+
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
146 |
+
if type(self.pooler) == ClsPooler
|
147 |
+
else out.last_hidden_state
|
148 |
+
)
|
149 |
+
|
150 |
+
if self.output_tokens:
|
151 |
+
return projected, tokens
|
152 |
+
return projected
|
153 |
+
|
154 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
155 |
+
if not unlocked_layers: # full freezing
|
156 |
+
for n, p in self.transformer.named_parameters():
|
157 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
158 |
+
return
|
159 |
+
|
160 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
161 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
162 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
163 |
+
embeddings = getattr(
|
164 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
165 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
166 |
+
# freeze layers
|
167 |
+
for module in modules:
|
168 |
+
for n, p in module.named_parameters():
|
169 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
170 |
+
|
171 |
+
@torch.jit.ignore
|
172 |
+
def set_grad_checkpointing(self, enable=True):
|
173 |
+
self.transformer.gradient_checkpointing_enable()
|
174 |
+
|
175 |
+
def init_parameters(self):
|
176 |
+
pass
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/loss.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
try:
|
6 |
+
import torch.distributed.nn
|
7 |
+
from torch import distributed as dist
|
8 |
+
|
9 |
+
has_distributed = True
|
10 |
+
except ImportError:
|
11 |
+
has_distributed = False
|
12 |
+
|
13 |
+
|
14 |
+
def gather_features(
|
15 |
+
image_features,
|
16 |
+
text_features,
|
17 |
+
local_loss=False,
|
18 |
+
gather_with_grad=False,
|
19 |
+
rank=0,
|
20 |
+
world_size=1,
|
21 |
+
):
|
22 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
23 |
+
# We gather tensors from all gpus
|
24 |
+
if gather_with_grad:
|
25 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
26 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
27 |
+
else:
|
28 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
29 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
30 |
+
dist.all_gather(gathered_image_features, image_features)
|
31 |
+
dist.all_gather(gathered_text_features, text_features)
|
32 |
+
if not local_loss:
|
33 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
34 |
+
gathered_image_features[rank] = image_features
|
35 |
+
gathered_text_features[rank] = text_features
|
36 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
37 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
38 |
+
|
39 |
+
return all_image_features, all_text_features
|
40 |
+
|
41 |
+
|
42 |
+
class ClipLoss(nn.Module):
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
local_loss=False,
|
47 |
+
gather_with_grad=False,
|
48 |
+
cache_labels=False,
|
49 |
+
rank=0,
|
50 |
+
world_size=1,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
self.local_loss = local_loss
|
54 |
+
self.gather_with_grad = gather_with_grad
|
55 |
+
self.cache_labels = cache_labels
|
56 |
+
self.rank = rank
|
57 |
+
self.world_size = world_size
|
58 |
+
|
59 |
+
# cache state
|
60 |
+
self.prev_num_logits = 0
|
61 |
+
self.labels = {}
|
62 |
+
|
63 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
64 |
+
# calculated ground-truth and cache if enabled
|
65 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
66 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
67 |
+
if self.world_size > 1 and self.local_loss:
|
68 |
+
labels = labels + num_logits * self.rank
|
69 |
+
if self.cache_labels:
|
70 |
+
self.labels[device] = labels
|
71 |
+
self.prev_num_logits = num_logits
|
72 |
+
else:
|
73 |
+
labels = self.labels[device]
|
74 |
+
return labels
|
75 |
+
|
76 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
77 |
+
if self.world_size > 1:
|
78 |
+
all_image_features, all_text_features = gather_features(
|
79 |
+
image_features, text_features,
|
80 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size)
|
81 |
+
|
82 |
+
if self.local_loss:
|
83 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
84 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
85 |
+
else:
|
86 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
87 |
+
logits_per_text = logits_per_image.T
|
88 |
+
else:
|
89 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
90 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
91 |
+
|
92 |
+
return logits_per_image, logits_per_text
|
93 |
+
|
94 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
95 |
+
device = image_features.device
|
96 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
97 |
+
|
98 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
99 |
+
|
100 |
+
total_loss = (
|
101 |
+
F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)
|
102 |
+
) / 2
|
103 |
+
|
104 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
105 |
+
|
106 |
+
|
107 |
+
class CoCaLoss(ClipLoss):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
caption_loss_weight,
|
111 |
+
clip_loss_weight,
|
112 |
+
pad_id=0, # pad_token for open_clip custom tokenizer
|
113 |
+
local_loss=False,
|
114 |
+
gather_with_grad=False,
|
115 |
+
cache_labels=False,
|
116 |
+
rank=0,
|
117 |
+
world_size=1,
|
118 |
+
):
|
119 |
+
super().__init__(
|
120 |
+
local_loss=local_loss,
|
121 |
+
gather_with_grad=gather_with_grad,
|
122 |
+
cache_labels=cache_labels,
|
123 |
+
rank=rank,
|
124 |
+
world_size=world_size,
|
125 |
+
)
|
126 |
+
|
127 |
+
self.clip_loss_weight = clip_loss_weight
|
128 |
+
self.caption_loss_weight = caption_loss_weight
|
129 |
+
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
130 |
+
|
131 |
+
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
132 |
+
clip_loss = super().forward(image_features, text_features, logit_scale)
|
133 |
+
clip_loss = self.clip_loss_weight * clip_loss
|
134 |
+
|
135 |
+
caption_loss = self.caption_loss(
|
136 |
+
logits.permute(0, 2, 1),
|
137 |
+
labels,
|
138 |
+
)
|
139 |
+
caption_loss = caption_loss * self.caption_loss_weight
|
140 |
+
|
141 |
+
if output_dict:
|
142 |
+
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
143 |
+
|
144 |
+
return clip_loss, caption_loss
|
145 |
+
|
146 |
+
|
147 |
+
class DistillClipLoss(ClipLoss):
|
148 |
+
|
149 |
+
def dist_loss(self, teacher_logits, student_logits):
|
150 |
+
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
image_features,
|
155 |
+
text_features,
|
156 |
+
logit_scale,
|
157 |
+
dist_image_features,
|
158 |
+
dist_text_features,
|
159 |
+
dist_logit_scale,
|
160 |
+
output_dict=False,
|
161 |
+
):
|
162 |
+
logits_per_image, logits_per_text = \
|
163 |
+
self.get_logits(image_features, text_features, logit_scale)
|
164 |
+
|
165 |
+
dist_logits_per_image, dist_logits_per_text = \
|
166 |
+
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
167 |
+
|
168 |
+
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
169 |
+
|
170 |
+
contrastive_loss = (
|
171 |
+
F.cross_entropy(logits_per_image, labels) +
|
172 |
+
F.cross_entropy(logits_per_text, labels)
|
173 |
+
) / 2
|
174 |
+
|
175 |
+
distill_loss = (
|
176 |
+
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
177 |
+
self.dist_loss(dist_logits_per_text, logits_per_text)
|
178 |
+
) / 2
|
179 |
+
|
180 |
+
if output_dict:
|
181 |
+
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
182 |
+
|
183 |
+
return contrastive_loss, distill_loss
|
184 |
+
|
185 |
+
|
186 |
+
class AVCLIPLoss(ClipLoss):
|
187 |
+
'''This loss is resembles the CLIP loss, but it simply renames the variables'''
|
188 |
+
|
189 |
+
def __init__(self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1):
|
190 |
+
super().__init__(local_loss, gather_with_grad, cache_labels, rank, world_size)
|
191 |
+
|
192 |
+
def forward(self, rgb_features, audio_features, logit_scale, output_dict=False):
|
193 |
+
return super().forward(rgb_features, audio_features, logit_scale, output_dict)
|
194 |
+
|
195 |
+
|
196 |
+
class MultilevelAVCLIPLoss(nn.Module):
|
197 |
+
|
198 |
+
def __init__(self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1):
|
199 |
+
super().__init__()
|
200 |
+
self.segment_avclip_loss = AVCLIPLoss(
|
201 |
+
local_loss, gather_with_grad, cache_labels, rank, world_size)
|
202 |
+
self.global_avclip_loss = AVCLIPLoss(
|
203 |
+
local_loss, gather_with_grad, cache_labels, rank, world_size)
|
204 |
+
|
205 |
+
def forward(self, rgb_features, audio_features, logit_scales, output_dict=False):
|
206 |
+
segment_rgb, global_rgb = rgb_features
|
207 |
+
segment_audio, global_audio = audio_features
|
208 |
+
segment_logit_scale, global_logit_scale = logit_scales
|
209 |
+
|
210 |
+
assert segment_rgb.shape[0] == segment_audio.shape[0], f'{segment_rgb.shape} != {segment_audio.shape}'
|
211 |
+
B, S, D = segment_rgb.shape
|
212 |
+
|
213 |
+
# (B*S, D) <- (B, S, D)
|
214 |
+
segment_loss = self.segment_avclip_loss(segment_rgb.view(B*S, D), segment_audio.view(B*S, D),
|
215 |
+
segment_logit_scale, output_dict)
|
216 |
+
global_loss = None
|
217 |
+
if global_rgb is not None:
|
218 |
+
global_loss = self.global_avclip_loss(global_rgb, global_audio, global_logit_scale, output_dict)
|
219 |
+
|
220 |
+
if output_dict:
|
221 |
+
losses = {'segment_contrastive_loss': segment_loss['contrastive_loss']}
|
222 |
+
if global_rgb is not None:
|
223 |
+
losses['global_contrastive_loss'] = global_loss['contrastive_loss']
|
224 |
+
return losses
|
225 |
+
else:
|
226 |
+
if global_loss is None:
|
227 |
+
return segment_loss
|
228 |
+
else:
|
229 |
+
return segment_loss, global_loss
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model.py
ADDED
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from utils.utils import instantiate_from_config
|
17 |
+
|
18 |
+
from .hf_model import HFTextEncoder
|
19 |
+
from .modified_resnet import ModifiedResNet
|
20 |
+
from .timm_model import TimmModel
|
21 |
+
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
22 |
+
from .utils import to_2tuple
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class CLIPVisionCfg:
|
27 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
28 |
+
width: int = 768
|
29 |
+
head_width: int = 64
|
30 |
+
mlp_ratio: float = 4.0
|
31 |
+
patch_size: int = 16
|
32 |
+
image_size: Union[Tuple[int, int], int] = 224
|
33 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
34 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
35 |
+
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
36 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
37 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
38 |
+
n_queries: int = 256 # n_queries for attentional pooler
|
39 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
40 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
41 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
42 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
43 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
44 |
+
timm_proj_bias: bool = False # enable bias final projection
|
45 |
+
timm_drop: float = 0. # head dropout
|
46 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
47 |
+
output_tokens: bool = False
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class CLIPTextCfg:
|
52 |
+
context_length: int = 77
|
53 |
+
vocab_size: int = 49408
|
54 |
+
width: int = 512
|
55 |
+
heads: int = 8
|
56 |
+
layers: int = 12
|
57 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
58 |
+
hf_model_name: str = None
|
59 |
+
hf_tokenizer_name: str = None
|
60 |
+
hf_model_pretrained: bool = True
|
61 |
+
proj: str = 'mlp'
|
62 |
+
pooler_type: str = 'mean_pooler'
|
63 |
+
embed_cls: bool = False
|
64 |
+
pad_id: int = 0
|
65 |
+
output_tokens: bool = False
|
66 |
+
|
67 |
+
|
68 |
+
def get_cast_dtype(precision: str):
|
69 |
+
cast_dtype = None
|
70 |
+
if precision == 'bf16':
|
71 |
+
cast_dtype = torch.bfloat16
|
72 |
+
elif precision == 'fp16':
|
73 |
+
cast_dtype = torch.float16
|
74 |
+
return cast_dtype
|
75 |
+
|
76 |
+
|
77 |
+
def _build_vision_tower(
|
78 |
+
embed_dim: int,
|
79 |
+
vision_cfg: CLIPVisionCfg,
|
80 |
+
quick_gelu: bool = False,
|
81 |
+
cast_dtype: Optional[torch.dtype] = None
|
82 |
+
):
|
83 |
+
if isinstance(vision_cfg, dict):
|
84 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
85 |
+
|
86 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
87 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
88 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
89 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
90 |
+
|
91 |
+
if vision_cfg.timm_model_name:
|
92 |
+
visual = TimmModel(
|
93 |
+
vision_cfg.timm_model_name,
|
94 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
95 |
+
pool=vision_cfg.timm_pool,
|
96 |
+
proj=vision_cfg.timm_proj,
|
97 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
98 |
+
drop=vision_cfg.timm_drop,
|
99 |
+
drop_path=vision_cfg.timm_drop_path,
|
100 |
+
embed_dim=embed_dim,
|
101 |
+
image_size=vision_cfg.image_size,
|
102 |
+
)
|
103 |
+
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
104 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
105 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
106 |
+
visual = ModifiedResNet(
|
107 |
+
layers=vision_cfg.layers,
|
108 |
+
output_dim=embed_dim,
|
109 |
+
heads=vision_heads,
|
110 |
+
image_size=vision_cfg.image_size,
|
111 |
+
width=vision_cfg.width,
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
115 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
116 |
+
visual = VisionTransformer(
|
117 |
+
image_size=vision_cfg.image_size,
|
118 |
+
patch_size=vision_cfg.patch_size,
|
119 |
+
width=vision_cfg.width,
|
120 |
+
layers=vision_cfg.layers,
|
121 |
+
heads=vision_heads,
|
122 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
123 |
+
ls_init_value=vision_cfg.ls_init_value,
|
124 |
+
patch_dropout=vision_cfg.patch_dropout,
|
125 |
+
input_patchnorm=vision_cfg.input_patchnorm,
|
126 |
+
global_average_pool=vision_cfg.global_average_pool,
|
127 |
+
attentional_pool=vision_cfg.attentional_pool,
|
128 |
+
n_queries=vision_cfg.n_queries,
|
129 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
130 |
+
output_tokens=vision_cfg.output_tokens,
|
131 |
+
output_dim=embed_dim,
|
132 |
+
act_layer=act_layer,
|
133 |
+
norm_layer=norm_layer,
|
134 |
+
)
|
135 |
+
|
136 |
+
return visual
|
137 |
+
|
138 |
+
|
139 |
+
def _build_text_tower(
|
140 |
+
embed_dim: int,
|
141 |
+
text_cfg: CLIPTextCfg,
|
142 |
+
quick_gelu: bool = False,
|
143 |
+
cast_dtype: Optional[torch.dtype] = None,
|
144 |
+
):
|
145 |
+
if isinstance(text_cfg, dict):
|
146 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
147 |
+
|
148 |
+
if text_cfg.hf_model_name:
|
149 |
+
text = HFTextEncoder(
|
150 |
+
text_cfg.hf_model_name,
|
151 |
+
output_dim=embed_dim,
|
152 |
+
proj=text_cfg.proj,
|
153 |
+
pooler_type=text_cfg.pooler_type,
|
154 |
+
pretrained=text_cfg.hf_model_pretrained,
|
155 |
+
output_tokens=text_cfg.output_tokens,
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
159 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
160 |
+
|
161 |
+
text = TextTransformer(
|
162 |
+
context_length=text_cfg.context_length,
|
163 |
+
vocab_size=text_cfg.vocab_size,
|
164 |
+
width=text_cfg.width,
|
165 |
+
heads=text_cfg.heads,
|
166 |
+
layers=text_cfg.layers,
|
167 |
+
ls_init_value=text_cfg.ls_init_value,
|
168 |
+
output_dim=embed_dim,
|
169 |
+
embed_cls=text_cfg.embed_cls,
|
170 |
+
output_tokens=text_cfg.output_tokens,
|
171 |
+
pad_id=text_cfg.pad_id,
|
172 |
+
act_layer=act_layer,
|
173 |
+
norm_layer=norm_layer,
|
174 |
+
)
|
175 |
+
return text
|
176 |
+
|
177 |
+
|
178 |
+
class CLIP(nn.Module):
|
179 |
+
output_dict: torch.jit.Final[bool]
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
embed_dim: int,
|
184 |
+
vision_cfg: CLIPVisionCfg,
|
185 |
+
text_cfg: CLIPTextCfg,
|
186 |
+
quick_gelu: bool = False,
|
187 |
+
cast_dtype: Optional[torch.dtype] = None,
|
188 |
+
output_dict: bool = False,
|
189 |
+
):
|
190 |
+
super().__init__()
|
191 |
+
self.output_dict = output_dict
|
192 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
193 |
+
|
194 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
195 |
+
self.transformer = text.transformer
|
196 |
+
self.vocab_size = text.vocab_size
|
197 |
+
self.token_embedding = text.token_embedding
|
198 |
+
self.positional_embedding = text.positional_embedding
|
199 |
+
self.ln_final = text.ln_final
|
200 |
+
self.text_projection = text.text_projection
|
201 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
202 |
+
|
203 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
204 |
+
|
205 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
206 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
207 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
208 |
+
|
209 |
+
@torch.jit.ignore
|
210 |
+
def set_grad_checkpointing(self, enable=True):
|
211 |
+
self.visual.set_grad_checkpointing(enable)
|
212 |
+
self.transformer.grad_checkpointing = enable
|
213 |
+
|
214 |
+
def encode_image(self, image, normalize: bool = False):
|
215 |
+
features = self.visual(image)
|
216 |
+
return F.normalize(features, dim=-1) if normalize else features
|
217 |
+
|
218 |
+
def encode_text(self, text, normalize: bool = False):
|
219 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
220 |
+
|
221 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
222 |
+
|
223 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
224 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
225 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
226 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
227 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
228 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
229 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
230 |
+
return F.normalize(x, dim=-1) if normalize else x
|
231 |
+
|
232 |
+
def forward(self, image, text):
|
233 |
+
image_features = self.encode_image(image, normalize=True)
|
234 |
+
text_features = self.encode_text(text, normalize=True)
|
235 |
+
if self.output_dict:
|
236 |
+
return {
|
237 |
+
"image_features": image_features,
|
238 |
+
"text_features": text_features,
|
239 |
+
"logit_scale": self.logit_scale.exp()
|
240 |
+
}
|
241 |
+
return image_features, text_features, self.logit_scale.exp()
|
242 |
+
|
243 |
+
|
244 |
+
class CustomTextCLIP(nn.Module):
|
245 |
+
output_dict: torch.jit.Final[bool]
|
246 |
+
|
247 |
+
def __init__(
|
248 |
+
self,
|
249 |
+
embed_dim: int,
|
250 |
+
vision_cfg: CLIPVisionCfg,
|
251 |
+
text_cfg: CLIPTextCfg,
|
252 |
+
quick_gelu: bool = False,
|
253 |
+
cast_dtype: Optional[torch.dtype] = None,
|
254 |
+
output_dict: bool = False,
|
255 |
+
):
|
256 |
+
super().__init__()
|
257 |
+
self.output_dict = output_dict
|
258 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
259 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
260 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
261 |
+
|
262 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
263 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
264 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
265 |
+
|
266 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
267 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
268 |
+
|
269 |
+
@torch.jit.ignore
|
270 |
+
def set_grad_checkpointing(self, enable=True):
|
271 |
+
self.visual.set_grad_checkpointing(enable)
|
272 |
+
self.text.set_grad_checkpointing(enable)
|
273 |
+
|
274 |
+
def encode_image(self, image, normalize: bool = False):
|
275 |
+
features = self.visual(image)
|
276 |
+
return F.normalize(features, dim=-1) if normalize else features
|
277 |
+
|
278 |
+
def encode_text(self, text, normalize: bool = False):
|
279 |
+
features = self.text(text)
|
280 |
+
return F.normalize(features, dim=-1) if normalize else features
|
281 |
+
|
282 |
+
def forward(self, image, text):
|
283 |
+
image_features = self.encode_image(image, normalize=True)
|
284 |
+
text_features = self.encode_text(text, normalize=True)
|
285 |
+
if self.output_dict:
|
286 |
+
return {
|
287 |
+
"image_features": image_features,
|
288 |
+
"text_features": text_features,
|
289 |
+
"logit_scale": self.logit_scale.exp()
|
290 |
+
}
|
291 |
+
return image_features, text_features, self.logit_scale.exp()
|
292 |
+
|
293 |
+
|
294 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
295 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
296 |
+
|
297 |
+
def _convert_weights(l):
|
298 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
299 |
+
l.weight.data = l.weight.data.to(dtype)
|
300 |
+
if l.bias is not None:
|
301 |
+
l.bias.data = l.bias.data.to(dtype)
|
302 |
+
|
303 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
304 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
305 |
+
tensor = getattr(l, attr)
|
306 |
+
if tensor is not None:
|
307 |
+
tensor.data = tensor.data.to(dtype)
|
308 |
+
|
309 |
+
for name in ["text_projection", "proj"]:
|
310 |
+
if hasattr(l, name):
|
311 |
+
attr = getattr(l, name)
|
312 |
+
if attr is not None:
|
313 |
+
attr.data = attr.data.to(dtype)
|
314 |
+
|
315 |
+
model.apply(_convert_weights)
|
316 |
+
|
317 |
+
|
318 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
319 |
+
|
320 |
+
|
321 |
+
# used to maintain checkpoint compatibility
|
322 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
323 |
+
if 'text_projection' in state_dict:
|
324 |
+
# old format state_dict, move text tower -> .text
|
325 |
+
new_state_dict = {}
|
326 |
+
for k, v in state_dict.items():
|
327 |
+
if any(k.startswith(p) for p in (
|
328 |
+
'text_projection',
|
329 |
+
'positional_embedding',
|
330 |
+
'token_embedding',
|
331 |
+
'transformer',
|
332 |
+
'ln_final',
|
333 |
+
)):
|
334 |
+
k = 'text.' + k
|
335 |
+
new_state_dict[k] = v
|
336 |
+
return new_state_dict
|
337 |
+
return state_dict
|
338 |
+
|
339 |
+
|
340 |
+
def build_model_from_openai_state_dict(
|
341 |
+
state_dict: dict,
|
342 |
+
quick_gelu=True,
|
343 |
+
cast_dtype=torch.float16,
|
344 |
+
):
|
345 |
+
vit = "visual.proj" in state_dict
|
346 |
+
|
347 |
+
if vit:
|
348 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
349 |
+
vision_layers = len(
|
350 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
351 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
352 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
353 |
+
image_size = vision_patch_size * grid_size
|
354 |
+
else:
|
355 |
+
counts: list = [
|
356 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
357 |
+
vision_layers = tuple(counts)
|
358 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
359 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
360 |
+
vision_patch_size = None
|
361 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
362 |
+
image_size = output_width * 32
|
363 |
+
|
364 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
365 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
366 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
367 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
368 |
+
transformer_heads = transformer_width // 64
|
369 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
370 |
+
|
371 |
+
vision_cfg = CLIPVisionCfg(
|
372 |
+
layers=vision_layers,
|
373 |
+
width=vision_width,
|
374 |
+
patch_size=vision_patch_size,
|
375 |
+
image_size=image_size,
|
376 |
+
)
|
377 |
+
text_cfg = CLIPTextCfg(
|
378 |
+
context_length=context_length,
|
379 |
+
vocab_size=vocab_size,
|
380 |
+
width=transformer_width,
|
381 |
+
heads=transformer_heads,
|
382 |
+
layers=transformer_layers,
|
383 |
+
)
|
384 |
+
model = CLIP(
|
385 |
+
embed_dim,
|
386 |
+
vision_cfg=vision_cfg,
|
387 |
+
text_cfg=text_cfg,
|
388 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
389 |
+
cast_dtype=cast_dtype,
|
390 |
+
)
|
391 |
+
|
392 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
393 |
+
state_dict.pop(key, None)
|
394 |
+
|
395 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
396 |
+
model.load_state_dict(state_dict)
|
397 |
+
return model.eval()
|
398 |
+
|
399 |
+
|
400 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
401 |
+
model.eval()
|
402 |
+
image_size = model.visual.image_size
|
403 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
404 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
405 |
+
model = torch.jit.trace_module(
|
406 |
+
model,
|
407 |
+
inputs=dict(
|
408 |
+
forward=(example_images, example_text),
|
409 |
+
encode_text=(example_text,),
|
410 |
+
encode_image=(example_images,)
|
411 |
+
))
|
412 |
+
model.visual.image_size = image_size
|
413 |
+
return model
|
414 |
+
|
415 |
+
|
416 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
417 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
418 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
419 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
420 |
+
return
|
421 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
422 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
423 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
424 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
425 |
+
return
|
426 |
+
|
427 |
+
if extra_tokens:
|
428 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
429 |
+
else:
|
430 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
431 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
432 |
+
|
433 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
434 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
435 |
+
pos_emb_img = F.interpolate(
|
436 |
+
pos_emb_img,
|
437 |
+
size=grid_size,
|
438 |
+
mode=interpolation,
|
439 |
+
antialias=antialias,
|
440 |
+
align_corners=False,
|
441 |
+
)
|
442 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
443 |
+
if pos_emb_tok is not None:
|
444 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
445 |
+
else:
|
446 |
+
new_pos_embed = pos_emb_img
|
447 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
448 |
+
|
449 |
+
class AVCLIP(nn.Module):
|
450 |
+
|
451 |
+
def __init__(self, n_embd: int, afeat_extractor: OmegaConf, vfeat_extractor: OmegaConf,
|
452 |
+
aproj: OmegaConf, vproj: OmegaConf, init_scale: float = 0.07, clamp_scale_min: float = 0.001,
|
453 |
+
clamp_scale_max: float = 0.5, gather_for_loss: bool = False):
|
454 |
+
super().__init__()
|
455 |
+
self.output_dict = True
|
456 |
+
self.n_embd = n_embd
|
457 |
+
|
458 |
+
# loading audio and rgb towers
|
459 |
+
self.v_encoder = instantiate_from_config(vfeat_extractor)
|
460 |
+
self.a_encoder = instantiate_from_config(afeat_extractor)
|
461 |
+
# loading audio and rgb towers and projection layers to account for different feature dimensions
|
462 |
+
self.aproj = instantiate_from_config(aproj)
|
463 |
+
self.vproj = instantiate_from_config(vproj)
|
464 |
+
|
465 |
+
self.clamp_scale_min, self.clamp_scale_max = clamp_scale_min, clamp_scale_max
|
466 |
+
self.init_scale = init_scale
|
467 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * self.init_scale) # NOTE: exp(1/OpenCLIP)
|
468 |
+
|
469 |
+
self.gather_for_loss = gather_for_loss
|
470 |
+
|
471 |
+
# self.ln_final = text.ln_final # perhaps only useful for transformer towers
|
472 |
+
# self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
473 |
+
|
474 |
+
def forward(self, vis: torch.Tensor, aud: torch.Tensor, alpha: float = 0.0, for_loop: bool = False,
|
475 |
+
world_size=1):
|
476 |
+
'''
|
477 |
+
Args:
|
478 |
+
vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
|
479 |
+
aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
|
480 |
+
alpha (float): linear interpolation coefficient for pseudo-targets (with 0.0 targets are 1-hot)
|
481 |
+
for_loop (bool): whether to process each segment in a for loop or all at once
|
482 |
+
(speed-memory tradeoff)
|
483 |
+
Returns:
|
484 |
+
rgb_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None RGB features
|
485 |
+
audio_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None audio features
|
486 |
+
logit_scale (tuple(torch.Tensor)): local and global logit scales (1, )
|
487 |
+
'''
|
488 |
+
assert alpha == 0.0, f'alpha={alpha} not supported yet'
|
489 |
+
logit_scales = self.clamp_logit_scales()
|
490 |
+
vfeat, _, afeat, _ = self.encode_streams(vis, aud, for_loop, do_norm=True)
|
491 |
+
|
492 |
+
if world_size > 1 and self.gather_for_loss: # gather all features
|
493 |
+
vfeat_all = torch.cat(torch.distributed.nn.all_gather(vfeat), dim=0)
|
494 |
+
afeat_all = torch.cat(torch.distributed.nn.all_gather(afeat), dim=0)
|
495 |
+
else:
|
496 |
+
vfeat_all = vfeat
|
497 |
+
afeat_all = afeat
|
498 |
+
|
499 |
+
loss_avc, _ = self.compute_loss(vfeat, afeat, vfeat_all.mT, afeat_all.mT, self.logit_scale, alpha=0)
|
500 |
+
out = {
|
501 |
+
'rgb_features': (vfeat, None), 'audio_features': (afeat, None),
|
502 |
+
'logit_scales': logit_scales,
|
503 |
+
'losses': {'segment_contrastive_loss': loss_avc},
|
504 |
+
}
|
505 |
+
return out
|
506 |
+
|
507 |
+
def compute_loss(self, vfeat, afeat, vfeat_all, afeat_all, scale, alpha=0.0, vfeat_m=None, afeat_m=None):
|
508 |
+
'''For Multi-level contrastive learning, the losses are made the same way for all levels'''
|
509 |
+
sim_v2a = vfeat @ afeat_all / scale
|
510 |
+
sim_a2v = afeat @ vfeat_all / scale
|
511 |
+
sim_v2a_t, sim_a2v_t = self._make_targets(sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m)
|
512 |
+
loss = self._loss(sim_v2a, sim_a2v, sim_v2a_t, sim_a2v_t)
|
513 |
+
return loss, (sim_v2a, sim_a2v)
|
514 |
+
|
515 |
+
@torch.no_grad()
|
516 |
+
def _make_targets(self, sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m):
|
517 |
+
# NOTE: for simplicity, we assume that sim_v2a.shape[0] == sim_a2v.shape[0]
|
518 |
+
# NOTE: sim_targets is not square (sim_v2a.shape is (bsize, bsize+Qsize) )
|
519 |
+
sim_targets = torch.eye(*sim_v2a.shape, device=sim_v2a.device, dtype=sim_v2a.dtype)
|
520 |
+
sim_v2a_targets = sim_targets
|
521 |
+
sim_a2v_targets = sim_targets
|
522 |
+
return sim_v2a_targets, sim_a2v_targets
|
523 |
+
|
524 |
+
def _loss(self, sim_v2a, sim_a2v, sim_v2a_targets, sim_a2v_targets):
|
525 |
+
loss_v2a = F.cross_entropy(sim_v2a, sim_v2a_targets)
|
526 |
+
loss_a2v = F.cross_entropy(sim_a2v, sim_a2v_targets)
|
527 |
+
return (loss_v2a + loss_a2v) / 2
|
528 |
+
|
529 |
+
def encode_streams(self, vis, aud, for_loop, do_norm=True):
|
530 |
+
# (B*S, D), (B, D) or None; because `flatten_to_2D = True`
|
531 |
+
flatten_to_2D = True
|
532 |
+
vfeat, _ = self.encode_stream(vis, self.v_encoder, self.vproj, do_norm, flatten_to_2D, for_loop)
|
533 |
+
afeat, _ = self.encode_stream(aud, self.a_encoder, self.aproj, do_norm, flatten_to_2D, for_loop)
|
534 |
+
return vfeat, None, afeat, None
|
535 |
+
|
536 |
+
def encode_stream(self, x, feat_extractor_fn, proj_fn, do_norm, flatten_to_2D, for_loop):
|
537 |
+
# x is (B, S, ...)
|
538 |
+
segment_x, _ = feat_extractor_fn(x, for_loop) # segment_x: (B, S, D), global_x: (B, D)
|
539 |
+
if flatten_to_2D:
|
540 |
+
B, S, D = segment_x.shape
|
541 |
+
segment_x = segment_x.view(B*S, D) # flatten batch and segment dims
|
542 |
+
segment_x = proj_fn(segment_x)
|
543 |
+
segment_x = F.normalize(segment_x, dim=-1) if do_norm else segment_x
|
544 |
+
# do_global is passed in to avoid computing global features when not needed (e.g. during eval)
|
545 |
+
return segment_x, None # (B*S, D), (B, D) or None
|
546 |
+
|
547 |
+
def forward_for_logging(self, vis, aud, for_momentum=False, for_loop=False, do_norm=True):
|
548 |
+
'''
|
549 |
+
Runs the forward pass but keeps certain tensors in memory for logging purposes, ie code duplication.
|
550 |
+
NOTE: to be used outside of this module, most likely during logging
|
551 |
+
|
552 |
+
Args:
|
553 |
+
vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
|
554 |
+
aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
|
555 |
+
'''
|
556 |
+
flatten_to_2D = True
|
557 |
+
out = dict()
|
558 |
+
|
559 |
+
# factorizing self.encode_streams into encode_visual/encode_audio to avoid unnecessary computations
|
560 |
+
# (B*S, D), (B, D) or None;
|
561 |
+
# vfeat, _ = self.encode_stream(vis, self.v_encoder, self.vproj, do_norm, flatten_to_2D, for_loop)
|
562 |
+
# afeat, _ = self.encode_stream(aud, self.a_encoder, self.aproj, do_norm, flatten_to_2D, for_loop)
|
563 |
+
vfeat, _, afeat, _ = self.encode_streams(vis, aud, for_loop, do_norm)
|
564 |
+
# cache features (for 0-shot evaluation)
|
565 |
+
out['segment_vfeat'] = vfeat.clone()
|
566 |
+
out['segment_afeat'] = afeat.clone()
|
567 |
+
# and similiarity matrices (for visualization) (B*S, B*S)
|
568 |
+
out['segment_sim_v2a'] = out['segment_vfeat'] @ out['segment_afeat'].mT / self.logit_scale
|
569 |
+
out['segment_sim_a2v'] = out['segment_afeat'] @ out['segment_vfeat'].mT / self.logit_scale
|
570 |
+
# self
|
571 |
+
out['segment_sim_v2v'] = out['segment_vfeat'] @ out['segment_vfeat'].mT / self.logit_scale
|
572 |
+
out['segment_sim_a2a'] = out['segment_afeat'] @ out['segment_afeat'].mT / self.logit_scale
|
573 |
+
|
574 |
+
# compute losses
|
575 |
+
loss, _ = self.compute_loss(vfeat, afeat, vfeat.mT, afeat.mT, self.logit_scale)
|
576 |
+
out['segment_contrastive_loss'] = loss
|
577 |
+
return out
|
578 |
+
|
579 |
+
@torch.no_grad()
|
580 |
+
def clamp_logit_scales(self):
|
581 |
+
self.logit_scale.clamp_(self.clamp_scale_min, self.clamp_scale_max)
|
582 |
+
return (self.logit_scale, None)
|
583 |
+
|
584 |
+
|
585 |
+
class MultilevelMoCoCLIP(nn.Module):
|
586 |
+
|
587 |
+
def __init__(self, n_embd: int, queue_size: int, momentum: float,
|
588 |
+
afeat_extractor: OmegaConf, vfeat_extractor: OmegaConf, aproj: OmegaConf, vproj: OmegaConf,
|
589 |
+
init_scale: float = 0.07, clamp_scale_min: float = 0.001, clamp_scale_max: float = 0.5):
|
590 |
+
super().__init__()
|
591 |
+
self.output_dict = True
|
592 |
+
self.n_embd = n_embd
|
593 |
+
self.momentum = momentum
|
594 |
+
self.to_add_global_repr = afeat_extractor.params.add_global_repr
|
595 |
+
|
596 |
+
# loading audio and rgb towers
|
597 |
+
self.v_encoder = instantiate_from_config(vfeat_extractor)
|
598 |
+
self.a_encoder = instantiate_from_config(afeat_extractor)
|
599 |
+
# loading audio and rgb towers and projection layers to account for different feature dimensions
|
600 |
+
self.segment_aproj = instantiate_from_config(aproj)
|
601 |
+
self.segment_vproj = instantiate_from_config(vproj)
|
602 |
+
self.global_aproj = instantiate_from_config(aproj) if self.to_add_global_repr else None
|
603 |
+
self.global_vproj = instantiate_from_config(vproj) if self.to_add_global_repr else None
|
604 |
+
|
605 |
+
self.clamp_scale_min, self.clamp_scale_max = clamp_scale_min, clamp_scale_max
|
606 |
+
self.init_scale = init_scale
|
607 |
+
self.segment_logit_scale = nn.Parameter(torch.ones([]) * self.init_scale) # NOTE: exp(1/OpenCLIP)
|
608 |
+
self.global_logit_scale = nn.Parameter(torch.ones([]) * self.init_scale) if self.to_add_global_repr else None
|
609 |
+
|
610 |
+
# create momentum models
|
611 |
+
self.v_encoder_m = instantiate_from_config(vfeat_extractor)
|
612 |
+
self.a_encoder_m = instantiate_from_config(afeat_extractor)
|
613 |
+
self.segment_aproj_m = instantiate_from_config(aproj)
|
614 |
+
self.segment_vproj_m = instantiate_from_config(vproj)
|
615 |
+
self.global_aproj_m = instantiate_from_config(aproj) if self.to_add_global_repr else None
|
616 |
+
self.global_vproj_m = instantiate_from_config(vproj) if self.to_add_global_repr else None
|
617 |
+
|
618 |
+
self.model_pairs = [
|
619 |
+
[self.v_encoder, self.v_encoder_m], [self.segment_vproj, self.segment_vproj_m],
|
620 |
+
[self.a_encoder, self.a_encoder_m], [self.segment_aproj, self.segment_aproj_m],
|
621 |
+
]
|
622 |
+
if self.to_add_global_repr:
|
623 |
+
self.model_pairs += [
|
624 |
+
[self.global_aproj, self.global_aproj_m], [self.global_vproj, self.global_vproj_m],
|
625 |
+
]
|
626 |
+
|
627 |
+
self.copy_params()
|
628 |
+
|
629 |
+
self.segment_queue_size = queue_size * afeat_extractor.params.max_segments # scaled by # of segments
|
630 |
+
self.global_queue_size = queue_size if self.to_add_global_repr else None
|
631 |
+
self.init_Qs(self.segment_queue_size, self.global_queue_size, self.n_embd)
|
632 |
+
|
633 |
+
# self.ln_final = text.ln_final # perhaps only useful for transformer towers
|
634 |
+
# self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
635 |
+
|
636 |
+
def forward(self, vis: torch.Tensor, aud: torch.Tensor, alpha: float = 0.0, for_loop: bool = False,
|
637 |
+
world_size=None):
|
638 |
+
'''
|
639 |
+
Args:
|
640 |
+
vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
|
641 |
+
aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
|
642 |
+
alpha (float): linear interpolation coefficient for pseudo-targets (with 0.0 targets are 1-hot)
|
643 |
+
for_loop (bool): whether to process each segment in a for loop or all at once
|
644 |
+
(speed-memory tradeoff)
|
645 |
+
Returns:
|
646 |
+
rgb_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None RGB features
|
647 |
+
audio_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None audio features
|
648 |
+
logit_scale (tuple(torch.Tensor)): local and global logit scales (1, )
|
649 |
+
'''
|
650 |
+
logit_scales = self.clamp_logit_scales()
|
651 |
+
to_add_global_repr = self.to_add_global_repr # for readability only
|
652 |
+
|
653 |
+
feats = self.encode_streams(vis, aud, for_momentum=False, for_loop=for_loop, do_norm=True)
|
654 |
+
segment_vfeat, global_vfeat, segment_afeat, global_afeat = feats
|
655 |
+
|
656 |
+
# get momentum features
|
657 |
+
with torch.no_grad():
|
658 |
+
if self.training:
|
659 |
+
self._momentum_update()
|
660 |
+
feats_m = self.encode_streams(vis, aud, for_momentum=True, for_loop=for_loop, do_norm=True)
|
661 |
+
segment_vfeat_m, global_vfeat_m, segment_afeat_m, global_afeat_m = feats_m
|
662 |
+
|
663 |
+
# cat with queue to extend the list of negatives
|
664 |
+
segment_vfeat_all = torch.cat([segment_vfeat_m.t(), self.segment_v_queue.clone().detach()], dim=1)
|
665 |
+
segment_afeat_all = torch.cat([segment_afeat_m.t(), self.segment_a_queue.clone().detach()], dim=1)
|
666 |
+
if to_add_global_repr:
|
667 |
+
global_vfeat_all = torch.cat([global_vfeat_m.t(), self.global_v_queue.clone().detach()], dim=1)
|
668 |
+
global_afeat_all = torch.cat([global_afeat_m.t(), self.global_a_queue.clone().detach()], dim=1)
|
669 |
+
|
670 |
+
segment_loss_avc, _ = self.compute_loss(segment_vfeat, segment_afeat, segment_vfeat_all,
|
671 |
+
segment_afeat_all, self.segment_logit_scale,
|
672 |
+
alpha, segment_vfeat_m, segment_afeat_m)
|
673 |
+
|
674 |
+
global_loss_avc = None
|
675 |
+
if to_add_global_repr:
|
676 |
+
global_loss_avc, _ = self.compute_loss(global_vfeat, global_afeat, global_vfeat_all,
|
677 |
+
global_afeat_all, self.global_logit_scale,
|
678 |
+
alpha, global_vfeat_m, global_afeat_m)
|
679 |
+
|
680 |
+
if self.training:
|
681 |
+
self._multilevel_dequeue_and_enqueue(segment_vfeat_m, segment_afeat_m, global_vfeat_m, global_afeat_m)
|
682 |
+
else:
|
683 |
+
raise Exception('This module is used only during training. Use model.something instead.')
|
684 |
+
|
685 |
+
out = {
|
686 |
+
'rgb_features': (segment_vfeat, global_vfeat),
|
687 |
+
'audio_features': (segment_afeat, global_afeat), 'logit_scales': logit_scales,
|
688 |
+
'losses': {'segment_contrastive_loss': segment_loss_avc},
|
689 |
+
}
|
690 |
+
if global_loss_avc is not None:
|
691 |
+
out['losses']['global_contrastive_loss'] = global_loss_avc
|
692 |
+
return out
|
693 |
+
|
694 |
+
def compute_loss(self, vfeat, afeat, vfeat_all, afeat_all, scale, alpha=0.0, vfeat_m=None, afeat_m=None):
|
695 |
+
'''For Multi-level contrastive learning, the losses are made the same way for all levels'''
|
696 |
+
sim_v2a = vfeat @ afeat_all / scale
|
697 |
+
sim_a2v = afeat @ vfeat_all / scale
|
698 |
+
sim_v2a_t, sim_a2v_t = self._make_targets(sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m)
|
699 |
+
loss = self._loss(sim_v2a, sim_a2v, sim_v2a_t, sim_a2v_t)
|
700 |
+
return loss, (sim_v2a, sim_a2v)
|
701 |
+
|
702 |
+
@torch.no_grad()
|
703 |
+
def _make_targets(self, sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m):
|
704 |
+
# NOTE: for simplicity, we assume that sim_v2a.shape[0] == sim_a2v.shape[0]
|
705 |
+
# NOTE: sim_targets is not square (sim_v2a.shape is (bsize, bsize+Qsize) )
|
706 |
+
sim_targets = torch.eye(*sim_v2a.shape, device=sim_v2a.device, dtype=sim_v2a.dtype)
|
707 |
+
# the ALBEF alpha trick
|
708 |
+
if alpha > 0.0:
|
709 |
+
sim_v2a_m = vfeat_m @ afeat_all / scale
|
710 |
+
sim_a2v_m = afeat_m @ vfeat_all / scale
|
711 |
+
sim_v2a_targets = alpha * F.softmax(sim_v2a_m, dim=1) + (1 - alpha) * sim_targets
|
712 |
+
sim_a2v_targets = alpha * F.softmax(sim_a2v_m, dim=1) + (1 - alpha) * sim_targets
|
713 |
+
else:
|
714 |
+
sim_v2a_targets = sim_targets
|
715 |
+
sim_a2v_targets = sim_targets
|
716 |
+
return sim_v2a_targets, sim_a2v_targets
|
717 |
+
|
718 |
+
def _loss(self, sim_v2a, sim_a2v, sim_v2a_targets, sim_a2v_targets):
|
719 |
+
loss_v2a = F.cross_entropy(sim_v2a, sim_v2a_targets)
|
720 |
+
loss_a2v = F.cross_entropy(sim_a2v, sim_a2v_targets)
|
721 |
+
return (loss_v2a + loss_a2v) / 2
|
722 |
+
|
723 |
+
def encode_streams(self, vis, aud, for_momentum, for_loop, do_norm=True):
|
724 |
+
# (B*S, D), (B, D) or None; because `flatten_to_2D = True`
|
725 |
+
flatten_to_2D = True
|
726 |
+
segment_vfeat, global_vfeat = self.encode_visual(vis, for_momentum, self.to_add_global_repr, do_norm,
|
727 |
+
for_loop=for_loop, flatten_to_2D=flatten_to_2D)
|
728 |
+
segment_afeat, global_afeat = self.encode_audio(aud, for_momentum, self.to_add_global_repr, do_norm,
|
729 |
+
for_loop=for_loop, flatten_to_2D=flatten_to_2D)
|
730 |
+
return segment_vfeat, global_vfeat, segment_afeat, global_afeat
|
731 |
+
|
732 |
+
def encode_audio(self, x, for_momentum: bool = False, do_global: bool = True, do_norm: bool = True,
|
733 |
+
flatten_to_2D=True, for_loop=False):
|
734 |
+
# define callables
|
735 |
+
encode_fn = self.a_encoder_m if for_momentum else self.a_encoder
|
736 |
+
segment_proj_fn = self.segment_aproj_m if for_momentum else self.segment_aproj
|
737 |
+
global_proj_fn = self.global_aproj_m if for_momentum else self.global_aproj
|
738 |
+
# do the encoding
|
739 |
+
return self.encode_stream(x, encode_fn, segment_proj_fn, global_proj_fn, do_global, do_norm,
|
740 |
+
flatten_to_2D, for_loop)
|
741 |
+
|
742 |
+
def encode_visual(self, x, for_momentum: bool = False, do_global: bool = True, do_norm: bool = True,
|
743 |
+
flatten_to_2D=True, for_loop=False):
|
744 |
+
# define callables
|
745 |
+
encode_fn = self.v_encoder_m if for_momentum else self.v_encoder
|
746 |
+
segment_proj_fn = self.segment_vproj_m if for_momentum else self.segment_vproj
|
747 |
+
global_proj_fn = self.global_vproj_m if for_momentum else self.global_vproj
|
748 |
+
# do the encoding
|
749 |
+
return self.encode_stream(x, encode_fn, segment_proj_fn, global_proj_fn, do_global, do_norm,
|
750 |
+
flatten_to_2D, for_loop)
|
751 |
+
|
752 |
+
def encode_stream(self, x, feat_extractor_fn, segment_proj_fn, global_proj_fn, do_global, do_norm,
|
753 |
+
flatten_to_2D, for_loop):
|
754 |
+
# x is (B, S, ...)
|
755 |
+
segment_x, global_x = feat_extractor_fn(x, for_loop) # segment_x: (B, S, D), global_x: (B, D)
|
756 |
+
if flatten_to_2D:
|
757 |
+
B, S, D = segment_x.shape
|
758 |
+
segment_x = segment_x.view(B*S, D) # flatten batch and segment dims
|
759 |
+
segment_x = segment_proj_fn(segment_x)
|
760 |
+
segment_x = F.normalize(segment_x, dim=-1) if do_norm else segment_x
|
761 |
+
# do_global is passed in to avoid computing global features when not needed (e.g. during eval)
|
762 |
+
if do_global and self.to_add_global_repr:
|
763 |
+
global_x = global_proj_fn(global_x)
|
764 |
+
global_x = F.normalize(global_x, dim=-1) if do_norm else global_x
|
765 |
+
return segment_x, global_x # (B*S, D), (B, D) or None
|
766 |
+
|
767 |
+
def forward_for_logging(self, vis, aud, for_momentum=False, for_loop=False, do_norm=True):
|
768 |
+
'''
|
769 |
+
Runs the forward pass but keeps certain tensors in memory for logging purposes, ie code duplication.
|
770 |
+
NOTE: to be used outside of this module, most likely during logging
|
771 |
+
|
772 |
+
Args:
|
773 |
+
vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
|
774 |
+
aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
|
775 |
+
'''
|
776 |
+
flatten_to_2D = True
|
777 |
+
|
778 |
+
out = dict()
|
779 |
+
|
780 |
+
# factorizing self.encode_streams into encode_visual/encode_audio to avoid unnecessary computations
|
781 |
+
# (B*S, D), (B, D) or None;
|
782 |
+
segment_vfeat, global_vfeat = self.encode_visual(vis, for_momentum, self.to_add_global_repr, do_norm,
|
783 |
+
flatten_to_2D, for_loop)
|
784 |
+
segment_afeat, global_afeat = self.encode_audio(aud, for_momentum, self.to_add_global_repr, do_norm,
|
785 |
+
flatten_to_2D, for_loop)
|
786 |
+
# cache features (for 0-shot evaluation)
|
787 |
+
out['segment_vfeat'] = segment_vfeat.clone()
|
788 |
+
out['segment_afeat'] = segment_afeat.clone()
|
789 |
+
# and similiarity matrices (for visualization) (B*S, B*S)
|
790 |
+
out['segment_sim_v2a'] = out['segment_vfeat'] @ out['segment_afeat'].mT / self.segment_logit_scale
|
791 |
+
out['segment_sim_a2v'] = out['segment_afeat'] @ out['segment_vfeat'].mT / self.segment_logit_scale
|
792 |
+
# self
|
793 |
+
out['segment_sim_v2v'] = out['segment_vfeat'] @ out['segment_vfeat'].mT / self.segment_logit_scale
|
794 |
+
out['segment_sim_a2a'] = out['segment_afeat'] @ out['segment_afeat'].mT / self.segment_logit_scale
|
795 |
+
|
796 |
+
# compute losses
|
797 |
+
segment_loss, _ = self.compute_loss(segment_vfeat, segment_afeat, segment_vfeat.mT, segment_afeat.mT,
|
798 |
+
self.segment_logit_scale)
|
799 |
+
out['segment_contrastive_loss'] = segment_loss
|
800 |
+
if self.to_add_global_repr:
|
801 |
+
global_loss, _ = self.compute_loss(global_vfeat, global_afeat, global_vfeat.mT, global_afeat.mT,
|
802 |
+
self.global_logit_scale)
|
803 |
+
out['global_contrastive_loss'] = global_loss
|
804 |
+
|
805 |
+
return out
|
806 |
+
|
807 |
+
|
808 |
+
@torch.no_grad()
|
809 |
+
def clamp_logit_scales(self):
|
810 |
+
self.segment_logit_scale.clamp_(self.clamp_scale_min, self.clamp_scale_max)
|
811 |
+
if self.to_add_global_repr:
|
812 |
+
self.global_logit_scale.clamp_(self.clamp_scale_min, self.clamp_scale_max)
|
813 |
+
return (self.segment_logit_scale, self.global_logit_scale)
|
814 |
+
|
815 |
+
@torch.no_grad()
|
816 |
+
def copy_params(self):
|
817 |
+
for model_pair in self.model_pairs:
|
818 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
819 |
+
param_m.data.copy_(param.data) # initialize
|
820 |
+
param_m.requires_grad = False # not update by gradient
|
821 |
+
|
822 |
+
@torch.no_grad()
|
823 |
+
def _momentum_update(self):
|
824 |
+
for model_pair in self.model_pairs:
|
825 |
+
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
826 |
+
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
827 |
+
|
828 |
+
@torch.no_grad()
|
829 |
+
def _multilevel_dequeue_and_enqueue(self, segment_vfeat_m, segment_afeat_m, global_vfeat_m, global_afeat_m):
|
830 |
+
if self.segment_queue_size > 0:
|
831 |
+
self._dequeue_and_enqueue(segment_vfeat_m, segment_afeat_m, 'segment_')
|
832 |
+
if self.to_add_global_repr and self.global_queue_size > 0:
|
833 |
+
self._dequeue_and_enqueue(global_vfeat_m, global_afeat_m, 'global_')
|
834 |
+
|
835 |
+
@torch.no_grad()
|
836 |
+
def _dequeue_and_enqueue(self, vfeat, afeat, level_prefix_: str):
|
837 |
+
# gather keys before updating queue
|
838 |
+
if torch.distributed.is_initialized():
|
839 |
+
vfeats = concat_all_gather(vfeat)
|
840 |
+
afeats = concat_all_gather(afeat)
|
841 |
+
else:
|
842 |
+
vfeats = vfeat
|
843 |
+
afeats = afeat
|
844 |
+
|
845 |
+
batch_size = vfeats.shape[0]
|
846 |
+
queue_size = getattr(self, level_prefix_ + 'queue_size')
|
847 |
+
|
848 |
+
# same as `ptr = int(self.segment_queue_ptr)` but allows accessing the attribute by string
|
849 |
+
ptr = int(getattr(self, level_prefix_ + 'queue_ptr'))
|
850 |
+
assert queue_size % batch_size == 0, f'For simplicity: {queue_size} % {batch_size} == 0'
|
851 |
+
|
852 |
+
# replace the keys at ptr (dequeue and enqueue)
|
853 |
+
getattr(self, level_prefix_ + 'v_queue')[:, ptr:ptr + batch_size] = vfeats.T
|
854 |
+
getattr(self, level_prefix_ + 'a_queue')[:, ptr:ptr + batch_size] = afeats.T
|
855 |
+
ptr = (ptr + batch_size) % queue_size # move pointer
|
856 |
+
|
857 |
+
getattr(self, level_prefix_ + 'queue_ptr')[0] = ptr
|
858 |
+
|
859 |
+
def init_Qs(self, segment_queue_size: int, global_queue_size: int, n_embd: int):
|
860 |
+
# create the queues; TODO: flip the dimensions, yikes!
|
861 |
+
self.register_buffer('segment_v_queue', torch.randn(n_embd, segment_queue_size))
|
862 |
+
self.register_buffer('segment_a_queue', torch.randn(n_embd, segment_queue_size))
|
863 |
+
self.register_buffer('segment_queue_ptr', torch.zeros(1, dtype=torch.long))
|
864 |
+
self.segment_v_queue = nn.functional.normalize(self.segment_v_queue, dim=0)
|
865 |
+
self.segment_a_queue = nn.functional.normalize(self.segment_a_queue, dim=0)
|
866 |
+
if self.to_add_global_repr:
|
867 |
+
self.register_buffer('global_v_queue', torch.randn(n_embd, global_queue_size))
|
868 |
+
self.register_buffer('global_a_queue', torch.randn(n_embd, global_queue_size))
|
869 |
+
self.register_buffer('global_queue_ptr', torch.zeros(1, dtype=torch.long))
|
870 |
+
self.global_v_queue = nn.functional.normalize(self.global_v_queue, dim=0)
|
871 |
+
self.global_a_queue = nn.functional.normalize(self.global_a_queue, dim=0)
|
872 |
+
|
873 |
+
@torch.no_grad()
|
874 |
+
def concat_all_gather(tensor):
|
875 |
+
"""
|
876 |
+
Performs all_gather operation on the provided tensors.
|
877 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
878 |
+
"""
|
879 |
+
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
880 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
881 |
+
|
882 |
+
output = torch.cat(tensors_gather, dim=0)
|
883 |
+
return output
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
23,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
23,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
6,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x16.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 384,
|
5 |
+
"layers": [
|
6 |
+
6,
|
7 |
+
8,
|
8 |
+
18,
|
9 |
+
8
|
10 |
+
],
|
11 |
+
"width": 96,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 768,
|
18 |
+
"heads": 12,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x64.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 448,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
15,
|
8 |
+
36,
|
9 |
+
10
|
10 |
+
],
|
11 |
+
"width": 128,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 1024,
|
18 |
+
"heads": 16,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus-240.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 240,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-plus-256.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 256,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-quickgelu.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": 12,
|
7 |
+
"width": 768,
|
8 |
+
"patch_size": 32
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
}
|
17 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-14.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 1280,
|
7 |
+
"head_width": 80,
|
8 |
+
"patch_size": 14
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
}
|
17 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-16.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 1280,
|
7 |
+
"head_width": 80,
|
8 |
+
"patch_size": 16
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
}
|
17 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-280.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 280,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-336.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16-320.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 320,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|