AmrMKayid commited on
Commit
1a6ac97
·
verified ·
1 Parent(s): 0ed633a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +38 -0
  2. __init__.py +8 -0
  3. cfg.yaml +282 -0
  4. config.json +12 -0
  5. model.pt +3 -0
  6. model_files.zip +3 -0
  7. modeling_synchformer.py +10 -0
  8. module_loader.py +32 -0
  9. modules/dataset/__init__.py +0 -0
  10. modules/dataset/audioset.py +216 -0
  11. modules/dataset/dataset_utils.py +112 -0
  12. modules/dataset/lrs.py +192 -0
  13. modules/dataset/transforms.py +1074 -0
  14. modules/dataset/vggsound.py +394 -0
  15. modules/model/__init__.py +1 -0
  16. modules/model/modules/bridges.py +178 -0
  17. modules/model/modules/feat_extractors/audio/ast.py +279 -0
  18. modules/model/modules/feat_extractors/audio/hf_src/modeling_ast.py +662 -0
  19. modules/model/modules/feat_extractors/audio/resnet.py +249 -0
  20. modules/model/modules/feat_extractors/train_clip_src/__init__.py +3 -0
  21. modules/model/modules/feat_extractors/train_clip_src/open_clip/__init__.py +13 -0
  22. modules/model/modules/feat_extractors/train_clip_src/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  23. modules/model/modules/feat_extractors/train_clip_src/open_clip/coca_model.py +458 -0
  24. modules/model/modules/feat_extractors/train_clip_src/open_clip/constants.py +2 -0
  25. modules/model/modules/feat_extractors/train_clip_src/open_clip/factory.py +193 -0
  26. modules/model/modules/feat_extractors/train_clip_src/open_clip/generation_utils.py +0 -0
  27. modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_configs.py +45 -0
  28. modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_model.py +176 -0
  29. modules/model/modules/feat_extractors/train_clip_src/open_clip/loss.py +229 -0
  30. modules/model/modules/feat_extractors/train_clip_src/open_clip/model.py +883 -0
  31. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101-quickgelu.json +22 -0
  32. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101.json +21 -0
  33. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50-quickgelu.json +22 -0
  34. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50.json +21 -0
  35. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x16.json +21 -0
  36. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x4.json +21 -0
  37. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x64.json +21 -0
  38. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
  39. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus.json +16 -0
  40. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16.json +16 -0
  41. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
  42. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  43. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32.json +16 -0
  44. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-14.json +17 -0
  45. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-16.json +17 -0
  46. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-280.json +16 -0
  47. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-336.json +16 -0
  48. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14.json +16 -0
  49. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16-320.json +16 -0
  50. modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16.json +16 -0
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Synchformer Hugging Face Model
2
+
3
+ This repository contains a Synchformer model for audio-visual synchronization. The model predicts the offset between audio and video tracks.
4
+
5
+ ## Usage
6
+
7
+ ```python
8
+ from transformers import AutoModel
9
+ import torch
10
+
11
+ # Load the model
12
+ model = AutoModel.from_pretrained("AmrMKayid/synchformer-hf")
13
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # Predict offset for a video
16
+ results = model.predict_offset(
17
+ "path/to/your/video.mp4",
18
+ offset_sec=0.0, # Ground truth offset (if known)
19
+ v_start_i_sec=0.0 # Start time in seconds for video
20
+ )
21
+
22
+ # Print results
23
+ print("\nPrediction Results:")
24
+ for pred in results["predictions"]:
25
+ print(f'p={pred["probability"]:.4f}, "{pred["offset_sec"]:.2f}" (class {pred["class_idx"]})')
26
+ ```
27
+
28
+ ## Model Details
29
+
30
+ This model is based on the Synchformer architecture, which uses a transformer to predict the offset between audio and video tracks.
31
+
32
+ ## Requirements
33
+
34
+ - torch
35
+ - torchaudio
36
+ - torchvision
37
+ - omegaconf
38
+ - ffmpeg (for video processing)
__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .synchformer_config import SynchformerConfig
2
+ from .synchformer_model import SynchformerModel
3
+ from .modeling_synchformer import *
4
+
5
+ __all__ = [
6
+ "SynchformerConfig",
7
+ "SynchformerModel",
8
+ ]
cfg.yaml ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ action: train_avsync_model
2
+ model:
3
+ target: model.sync_model.Synchformer
4
+ params:
5
+ afeat_extractor:
6
+ is_trainable: false
7
+ target: model.modules.feat_extractors.audio.ast.AST
8
+ params:
9
+ ckpt_path: /scratch/project_462000293/vladimir/logs/sync/avclip_models/23-12-22T16-13-38/checkpoints/epoch_e28.pt
10
+ extract_features: true
11
+ max_spec_t: 66
12
+ factorize_freq_time: true
13
+ agg_freq_module: TransformerEncoderLayer
14
+ agg_time_module: torch.nn.Identity
15
+ add_global_repr: false
16
+ vfeat_extractor:
17
+ is_trainable: false
18
+ target: model.modules.feat_extractors.visual.motionformer.MotionFormer
19
+ params:
20
+ ckpt_path: /scratch/project_462000293/vladimir/logs/sync/avclip_models/23-12-22T16-13-38/checkpoints/epoch_e28.pt
21
+ extract_features: true
22
+ factorize_space_time: true
23
+ agg_space_module: TransformerEncoderLayer
24
+ agg_time_module: torch.nn.Identity
25
+ add_global_repr: false
26
+ aproj:
27
+ target: torch.nn.Linear
28
+ params:
29
+ in_features: 768
30
+ out_features: 768
31
+ vproj:
32
+ target: torch.nn.Linear
33
+ params:
34
+ in_features: 768
35
+ out_features: 768
36
+ transformer:
37
+ target: model.sync_model.GlobalTransformer
38
+ params:
39
+ n_layer: 3
40
+ n_head: 8
41
+ n_embd: 768
42
+ tok_pdrop: 0.0
43
+ embd_pdrop: 0.1
44
+ resid_pdrop: 0.1
45
+ attn_pdrop: 0.1
46
+ pos_emb_cfg:
47
+ target: model.modules.transformer.RandInitPositionalEncoding
48
+ params:
49
+ block_shape:
50
+ - 198
51
+ n_embd: 768
52
+ off_head_cfg:
53
+ target: torch.nn.Linear
54
+ params:
55
+ in_features: 768
56
+ out_features: 21
57
+ training:
58
+ base_learning_rate: 2.0e-06
59
+ base_batch_size: 16
60
+ num_workers: 7
61
+ num_epochs: 10000
62
+ patience: 50
63
+ to_max_metric: true
64
+ metric_name: accuracy_1
65
+ early_stop_phase: valid
66
+ use_half_precision: true
67
+ seed: 1337
68
+ compile: false
69
+ skip_test: false
70
+ run_test_only: false
71
+ resume: false
72
+ finetune: false
73
+ dist_backend: nccl
74
+ max_clip_norm: 1
75
+ lr_scheduler:
76
+ name: constant_with_warmup
77
+ warmup: 1000
78
+ optimizer:
79
+ name: adam
80
+ betas:
81
+ - 0.9
82
+ - 0.999
83
+ momentum: 0.9
84
+ weight_decay: 0
85
+ local_rank: 0
86
+ global_rank: 0
87
+ world_size: 32
88
+ data:
89
+ offset_type: grid
90
+ num_off_cls: 21
91
+ prob_oos: null
92
+ max_off_sec: 2
93
+ crop_len_sec: 5
94
+ step_size_seg: 0.5
95
+ vids_path: /scratch/project_462000293/vladimir/data/audioset/h264_video_25fps_256side_16000hz_aac/
96
+ size_before_crop: 256
97
+ input_size: 224
98
+ segment_size_vframes: 16
99
+ vfps: 25
100
+ afps: 16000
101
+ n_segments: 14
102
+ do_offset: true
103
+ p_color_jitter: 0.0
104
+ p_gray_scale: 0.0
105
+ sometimes_upscale_p: 0.0
106
+ is_spatial_crop_random: true
107
+ is_temporal_crop_random: true
108
+ audio_jitter_sec: 0.05
109
+ p_horizontal_flip: 0.5
110
+ p_audio_aug: 0.0
111
+ dataset:
112
+ target: dataset.audioset.AudioSet
113
+ params:
114
+ load_fixed_offsets_on:
115
+ - valid
116
+ - test
117
+ vis_load_backend: read_video
118
+ size_ratio: null
119
+ transform_sequence_train:
120
+ - target: dataset.transforms.EqualifyFromRight
121
+ params:
122
+ clip_max_len_sec: 10
123
+ - target: dataset.transforms.RGBSpatialCropSometimesUpscale
124
+ params:
125
+ sometimes_p: 0.0
126
+ smaller_input_size: 192
127
+ target_input_size: 224
128
+ is_random: true
129
+ - target: dataset.transforms.TemporalCropAndOffset
130
+ params:
131
+ crop_len_sec: 5
132
+ max_off_sec: 2
133
+ max_wiggle_sec: 0.05
134
+ do_offset: true
135
+ offset_type: grid
136
+ prob_oos: null
137
+ grid_size: 21
138
+ segment_size_vframes: 16
139
+ n_segments: 14
140
+ step_size_seg: 0.5
141
+ vfps: 25
142
+ - target: dataset.transforms.RandomApplyColorDistortion
143
+ params:
144
+ p_color_jitter: 0.0
145
+ s: 1.0
146
+ p_gray_scale: 0.0
147
+ - target: dataset.transforms.RandomHorizontalFlip
148
+ params:
149
+ p: 0.5
150
+ - target: dataset.transforms.AudioRandomReverb
151
+ params:
152
+ p: 0.0
153
+ - target: dataset.transforms.AudioRandomVolume
154
+ params:
155
+ p: 0.0
156
+ gain: 2.0
157
+ gain_type: amplitude
158
+ - target: dataset.transforms.AudioRandomPitchShift
159
+ params:
160
+ p: 0.0
161
+ shift: 1000
162
+ - target: dataset.transforms.AudioRandomLowpassFilter
163
+ params:
164
+ p: 0.0
165
+ cutoff_freq: 100
166
+ - target: dataset.transforms.AudioRandomGaussNoise
167
+ params:
168
+ p: 0.0
169
+ amplitude: 0.01
170
+ - target: dataset.transforms.GenerateMultipleSegments
171
+ params:
172
+ segment_size_vframes: 16
173
+ n_segments: 14
174
+ is_start_random: true
175
+ step_size_seg: 0.5
176
+ - target: dataset.transforms.RGBToHalfToZeroOne
177
+ - target: dataset.transforms.RGBNormalize
178
+ params:
179
+ mean:
180
+ - 0.5
181
+ - 0.5
182
+ - 0.5
183
+ std:
184
+ - 0.5
185
+ - 0.5
186
+ - 0.5
187
+ - target: dataset.transforms.AudioMelSpectrogram
188
+ params:
189
+ sample_rate: 16000
190
+ win_length: 400
191
+ hop_length: 160
192
+ n_fft: 1024
193
+ n_mels: 128
194
+ - target: dataset.transforms.AudioLog
195
+ - target: dataset.transforms.PadOrTruncate
196
+ params:
197
+ max_spec_t: 66
198
+ - target: dataset.transforms.AudioNormalizeAST
199
+ params:
200
+ mean: -4.2677393
201
+ std: 4.5689974
202
+ - target: dataset.transforms.PermuteStreams
203
+ params:
204
+ einops_order_audio: S F T -> S 1 F T
205
+ einops_order_rgb: S T C H W -> S T C H W
206
+ transform_sequence_test:
207
+ - target: dataset.transforms.EqualifyFromRight
208
+ - target: dataset.transforms.RGBSpatialCrop
209
+ params:
210
+ input_size: 224
211
+ is_random: false
212
+ - target: dataset.transforms.TemporalCropAndOffset
213
+ params:
214
+ crop_len_sec: 5
215
+ max_off_sec: 2
216
+ max_wiggle_sec: 0.0
217
+ do_offset: true
218
+ grid_size: 21
219
+ offset_type: grid
220
+ prob_oos: null
221
+ segment_size_vframes: 16
222
+ n_segments: 14
223
+ step_size_seg: 0.5
224
+ vfps: 25
225
+ - target: dataset.transforms.GenerateMultipleSegments
226
+ params:
227
+ segment_size_vframes: 16
228
+ n_segments: 14
229
+ is_start_random: false
230
+ step_size_seg: 0.5
231
+ - target: dataset.transforms.RGBToHalfToZeroOne
232
+ - target: dataset.transforms.RGBNormalize
233
+ params:
234
+ mean:
235
+ - 0.5
236
+ - 0.5
237
+ - 0.5
238
+ std:
239
+ - 0.5
240
+ - 0.5
241
+ - 0.5
242
+ - target: dataset.transforms.AudioMelSpectrogram
243
+ params:
244
+ sample_rate: 16000
245
+ win_length: 400
246
+ hop_length: 160
247
+ n_fft: 1024
248
+ n_mels: 128
249
+ - target: dataset.transforms.AudioLog
250
+ - target: dataset.transforms.PadOrTruncate
251
+ params:
252
+ max_spec_t: 66
253
+ - target: dataset.transforms.AudioNormalizeAST
254
+ params:
255
+ mean: -4.2677393
256
+ std: 4.5689974
257
+ - target: dataset.transforms.PermuteStreams
258
+ params:
259
+ einops_order_audio: S F T -> S 1 F T
260
+ einops_order_rgb: S T C H W -> S T C H W
261
+ logging:
262
+ logdir: /scratch/project_462000293/vladimir/logs/sync/sync_models/
263
+ log_code_state: true
264
+ log_frequency: 20
265
+ patterns_to_ignore:
266
+ - logs
267
+ - .git
268
+ - __pycache__
269
+ - data
270
+ - '*.pt'
271
+ - sbatch_logs
272
+ - '*.mp4'
273
+ - '*.wav'
274
+ - '*.jpg'
275
+ - '*.gif'
276
+ - misc*
277
+ vis_segment_sim: true
278
+ log_max_items: 500000
279
+ use_wandb: true
280
+ start_time: 24-01-04T16-39-21
281
+ config: ./configs/sync.yaml
282
+ ckpt_path: /scratch/project_462000293/vladimir/logs/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "afeat_extractor_config": {},
3
+ "afps": 16000,
4
+ "in_size": 256,
5
+ "max_off_sec": 2,
6
+ "model_type": "synchformer",
7
+ "transformer_config": {},
8
+ "transformers_version": "4.45.2",
9
+ "use_half_precision": true,
10
+ "vfeat_extractor_config": {},
11
+ "vfps": 25
12
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b4b3557fbd96b61aaffa8bc70b28f9ff53f8fa98edc202655c5d94ab3c719ee
3
+ size 1131153989
model_files.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f5534dc446ddf64b8da751931b490c2990727784d3cab28f3c5b04c783160a
3
+ size 787762366
modeling_synchformer.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.auto.modeling_auto import _BaseAutoModelClass
2
+ from transformers.models.auto.configuration_auto import _LazyAutoMapping
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.models.auto import AutoModel
5
+
6
+ from synchformer_config import SynchformerConfig
7
+ from synchformer_model import SynchformerModel
8
+
9
+ # Register the model with the transformers library
10
+ AutoModel.register(SynchformerConfig, SynchformerModel)
module_loader.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+ import importlib.util
5
+
6
+ def setup_modules():
7
+ # Get the directory where this script is located
8
+ current_dir = os.path.dirname(os.path.abspath(__file__))
9
+ modules_dir = os.path.join(current_dir, "modules")
10
+
11
+ # Add modules directory to path
12
+ if modules_dir not in sys.path:
13
+ sys.path.insert(0, modules_dir)
14
+
15
+ # Import required modules
16
+ required_modules = ["dataset", "model", "scripts", "utils"]
17
+ for module_name in required_modules:
18
+ module_path = os.path.join(modules_dir, module_name)
19
+ if os.path.exists(module_path):
20
+ # Check if module is already imported
21
+ if module_name not in sys.modules:
22
+ # Import the module
23
+ spec = importlib.util.spec_from_file_location(
24
+ module_name,
25
+ os.path.join(module_path, "__init__.py")
26
+ )
27
+ if spec:
28
+ module = importlib.util.module_from_spec(spec)
29
+ sys.modules[module_name] = module
30
+ spec.loader.exec_module(module)
31
+
32
+ return True
modules/dataset/__init__.py ADDED
File without changes
modules/dataset/audioset.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import logging
3
+ import random
4
+ import sys
5
+ from glob import glob
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ sys.path.insert(0, '.') # nopep8
11
+ from dataset.dataset_utils import (get_fixed_offsets, get_video_and_audio)
12
+
13
+
14
+ class AudioSet(torch.utils.data.Dataset):
15
+
16
+ def __init__(self,
17
+ split,
18
+ vids_dir,
19
+ transforms=None,
20
+ to_filter_bad_examples=True,
21
+ splits_path='./data',
22
+ meta_path='./data/audioset.csv',
23
+ seed=1337,
24
+ load_fixed_offsets_on=['valid' 'test'],
25
+ vis_load_backend='read_video',
26
+ size_ratio=None,
27
+ attr_annot_path=None,
28
+ max_attr_per_vid=None):
29
+ super().__init__()
30
+ self.max_clip_len_sec = None
31
+ self.split = split
32
+ self.vids_dir = Path(vids_dir)
33
+ self.transforms = transforms
34
+ self.to_filter_bad_examples = to_filter_bad_examples
35
+ self.splits_path = Path(splits_path)
36
+ self.meta_path = Path(meta_path)
37
+ self.seed = seed
38
+ self.load_fixed_offsets_on = [] if load_fixed_offsets_on is None else load_fixed_offsets_on
39
+ self.vis_load_backend = vis_load_backend
40
+ self.size_ratio = size_ratio
41
+
42
+ self.split2short = {'train': 'unbalanced', 'valid': 'balanced', 'test': 'eval'}
43
+ short2long = {'unbalanced': 'unbalanced_train_segments',
44
+ 'balanced': 'balanced_train_segments',
45
+ 'eval': 'eval_segments'}
46
+
47
+ # read meta
48
+ split_meta = []
49
+ for shortdir_vid, start, end, targets, phase in csv.reader(open(meta_path), quotechar='"'):
50
+ if shortdir_vid.startswith(self.split2short[split]):
51
+ # shortdir_vid 'unbalanced/NFap9qgsI_s' -> 'unbalanced_train_segments/NFap9qgsI_s'
52
+ shortdir, vid = shortdir_vid.split('/')
53
+ longdir_vid = '/'.join([short2long[shortdir], vid])
54
+ split_meta.append([longdir_vid, float(start), float(end), targets, phase])
55
+
56
+ # filter "bad" examples
57
+ if to_filter_bad_examples:
58
+ split_meta = self.filter_bad_examples(split_meta)
59
+
60
+ # label maps
61
+ self.label2target = {l: int(t) for t, _, l in csv.reader(open(self.splits_path / 'audioset_labels.csv'))}
62
+ self.target2label = {t: l for l, t in self.label2target.items()}
63
+ self.video2target = {key: list(map(int, targets.split(','))) for key, _, _, targets, _ in split_meta}
64
+
65
+ clip_paths = [self.vids_dir / f'{k}_{int(s*1000)}_{int(e*1000)}.mp4' for k, s, e, t, p in split_meta]
66
+ clip_paths = sorted(clip_paths)
67
+
68
+ # loading the fixed offsets. COMMENT THIS IF YOU DON'T HAVE A FILE YET
69
+ if transforms is not None and split in load_fixed_offsets_on:
70
+ logging.info(f'Using fixed offset for {split}')
71
+ self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'audioset')
72
+
73
+ self.dataset = clip_paths
74
+ if size_ratio is not None and 0.0 < size_ratio < 1.0:
75
+ cut_off = int(len(self.dataset) * size_ratio)
76
+ random.seed(seed)
77
+ random.shuffle(self.dataset)
78
+ self.dataset = self.dataset[:cut_off]
79
+
80
+ logging.info(f'{split} has {len(self.dataset)} items')
81
+
82
+ def filter_bad_examples(self, audioset_meta):
83
+ bad = set()
84
+ base_path = Path('./data/filtered_examples_audioset')
85
+ files = sorted(glob(str(base_path / '*.txt')))
86
+ lists = [open(p).read().splitlines() for p in files]
87
+ logging.info(f'Filtering for {files}')
88
+ for s in lists:
89
+ bad = bad.union(s)
90
+ # the ugly string converts '---g-f_I2yQ', '1' into `---g-f_I2yQ_1000_11000`
91
+ audioset_meta = [r for r in audioset_meta if f'{r[0]}_{int(r[1]*1000)}_{int(r[2]*1000)}' not in bad]
92
+ return audioset_meta
93
+
94
+ def __getitem__(self, index):
95
+ path = self.dataset[index]
96
+ rgb, audio, meta = self.load_media(path)
97
+ item = self.make_datapoint(path, rgb, audio, meta)
98
+ if self.transforms is not None:
99
+ item = self.transforms(item)
100
+ return item
101
+
102
+ def make_datapoint(self, path, rgb, audio, meta):
103
+ # (Tv, 3, H, W) in [0, 225], (Ta, C) in [-1, 1]
104
+ # TODO: since audioset is annotated by tagging, targets have multiple elemenets: default collate fails
105
+ # targets = self.video2target[f'{Path(path).parent.stem}/{Path(path).stem[:11]}']
106
+ item = {
107
+ 'video': rgb,
108
+ 'audio': audio,
109
+ 'meta': meta,
110
+ 'path': str(path),
111
+ # 'targets': {'audioset_target': [targets], 'audioset_label': [self.target2label[t] for t in targets]},
112
+ 'targets': {},
113
+ 'split': self.split,
114
+ }
115
+
116
+ # loading the fixed offsets. COMMENT THIS IF YOU DON'T HAVE A FILE YET
117
+ if self.transforms is not None and self.split in self.load_fixed_offsets_on:
118
+ key = f'{self.split2short[self.split]}/{Path(path).stem}'
119
+ item['targets']['offset_sec'] = self.vid2offset_params[key]['offset_sec']
120
+ item['targets']['v_start_i_sec'] = self.vid2offset_params[key]['v_start_i_sec']
121
+
122
+ return item
123
+
124
+ def load_media(self, path):
125
+ rgb, audio, meta = get_video_and_audio(path, get_meta=True, end_sec=self.max_clip_len_sec)
126
+ return rgb, audio, meta
127
+
128
+ def __len__(self):
129
+ return len(self.dataset)
130
+
131
+ class AudioSetBalanced737k(AudioSet):
132
+
133
+ def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True, splits_path='./data',
134
+ # here
135
+ meta_path='./data/audioset_balanced_737k.csv',
136
+ seed=1337, load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
137
+ attr_annot_path=None, max_attr_per_vid=None):
138
+ super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path,
139
+ seed, load_fixed_offsets_on, vis_load_backend, size_ratio)
140
+
141
+ class AudioSetBalanced540k(AudioSet):
142
+ ''' MBT's balanced 500k (from unbalanced part) + 20k from balaced part + 20k from eval part '''
143
+
144
+ def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True, splits_path='./data',
145
+ # here
146
+ meta_path='./data/audioset_balanced_540k.csv',
147
+ seed=1337, load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
148
+ attr_annot_path=None, max_attr_per_vid=None):
149
+ super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path,
150
+ seed, load_fixed_offsets_on, vis_load_backend, size_ratio)
151
+
152
+
153
+ if __name__ == '__main__':
154
+ from omegaconf import OmegaConf
155
+ from scripts.train_utils import get_transforms
156
+ from utils.utils import cfg_sanity_check_and_patch
157
+ cfg = OmegaConf.load('./configs/sparse_sync.yaml')
158
+ vis_load_backend = 'read_video'
159
+
160
+ transforms = get_transforms(cfg)
161
+
162
+ # vids_path = 'PLACEHOLDER'
163
+ vids_path = '/scratch/project_2000936/vladimir/data/audioset/h264_video_25fps_256side_16000hz_aac'
164
+ load_fixed_offsets_on = []
165
+
166
+ cfg.data.dataset.params.size_ratio = 0.1
167
+
168
+ cfg_sanity_check_and_patch(cfg)
169
+
170
+ datasets = {
171
+ 'train': AudioSet('train', vids_path, transforms['train'], vis_load_backend=vis_load_backend,
172
+ to_filter_bad_examples=True, size_ratio=cfg.data.dataset.params.size_ratio,
173
+ load_fixed_offsets_on=load_fixed_offsets_on),
174
+ 'valid': AudioSet('valid', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
175
+ to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
176
+ 'test': AudioSet('test', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
177
+ to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
178
+ }
179
+ for phase in ['train', 'valid', 'test']:
180
+ print(phase, len(datasets[phase]))
181
+
182
+ print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
183
+ print(datasets['train'][0]['meta'])
184
+ print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
185
+ print(datasets['valid'][0]['meta'])
186
+ print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
187
+ print(datasets['test'][0]['meta'])
188
+
189
+ for i in range(300, 1000):
190
+ datasets['train'][i]['path']
191
+ print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
192
+ print(datasets['train'][0]['meta'])
193
+
194
+ datasets = {
195
+ 'train': AudioSetBalanced737k('train', vids_path, transforms['train'], vis_load_backend=vis_load_backend,
196
+ to_filter_bad_examples=True, size_ratio=cfg.data.dataset.params.size_ratio,
197
+ load_fixed_offsets_on=load_fixed_offsets_on),
198
+ 'valid': AudioSetBalanced737k('valid', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
199
+ to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
200
+ 'test': AudioSetBalanced737k('test', vids_path, transforms['test'], vis_load_backend=vis_load_backend,
201
+ to_filter_bad_examples=True, load_fixed_offsets_on=load_fixed_offsets_on),
202
+ }
203
+ for phase in ['train', 'valid', 'test']:
204
+ print(phase, len(datasets[phase]))
205
+
206
+ print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
207
+ print(datasets['train'][0]['meta'])
208
+ print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
209
+ print(datasets['valid'][0]['meta'])
210
+ print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
211
+ print(datasets['test'][0]['meta'])
212
+
213
+ for i in range(300, 1000):
214
+ datasets['train'][i]['path']
215
+ print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
216
+ print(datasets['train'][0]['meta'])
modules/dataset/dataset_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import random
4
+ from pathlib import Path
5
+ from glob import glob
6
+ import shutil
7
+ import logging
8
+
9
+ import torchaudio
10
+ import torchvision
11
+
12
+ from utils.utils import get_fixed_off_fname
13
+
14
+
15
+ def get_fixed_offsets(transforms, split, splits_path, dataset_name):
16
+ '''dataset_name: `vggsound` or `lrs3`'''
17
+ logging.info(f'Using fixed offset for {split}')
18
+ vid2offset_params = {}
19
+ fixed_offset_fname = get_fixed_off_fname(transforms, split)
20
+ if fixed_offset_fname is None:
21
+ raise ValueError('Cant find fixed offsets for given params. Perhaps you need to make it first?')
22
+ fixed_offset_path = os.path.join(splits_path, f'fixed_offsets_{dataset_name}', fixed_offset_fname)
23
+ fixed_offset_paths = sorted(glob(fixed_offset_path.replace(split, '*')))
24
+ assert len(fixed_offset_paths) > 0, f'Perhaps: {fixed_offset_path} does not exist. Make fixed offsets'
25
+
26
+ for fix_off_path in fixed_offset_paths:
27
+ reader = csv.reader(open(fix_off_path))
28
+ # k700_2020 has no header, and also `vstart` comes before `offset_sec`
29
+ if dataset_name == 'k700_2020':
30
+ header = ['path', 'vstart_sec', 'offset_sec', 'oos_target']
31
+ else:
32
+ header = next(reader)
33
+ for line in reader:
34
+ data = dict()
35
+ for f, value in zip(header, line):
36
+ if f == 'path':
37
+ v = value
38
+ elif f == 'offset_sec':
39
+ data[f] = float(value)
40
+ elif f in ['vstart_sec', 'v_start_sec']:
41
+ f = 'v_start_i_sec'
42
+ data[f] = float(value)
43
+ elif f == 'oos_target':
44
+ data[f] = int(value)
45
+ else:
46
+ data[f] = value
47
+ # assert v not in vid2offset_params, 'otherwise, offs from other splits will override each other'
48
+
49
+ # even if we have multiple splits (val=test), we want to make sure that the offsets are the same
50
+ if v in vid2offset_params:
51
+ assert all([vid2offset_params[v][k] == data[k] for k in data]), f'{v} isnt unique and vary'
52
+
53
+ vid2offset_params[v] = data
54
+ return vid2offset_params
55
+
56
+
57
+ def maybe_cache_file(path: os.PathLike):
58
+ '''Motivation: if every job reads from a shared disk it`ll get very slow, consider an image can
59
+ be 2MB, then with batch size 32, 16 workers in dataloader you`re already requesting 1GB!! -
60
+ imagine this for all users and all jobs simultaneously.'''
61
+ # checking if we are on cluster, not on a local machine
62
+ if 'LOCAL_SCRATCH' in os.environ:
63
+ cache_dir = os.environ.get('LOCAL_SCRATCH')
64
+ # a bit ugly but we need not just fname to be appended to `cache_dir` but parent folders,
65
+ # otherwise the same fnames in multiple folders will create a bug (the same input for multiple paths)
66
+ cache_path = os.path.join(cache_dir, Path(path).relative_to('/'))
67
+ if not os.path.exists(cache_path):
68
+ os.makedirs(Path(cache_path).parent, exist_ok=True)
69
+ shutil.copyfile(path, cache_path)
70
+ return cache_path
71
+ else:
72
+ return path
73
+
74
+
75
+ def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None):
76
+ orig_path = path
77
+ path = maybe_cache_file(path)
78
+ # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta)
79
+ rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, 'sec', output_format='TCHW')
80
+ assert meta['video_fps'], f'No video fps for {orig_path}'
81
+ # (Ta) <- (Ca, Ta)
82
+ audio = audio.mean(dim=0)
83
+ # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader.
84
+ meta = {'video': {'fps': [meta['video_fps']]}, 'audio': {'framerate': [meta['audio_fps']]}, }
85
+ return rgb, audio, meta
86
+
87
+
88
+ def get_audio_stream(path, get_meta=False):
89
+ '''Used only in feature extractor training'''
90
+ path = str(Path(path).with_suffix('.wav'))
91
+ path = maybe_cache_file(path)
92
+ waveform, _ = torchaudio.load(path)
93
+ waveform = waveform.mean(dim=0)
94
+ if get_meta:
95
+ info = torchaudio.info(path)
96
+ duration = info.num_frames / info.sample_rate
97
+ meta = {'audio': {'duration': [duration], 'framerate': [info.sample_rate]}}
98
+ return waveform, meta
99
+ else:
100
+ return waveform
101
+
102
+ def subsample_dataset(dataset: list, size_ratio: float, shuffle: bool = False):
103
+ if size_ratio is not None and 0.0 < size_ratio < 1.0:
104
+ logging.info(f'Subsampling dataset to {size_ratio}')
105
+ # shuffling is important only during subsampling (sometimes paths are sorted by class)
106
+ if shuffle:
107
+ random.shuffle(dataset)
108
+ cut_off = int(len(dataset) * size_ratio)
109
+ # making sure that we have at least one example
110
+ dataset = dataset[:max(1, cut_off)]
111
+ logging.info(f'Subsampled dataset to {size_ratio} (size: {len(dataset)})')
112
+ return dataset
modules/dataset/lrs.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import logging
3
+ import os
4
+ import random
5
+ import sys
6
+ from glob import glob
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+
12
+ sys.path.insert(0, '.') # nopep8
13
+ from dataset.dataset_utils import (get_fixed_offsets, get_video_and_audio, subsample_dataset)
14
+
15
+
16
+ class LRS3(torch.utils.data.Dataset):
17
+
18
+ def __init__(self,
19
+ split,
20
+ vids_dir,
21
+ transforms=None,
22
+ splits_path='./data',
23
+ seed=1337,
24
+ load_fixed_offsets_on=['valid', 'test'],
25
+ vis_load_backend='VideoReader',
26
+ size_ratio=None,
27
+ attr_annot_path=None,
28
+ max_attr_per_vid=None,
29
+ to_filter_bad_examples=True,):
30
+ super().__init__()
31
+ self.max_clip_len_sec = 11
32
+ logging.info(f'During IO, the length of clips is limited to {self.max_clip_len_sec} sec')
33
+ self.split = split
34
+ self.vids_dir = vids_dir
35
+ self.transforms = transforms
36
+ self.splits_path = splits_path
37
+ self.seed = seed
38
+ self.load_fixed_offsets_on = [] if load_fixed_offsets_on is None else load_fixed_offsets_on
39
+ self.vis_load_backend = vis_load_backend
40
+ self.size_ratio = size_ratio
41
+
42
+ split_clip_ids_path = os.path.join(splits_path, f'lrs3_{split}.txt')
43
+ if not os.path.exists(split_clip_ids_path):
44
+ vid_folder = Path(vids_dir) / 'pretrain'
45
+ clip_paths = sorted(vid_folder.rglob('*/*.mp4'))
46
+ if to_filter_bad_examples:
47
+ clip_paths = self.filter_bad_examples(clip_paths)
48
+ self.make_split_files(clip_paths)
49
+
50
+ # read the ids from a split
51
+ split_clip_ids = sorted(open(split_clip_ids_path).read().splitlines())
52
+
53
+ # make paths from the ids
54
+ clip_paths = [os.path.join(vids_dir, v + '.mp4') for v in split_clip_ids]
55
+
56
+ if split in self.load_fixed_offsets_on:
57
+ logging.info(f'Using fixed offset for {split}')
58
+ self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'lrs3')
59
+
60
+ self.dataset = clip_paths
61
+ self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
62
+
63
+ logging.info(f'{split} has {len(self.dataset)} items')
64
+
65
+ def __getitem__(self, index):
66
+ path = self.dataset[index]
67
+ rgb, audio, meta = get_video_and_audio(path, get_meta=True, end_sec=self.max_clip_len_sec)
68
+
69
+ # (Tv, 3, H, W) in [0, 225], (Ta, C) in [-1, 1]
70
+ item = {'video': rgb, 'audio': audio, 'meta': meta, 'path': path, 'targets': {}, 'split': self.split}
71
+
72
+ # loading fixed offsets so we could evaluate on the same data each time (valid and test)
73
+ if self.split in self.load_fixed_offsets_on:
74
+ unique_id = path.replace(f'{self.vids_dir}/', '').replace(self.vids_dir, '').replace('.mp4', '')
75
+ offset_params = self.vid2offset_params[unique_id]
76
+ item['targets']['offset_sec'] = offset_params['offset_sec']
77
+ item['targets']['v_start_i_sec'] = offset_params['v_start_i_sec']
78
+ if 'oos_target' in offset_params:
79
+ item['targets']['offset_target'] = {
80
+ 'oos': offset_params['oos_target'], 'offset': item['targets']['offset_sec'],
81
+ }
82
+
83
+ if self.transforms is not None:
84
+ item = self.transforms(item)
85
+
86
+ return item
87
+
88
+ def filter_bad_examples(self, paths):
89
+ bad = set()
90
+ base_path = Path('./data/filtered_examples_lrs')
91
+ lists = [open(p).read().splitlines() for p in sorted(glob(str(base_path / '*.txt')))]
92
+ for s in lists:
93
+ bad = bad.union(s)
94
+ logging.info(f'Number of clips before filtering: {len(paths)}')
95
+ video_ids = [str(i).replace(self.vids_dir, '') for i in paths]
96
+ video_ids = [str(i).replace(f'{self.vids_dir}/', '') for i in video_ids]
97
+ paths = sorted([r for r in video_ids if r not in bad])
98
+ logging.info(f'Number of clips after filtering: {len(paths)}')
99
+ return paths
100
+
101
+ def make_split_files(self, paths):
102
+ logging.warning(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
103
+
104
+ # will be splitting using videos, not clips to prevent train-test intersection
105
+ all_vids = sorted(list(set([Path(p).parent.name for p in paths])))
106
+ random.Random(self.seed).shuffle(all_vids)
107
+
108
+ # 0.1: splits are 8:1:1
109
+ hold_out_ratio = 0.1
110
+ hold_out_size = int(len(all_vids) * hold_out_ratio)
111
+ test_vids, train_valid_vids = all_vids[:hold_out_size], all_vids[hold_out_size:]
112
+ valid_vids, train_vids = train_valid_vids[:hold_out_size], train_valid_vids[hold_out_size:]
113
+
114
+ # making files
115
+ for phase, vids in zip(['train', 'valid', 'test'], [train_vids, valid_vids, test_vids]):
116
+ with open(os.path.join(self.splits_path, f'lrs3_{phase}.txt'), 'w') as wfile:
117
+ for path in paths:
118
+ vid_name = Path(path).parent.name
119
+ # just in the case I forgot the trailing '/' in the path
120
+ unique_id = path.replace(f'{self.vids_dir}/', '').replace(self.vids_dir, '') \
121
+ .replace('.mp4', '')
122
+ if vid_name in vids:
123
+ wfile.write(unique_id + '\n')
124
+
125
+ def __len__(self):
126
+ return len(self.dataset)
127
+
128
+ class LongerLRS3(LRS3):
129
+ '''This class is different to the parent in the extra filtering it does. If the parent was
130
+ making the splits with filtering for shorter than 9 second, this class filters for shorter than 9.5 sec.
131
+ by applying extra filtering.
132
+ '''
133
+
134
+ def __init__(self,
135
+ split,
136
+ vids_dir,
137
+ transforms=None,
138
+ splits_path='./data',
139
+ seed=1337,
140
+ load_fixed_offsets_on=['valid', 'test'],
141
+ vis_load_backend='VideoReader',
142
+ size_ratio=None,
143
+ attr_annot_path=None,
144
+ max_attr_per_vid=None,
145
+ to_filter_bad_examples=True,):
146
+ # size_ratio is not used here as we are doing it this class (avoiding double subsampling)
147
+ super().__init__(split, vids_dir, transforms, splits_path, seed, load_fixed_offsets_on,
148
+ vis_load_backend, None, attr_annot_path, max_attr_per_vid,
149
+ to_filter_bad_examples)
150
+ # does extra filtering
151
+ if to_filter_bad_examples:
152
+ self.dataset = self.filter_bad_examples(self.dataset)
153
+ self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
154
+ logging.info(f'{split} has {len(self.dataset)} items')
155
+
156
+ def filter_bad_examples(self, paths):
157
+ bad = set()
158
+ base_path = Path('./data/filtered_examples_lrs_extra')
159
+ lists = [open(p).read().splitlines() for p in sorted(glob(str(base_path / '*.txt')))]
160
+ for s in lists:
161
+ bad = bad.union(s)
162
+ logging.info(f'Number of clips before filtering: {len(paths)}')
163
+ video_ids = [str(i).replace(self.vids_dir, '').replace(f'{self.vids_dir}/', '') for i in paths]
164
+ paths = sorted([os.path.join(self.vids_dir, r) for r in video_ids if r not in bad])
165
+ logging.info(f'Number of clips after filtering: {len(paths)}')
166
+ return paths
167
+
168
+
169
+ if __name__ == '__main__':
170
+ from time import time
171
+ from omegaconf import OmegaConf
172
+ import sys
173
+ sys.path.insert(0, '.') # nopep8
174
+ from scripts.train_utils import get_transforms
175
+ from utils.utils import cfg_sanity_check_and_patch
176
+ cfg = OmegaConf.load('./configs/sparse_sync.yaml')
177
+ cfg.data.vids_path = '/scratch/local/hdd/vi/data/lrs3/h264_uncropped_25fps_256side_16000hz_aac/'
178
+ cfg.data.dataset.params.load_fixed_offsets_on = ['valid', 'test']
179
+
180
+ cfg_sanity_check_and_patch(cfg)
181
+ transforms = get_transforms(cfg)
182
+
183
+ datasets = {
184
+ 'train': LRS3('train', cfg.data.vids_path, transforms['train'], load_fixed_offsets_on=[]),
185
+ 'valid': LRS3('valid', cfg.data.vids_path, transforms['test'], load_fixed_offsets_on=[]),
186
+ 'test': LRS3('test', cfg.data.vids_path, transforms['test'], load_fixed_offsets_on=[]),
187
+ }
188
+ for phase in ['train', 'valid', 'test']:
189
+ print(phase, len(datasets[phase]))
190
+
191
+ print(datasets['train'][1]['audio'].shape, datasets['train'][1]['video'].shape)
192
+ print(datasets['valid'][1]['audio'].shape, datasets['valid'][1]['video'].shape)
modules/dataset/transforms.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import random
4
+ from typing import Tuple
5
+ import torch
6
+ import torchvision
7
+ import torchaudio
8
+ import numpy as np
9
+ import einops
10
+
11
+
12
+ def sec2frames(sec, fps):
13
+ return int(sec * fps)
14
+
15
+ def frames2sec(frames, fps):
16
+ return frames / fps
17
+
18
+
19
+ class EqualifyFromRight(torch.nn.Module):
20
+
21
+ def __init__(self, clip_max_len_sec=10):
22
+ '''
23
+ Takes the dataset item and makes sure more streams are of an equal size in terms of fps.
24
+ It, however, assumes that the signal is synched and trims the ending parts ('from the right').
25
+ '''
26
+ super().__init__()
27
+ self.clip_max_len_sec = clip_max_len_sec
28
+
29
+ def forward(self, item):
30
+ '''
31
+ `item`: {'video': (Tv, C, H, W), 'audio': (Ta,),
32
+ 'meta': {
33
+ 'audio': {'framerate': [float], 'duration': [float]}
34
+ 'video': {'fps': [float], 'duration': [float]}}
35
+ '''
36
+ a_fps = item['meta']['audio']['framerate'][0]
37
+ v_fps = item['meta']['video']['fps'][0]
38
+
39
+ Ta = item['audio'].shape[0]
40
+ Tv, C, H, W = item['video'].shape
41
+
42
+ a_len_secs = Ta / a_fps
43
+ v_len_secs = Tv / v_fps
44
+ min_len = min(self.clip_max_len_sec, a_len_secs, v_len_secs)
45
+
46
+ a_frames_per_v_frame = a_fps // v_fps
47
+ v_len_frames = int(v_fps * min_len)
48
+ a_len_frames = int(a_frames_per_v_frame * v_len_frames)
49
+ # print(a_len_frames, v_len_frames)
50
+
51
+ assert a_len_frames <= Ta and v_len_frames <= Tv
52
+
53
+ item['audio'] = item['audio'][:a_len_frames]
54
+ item['video'] = item['video'][:v_len_frames, :, :, :]
55
+
56
+ return item
57
+
58
+
59
+ class RGBSpatialCrop(torch.nn.Module):
60
+
61
+ def __init__(self, input_size, is_random):
62
+ super().__init__()
63
+ assert input_size is not None, f'smaller_input_size is `{input_size}`'
64
+ if isinstance(input_size, int):
65
+ input_size = (input_size, input_size)
66
+ self.input_size = input_size
67
+ self.is_random = is_random
68
+
69
+ @staticmethod
70
+ def get_random_crop_sides(vid, output_size):
71
+ '''Slice parameters for random crop'''
72
+ h, w = vid.shape[-2:]
73
+ th, tw = output_size
74
+ if w == tw and h == th:
75
+ return 0, 0, h, w
76
+ i = random.randint(0, h - th)
77
+ j = random.randint(0, w - tw)
78
+ return i, j, th, tw
79
+
80
+ @staticmethod
81
+ def get_center_crop_sides(vid, output_size):
82
+ '''Slice parameters for center crop'''
83
+ h, w = vid.shape[-2:]
84
+ th, tw = output_size
85
+
86
+ i = int(round((h - th) / 2.))
87
+ j = int(round((w - tw) / 2.))
88
+ return i, j, th, tw
89
+
90
+ def forward(self, item):
91
+ # (Tv, C, H, W)
92
+ vid = item['video']
93
+ if self.is_random:
94
+ i, j, h, w = self.get_random_crop_sides(vid, self.input_size)
95
+ else:
96
+ i, j, h, w = self.get_center_crop_sides(vid, self.input_size)
97
+ item['video'] = vid[..., i:(i + h), j:(j + w)]
98
+ return item
99
+
100
+ class Resize(torchvision.transforms.Resize):
101
+
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__(*args, **kwargs)
104
+
105
+ def forward(self, item):
106
+ item['video'] = super().forward(item['video'])
107
+ return item
108
+
109
+
110
+ class RGBSpatialCropSometimesUpscale(torch.nn.Module):
111
+ '''This (randomly) crops the input video and with prob `sometimes_p` this crop is smaller but upscaled
112
+ to `target_input_size`'''
113
+
114
+ def __init__(self, sometimes_p, target_input_size, is_random, smaller_input_size=None):
115
+ super().__init__()
116
+ self.sometimes_p = sometimes_p
117
+ self.do_sometimes_upscale = sometimes_p is not None and sometimes_p > 0
118
+
119
+ self.crop_only = RGBSpatialCrop(target_input_size, is_random)
120
+
121
+ if self.do_sometimes_upscale:
122
+ self.crop_further_and_upscale = torchvision.transforms.Compose([
123
+ RGBSpatialCrop(smaller_input_size, is_random),
124
+ Resize(target_input_size, antialias=None),
125
+ ])
126
+
127
+ def forward(self, item):
128
+ assert len(item['video'].shape) == 4, \
129
+ f"{item['video'].shape}: if it is applied after GenerateMultipleClips," \
130
+ "augs should be applied to each clip separately, not to the whole video array. " \
131
+ "Otherwise, ignore this warning (comment it)."
132
+ if self.do_sometimes_upscale and self.sometimes_p > torch.rand(1):
133
+ return self.crop_further_and_upscale(item)
134
+ else:
135
+ return self.crop_only(item)
136
+
137
+
138
+ class RandomApplyColorDistortion(torch.nn.Module):
139
+
140
+ def __init__(self, p_gray_scale=0., p_color_jitter=0., s=1.) -> None:
141
+ super().__init__()
142
+ self.p_gray_scale = p_gray_scale
143
+ self.p_color_jitter = p_color_jitter
144
+ self.s = s
145
+ assert 0 <= self.p_color_jitter <= 1 and 0 <= self.p_gray_scale <= 1, (p_color_jitter, p_gray_scale)
146
+ # SimCLR params
147
+ color_jitter = torchvision.transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
148
+ rand_color_jitter = torchvision.transforms.RandomApply([color_jitter], p_color_jitter)
149
+ rand_gray = torchvision.transforms.RandomGrayscale(p_gray_scale)
150
+ self.transforms = torchvision.transforms.Compose([rand_color_jitter, rand_gray])
151
+
152
+ def apply_to_single_clip(self, clip):
153
+ return self.transforms(clip)
154
+
155
+ def apply_to_each_clip(self, clips):
156
+ for i, clip in enumerate(clips):
157
+ clips[i] = self.apply_to_single_clip(clip)
158
+ return clips
159
+
160
+ def forward(self, item):
161
+ has_batch_dim = len(item['video'].shape) == 5
162
+ if has_batch_dim:
163
+ fn = self.apply_to_each_clip
164
+ else:
165
+ fn = self.apply_to_single_clip
166
+ item['video'] = fn(item['video'])
167
+ return item
168
+
169
+
170
+ class ApplyColorJitterFrameWise(torch.nn.Module):
171
+
172
+ def __init__(self, s=1.) -> None:
173
+ super().__init__()
174
+ self.s = s
175
+ # SimCLR params
176
+ self.transform = torchvision.transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
177
+
178
+ def apply_to_single_clip(self, clip):
179
+ for i, frame in enumerate(clip):
180
+ clip[i] = self.transform(frame)
181
+ return clip
182
+
183
+ def apply_to_each_clip(self, clips):
184
+ for i, clip in enumerate(clips):
185
+ clips[i] = self.apply_to_single_clip(clip)
186
+ return clips
187
+
188
+ def forward(self, item):
189
+ has_batch_dim = len(item['video'].shape) == 5
190
+ if has_batch_dim:
191
+ fn = self.apply_to_each_clip
192
+ else:
193
+ fn = self.apply_to_single_clip
194
+ item['video'] = fn(item['video'])
195
+ return item
196
+
197
+
198
+ class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
199
+
200
+ def __init__(self, p=0.5):
201
+ super().__init__(p)
202
+
203
+ def apply_to_single_clip(self, clip):
204
+ return super().forward(clip)
205
+
206
+ def apply_to_each_clip(self, clips):
207
+ for i, clip in enumerate(clips):
208
+ clips[i] = self.apply_to_single_clip(clip)
209
+ return clips
210
+
211
+ def forward(self, item):
212
+ has_batch_dim = len(item['video'].shape) == 5
213
+ if has_batch_dim:
214
+ fn = self.apply_to_each_clip
215
+ else:
216
+ fn = self.apply_to_single_clip
217
+ item['video'] = fn(item['video'])
218
+ return item
219
+
220
+
221
+ def make_class_grid(leftmost_val, rightmost_val, grid_size, add_extreme_offset: bool = False,
222
+ seg_size_vframes: int = None, nseg: int = None, step_size_seg: float = None,
223
+ vfps: float = None):
224
+ assert grid_size >= 3, f'grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()'
225
+ grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float()
226
+ if add_extreme_offset:
227
+ assert all([seg_size_vframes, nseg, step_size_seg]), f'{seg_size_vframes} {nseg} {step_size_seg}'
228
+ seg_size_sec = seg_size_vframes / vfps
229
+ trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1)
230
+ extreme_value = trim_size_in_seg * seg_size_sec
231
+ grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid
232
+ return grid
233
+
234
+
235
+ def quantize_offset(grid: torch.Tensor, off_sec: float) -> Tuple[float, int]:
236
+ '''Takes in the offset in seconds and snaps it onto the closest grid element.
237
+ Returns the grid value and its index.'''
238
+ closest_grid_el = (grid - off_sec).abs().argmin()
239
+ return grid[closest_grid_el], closest_grid_el
240
+
241
+ def apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps, max_a_jitter_sec):
242
+ max_a_start_i = a_len_frames - a_crop_len_frames
243
+ max_a_jitter_i = sec2frames(max_a_jitter_sec, a_fps)
244
+ max_a_jitter_i_left = min(a_start_i, max_a_jitter_i)
245
+ max_a_jitter_i_right = min(max_a_start_i - a_start_i, max_a_jitter_i)
246
+ # jitter is U[left, right]
247
+ a_jitter_i = random.randint(-max_a_jitter_i_left, max_a_jitter_i_right)
248
+ # apply jitter
249
+ a_start_i = a_start_i + a_jitter_i
250
+ # making sure that any value from `a_start_i + U[left, right]` will be inside of [0, len-crop] region
251
+ assert 0 <= a_start_i <= max_a_start_i, f'{a_jitter_i} {max_a_jitter_i_left} {max_a_jitter_i_right} {max_a_start_i}'
252
+ return a_start_i, a_jitter_i
253
+
254
+
255
+ class TemporalCropAndOffset(torch.nn.Module):
256
+
257
+ def __init__(self, crop_len_sec: float, max_off_sec: float, offset_type='grid', do_offset: bool = True,
258
+ grid_size: int = None, max_wiggle_sec: float = None, add_doubt_cls: bool = False,
259
+ segment_size_vframes: int = None, n_segments: int = None, step_size_seg: float = None,
260
+ vfps: float = None, prob_oos: float = None):
261
+ super().__init__()
262
+ self.crop_len_sec = crop_len_sec
263
+ self.do_offset = do_offset
264
+ self.grid_size = grid_size
265
+ self.offset_type = offset_type
266
+ self.max_off_sec = max_off_sec
267
+ self.max_a_jitter_sec = max_wiggle_sec
268
+ if do_offset:
269
+ if offset_type == 'grid':
270
+ self.class_grid = make_class_grid(-max_off_sec, max_off_sec, grid_size, add_doubt_cls,
271
+ segment_size_vframes, n_segments, step_size_seg, vfps)
272
+ logging.info(f'Offsets class grid: {self.class_grid}')
273
+ if self.max_a_jitter_sec is not None:
274
+ assert (max_wiggle_sec-1e-6) <= ((self.class_grid[1] - self.class_grid[0]) / 2), f'{self.class_grid}'
275
+ elif offset_type == 'uniform':
276
+ self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec)
277
+ logging.info(f'Offset uniform distribution: {self.off_dist}')
278
+ elif offset_type == 'uniform_binary':
279
+ self.itu_t_range = (-0.125, 0.045)
280
+ self.prob_oos = prob_oos
281
+ self.ins_dist = torch.distributions.uniform.Uniform(self.itu_t_range[0], self.itu_t_range[1])
282
+ self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec)
283
+ else:
284
+ raise NotImplementedError(f'Unknown offset type: {offset_type}')
285
+
286
+ def forward(self, item):
287
+ vid = item['video']
288
+ aud = item['audio']
289
+ v_len_frames, C, H, W = vid.shape
290
+ a_len_frames = aud.shape[0]
291
+
292
+ v_fps = int(item['meta']['video']['fps'][0])
293
+ a_fps = int(item['meta']['audio']['framerate'][0])
294
+
295
+ v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps)
296
+ a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps)
297
+
298
+ if self.do_offset:
299
+ # trying to get the offset parameters (for instance during valid and test we have fixed offsets)
300
+ offset_sec = item['targets'].get('offset_sec', None)
301
+ v_start_i_sec = item['targets'].get('v_start_i_sec', None)
302
+ if 'offset_target' in item['targets']:
303
+ is_oos = item['targets']['offset_target'].get('oos', None)
304
+ # train-time
305
+ if offset_sec is None and v_start_i_sec is None:
306
+ # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec
307
+ if self.offset_type == 'grid':
308
+ offset_sec = random.choice(self.class_grid.tolist())
309
+ elif self.offset_type == 'uniform':
310
+ offset_sec = self.off_dist.sample().item()
311
+ elif self.offset_type == 'uniform_binary':
312
+ # in-sync: Uniform(-0.125, 0.045)
313
+ # out-of-sync: Uniform(-5.5, 5.5) and resampled until not in Uniform(-0.125, 0.045)
314
+ # first, we sample if the offset is out-of-sync with prob_oss
315
+ is_oos = (torch.rand(1) < self.prob_oos).item()
316
+ if is_oos:
317
+ # second, we sample the offset itself (if in in-sync range, trying again)
318
+ offset_sec = self.off_dist.sample().item()
319
+ while self.itu_t_range[0] <= offset_sec <= self.itu_t_range[1]:
320
+ offset_sec = self.off_dist.sample().item()
321
+ else:
322
+ offset_sec = self.ins_dist.sample().item()
323
+ offset_sec = round(offset_sec, 2)
324
+ v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps)
325
+ assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}'
326
+ # `v_start_sec` IS NOT rounded to the fps grid
327
+ v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec-offset_sec))
328
+ assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}'
329
+ v_start_i = sec2frames(v_start_sec, v_fps)
330
+ # `v_start_i_sec` IS rounded to the fps grid
331
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
332
+ else:
333
+ offset_sec = round(offset_sec, 2)
334
+ v_start_i = sec2frames(v_start_i_sec, v_fps)
335
+ v_end_i = v_start_i + v_crop_len_frames
336
+ # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise
337
+ # (v_start_sec) we have ±0.1 jittering
338
+ a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
339
+ else:
340
+ offset_sec = 0.0
341
+ is_random_crop = item['split'] == 'train'
342
+ v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop)
343
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
344
+ a_start_i = sec2frames(v_start_i_sec, a_fps)
345
+
346
+ # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48
347
+ # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5)
348
+ if a_start_i < 0:
349
+ how_much_out = a_start_i
350
+ logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}')
351
+ if abs(how_much_out) <= a_fps / v_fps:
352
+ logging.info('fixing it')
353
+ a_start_i += abs(how_much_out)
354
+ else:
355
+ raise Exception(f'{how_much_out} {item["path"]}')
356
+
357
+ if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
358
+ a_start_i, a_jitter_i = apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps,
359
+ self.max_a_jitter_sec)
360
+ item['meta']['a_jitter_i'] = a_jitter_i
361
+
362
+ a_end_i = a_start_i + a_crop_len_frames
363
+
364
+ assert v_start_i < v_end_i and a_start_i < a_end_i
365
+ assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}'
366
+ assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}'
367
+
368
+ vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i]
369
+
370
+ item['video'] = vid
371
+ item['audio'] = aud
372
+
373
+ assert item['video'].shape[0] == v_fps * self.crop_len_sec, f'{item["video"].shape} {item["path"]}'
374
+ assert item['audio'].shape[0] == a_fps * self.crop_len_sec, f'{item["audio"].shape} {item["path"]}'
375
+
376
+ # caching parameters
377
+ if self.do_offset:
378
+ if self.offset_type == 'grid':
379
+ offset_label, offset_target = quantize_offset(self.class_grid, offset_sec)
380
+ elif self.offset_type == 'uniform':
381
+ offset_label, offset_target = offset_sec, offset_sec
382
+ elif self.offset_type == 'uniform_binary':
383
+ offset_label, offset_target = offset_sec, {'oos': is_oos, 'offset': offset_sec}
384
+ item['targets']['offset_sec'] = offset_sec
385
+ item['targets']['v_start_i_sec'] = v_start_i_sec
386
+ item['targets']['offset_label'] = offset_label
387
+ # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?'
388
+ item['targets']['offset_target'] = offset_target
389
+
390
+ return item
391
+
392
+ def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True):
393
+ if len_frames == crop_len_frames:
394
+ return 0, len_frames
395
+ if is_random:
396
+ left_i = random.randint(0, len_frames - crop_len_frames)
397
+ else:
398
+ left_i = int(round((len_frames - crop_len_frames) / 2.))
399
+ return left_i, left_i+crop_len_frames
400
+
401
+
402
+ class GenerateMultipleSegments(torch.nn.Module):
403
+ '''
404
+ Given an item with video and audio, generates a batch of `n_segments` segments
405
+ of length `segment_size_vframes` (if None, the max number of segments will be made).
406
+ If `is_start_random` is True, the starting position of the 1st segment will be random but respecting
407
+ n_segments.
408
+ `audio_jitter_sec` is the amount of audio offset in seconds.
409
+ '''
410
+
411
+ def __init__(self, segment_size_vframes: int, n_segments: int = None, is_start_random: bool = False,
412
+ audio_jitter_sec: float = 0., step_size_seg: float = 1):
413
+ super().__init__()
414
+ self.segment_size_vframes = segment_size_vframes
415
+ self.n_segments = n_segments
416
+ self.is_start_random = is_start_random
417
+ self.audio_jitter_sec = audio_jitter_sec
418
+ self.step_size_seg = step_size_seg
419
+ logging.info(f'Segment step size: {self.step_size_seg}')
420
+
421
+ def forward(self, item):
422
+ v_len_frames, C, H, W = item['video'].shape
423
+ a_len_frames = item['audio'].shape[0]
424
+
425
+ v_fps = int(item['meta']['video']['fps'][0])
426
+ a_fps = int(item['meta']['audio']['framerate'][0])
427
+
428
+ ## Determining the number of segments
429
+ # segment size
430
+ segment_size_vframes = self.segment_size_vframes
431
+ segment_size_aframes = sec2frames(frames2sec(self.segment_size_vframes, v_fps), a_fps)
432
+ # step size (stride)
433
+ stride_vframes = int(self.step_size_seg * segment_size_vframes)
434
+ stride_aframes = int(self.step_size_seg * segment_size_aframes)
435
+ # calculating the number of segments. (W - F + 2P) / S + 1
436
+ n_segments_max_v = math.floor((v_len_frames - segment_size_vframes) / stride_vframes) + 1
437
+ n_segments_max_a = math.floor((a_len_frames - segment_size_aframes) / stride_aframes) + 1
438
+ # making sure audio and video can accommodate the same number of segments
439
+ n_segments_max = min(n_segments_max_v, n_segments_max_a)
440
+ n_segments = n_segments_max if self.n_segments is None else self.n_segments
441
+
442
+ assert n_segments <= n_segments_max, \
443
+ f'cant make {n_segments} segs of len {self.segment_size_vframes} in a vid ' \
444
+ f'of len {v_len_frames} for {item["path"]}'
445
+
446
+ # (n_segments, 2) each
447
+ v_ranges, a_ranges = self.get_sequential_seg_ranges(v_len_frames, a_len_frames, v_fps, a_fps,
448
+ n_segments, segment_size_aframes)
449
+
450
+ # segmenting original streams (n_segments, segment_size_frames, C, H, W)
451
+ item['video'] = torch.stack([item['video'][s:e] for s, e in v_ranges], dim=0)
452
+ item['audio'] = torch.stack([item['audio'][s:e] for s, e in a_ranges], dim=0)
453
+ return item
454
+
455
+ def get_sequential_seg_ranges(self, v_len_frames, a_len_frames, v_fps, a_fps, n_seg, seg_size_aframes):
456
+ # if is_start_random is True, the starting position of the 1st segment will
457
+ # be random but respecting n_segments like so: "-CCCCCCCC---" (maybe with fixed overlap),
458
+ # else the segments are taken from the middle of the video respecting n_segments: "--CCCCCCCC--"
459
+
460
+ seg_size_vframes = self.segment_size_vframes # for brevity
461
+
462
+ # calculating the step size in frames
463
+ step_size_vframes = int(self.step_size_seg * seg_size_vframes)
464
+ step_size_aframes = int(self.step_size_seg * seg_size_aframes)
465
+
466
+ # calculating the length of the sequence of segments (and in frames)
467
+ seg_seq_len = n_seg * self.step_size_seg + (1 - self.step_size_seg)
468
+ vframes_seg_seq_len = int(seg_seq_len * seg_size_vframes)
469
+ aframes_seg_seq_len = int(seg_seq_len * seg_size_aframes)
470
+
471
+ # doing temporal crop
472
+ max_v_start_i = v_len_frames - vframes_seg_seq_len
473
+ if self.is_start_random:
474
+ v_start_i = random.randint(0, max_v_start_i)
475
+ else:
476
+ v_start_i = max_v_start_i // 2
477
+ a_start_i = sec2frames(frames2sec(v_start_i, v_fps), a_fps) # vid frames -> seconds -> aud frames
478
+
479
+ # make segments starts
480
+ v_start_seg_i = torch.tensor([v_start_i + i * step_size_vframes for i in range(n_seg)]).int()
481
+ a_start_seg_i = torch.tensor([a_start_i + i * step_size_aframes for i in range(n_seg)]).int()
482
+
483
+ # apply jitter to audio
484
+ if self.audio_jitter_sec > 0:
485
+ jitter_aframes = sec2frames(self.audio_jitter_sec, a_fps)
486
+ # making sure after applying jitter, the audio is still within the audio boundaries
487
+ jitter_aframes = min(jitter_aframes, a_start_i, a_len_frames-a_start_i-aframes_seg_seq_len)
488
+ a_start_seg_i += random.randint(-jitter_aframes, jitter_aframes) # applying jitter to segments
489
+
490
+ # make segment ends
491
+ v_ends_seg_i = v_start_seg_i + seg_size_vframes
492
+ a_ends_seg_i = a_start_seg_i + seg_size_aframes # using the adjusted a_start_seg_i (with jitter)
493
+
494
+ # make ranges
495
+ v_ranges = torch.stack([v_start_seg_i, v_ends_seg_i], dim=1)
496
+ a_ranges = torch.stack([a_start_seg_i, a_ends_seg_i], dim=1)
497
+ assert (a_ranges >= 0).all() and (a_ranges <= a_len_frames).all(), f'{a_ranges} out of {a_len_frames}'
498
+ assert (v_ranges <= v_len_frames).all(), f'{v_ranges} out of {v_len_frames}'
499
+ return v_ranges, a_ranges
500
+
501
+
502
+ class TemporalCropAndOffsetForSyncabilityTraining(torch.nn.Module):
503
+
504
+ def __init__(self, max_off_sec: float, do_offset: bool = True,
505
+ grid_size: int = None, max_wiggle_sec: float = None,
506
+ segment_size_vframes: int = None, n_segments: int = None, step_size_seg: float = None,
507
+ vfps: float = None):
508
+ super().__init__()
509
+ seg_size_sec = segment_size_vframes / vfps
510
+ trim_size_in_seg = n_segments - (1 - step_size_seg) * (n_segments - 1)
511
+ self.crop_len_sec = round(trim_size_in_seg * seg_size_sec, 2)
512
+ logging.info(f'Crop len: {self.crop_len_sec}')
513
+ self.do_offset = do_offset
514
+ self.grid_size = grid_size
515
+ self.max_off_sec = max_off_sec
516
+ self.max_a_jitter_sec = max_wiggle_sec
517
+ self.segment_size_vframes = segment_size_vframes
518
+ self.n_segments = n_segments
519
+ self.step_size_seg = step_size_seg
520
+ self.prob_syncable = 0.5
521
+ if do_offset:
522
+ self.class_grid = make_class_grid(-max_off_sec, max_off_sec, grid_size)
523
+ logging.info(f'Offset class grid: {self.class_grid}')
524
+ if self.max_a_jitter_sec is not None:
525
+ assert (max_wiggle_sec-1e-6) <= ((self.class_grid[1] - self.class_grid[0]) / 2), f'{self.class_grid}'
526
+
527
+ def forward(self, item):
528
+ vid = item['video']
529
+ aud = item['audio']
530
+ v_len_frames, C, H, W = vid.shape
531
+ a_len_frames = aud.shape[0]
532
+
533
+ v_fps = int(item['meta']['video']['fps'][0])
534
+ a_fps = int(item['meta']['audio']['framerate'][0])
535
+
536
+ v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps)
537
+ a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps)
538
+
539
+ if self.do_offset:
540
+ # trying to get the offset parameters (for instance during valid and test we have fixed offsets)
541
+ offset_sec = item['targets'].get('offset_sec', None)
542
+ v_start_i_sec = item['targets'].get('v_start_i_sec', None)
543
+ # train-time
544
+ if offset_sec is None and v_start_i_sec is None:
545
+
546
+ # for the syncability training, we want to have a syncable or non-syncable offset with 50% prob
547
+ offset_is_syncable = random.random() < self.prob_syncable # 1=syncable, 0=non-syncable
548
+ if offset_is_syncable:
549
+ offset_sec = random.choice(self.class_grid.tolist())
550
+ else:
551
+ offset_sec = random.choice([-self.crop_len_sec, self.crop_len_sec]) # either - or + offset
552
+ # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec
553
+
554
+ offset_sec = round(offset_sec, 2)
555
+ v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps)
556
+ assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}'
557
+ # `v_start_sec` IS NOT rounded to the fps grid
558
+ v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec-offset_sec))
559
+ assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}'
560
+ v_start_i = sec2frames(v_start_sec, v_fps)
561
+ v_end_i = v_start_i + v_crop_len_frames
562
+ # `v_start_i_sec` IS rounded to the fps grid
563
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
564
+ # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise
565
+ # (v_start_sec) we have ±0.1 jittering
566
+ a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
567
+ if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
568
+ a_start_i, a_jitter_i = apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps,
569
+ self.max_a_jitter_sec)
570
+ item['meta']['a_jitter_i'] = a_jitter_i
571
+ a_end_i = a_start_i + a_crop_len_frames
572
+ else:
573
+ offset_sec = round(offset_sec, 2)
574
+ v_start_i = sec2frames(v_start_i_sec, v_fps)
575
+ a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
576
+ v_end_i = v_start_i + v_crop_len_frames
577
+ a_end_i = a_start_i + a_crop_len_frames
578
+ else:
579
+ offset_sec = 0.0
580
+ is_random_crop = item['split'] == 'train'
581
+ v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop)
582
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
583
+ a_start_i = sec2frames(v_start_i_sec, a_fps)
584
+ if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
585
+ a_start_i, a_jitter_i = apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps,
586
+ self.max_a_jitter_sec)
587
+ item['meta']['a_jitter_i'] = a_jitter_i
588
+ a_end_i = a_start_i + a_crop_len_frames
589
+
590
+ # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48
591
+ # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5)
592
+ if a_start_i < 0:
593
+ how_much_out = a_start_i
594
+ logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}')
595
+ if abs(how_much_out) <= a_fps / v_fps:
596
+ logging.info('fixing it')
597
+ a_start_i += abs(how_much_out)
598
+ a_end_i += abs(how_much_out)
599
+ else:
600
+ raise Exception(f'{how_much_out} {item["path"]}')
601
+
602
+ assert v_start_i < v_end_i and a_start_i < a_end_i
603
+ assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}'
604
+ assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}'
605
+
606
+ vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i]
607
+
608
+ item['video'] = vid
609
+ item['audio'] = aud
610
+
611
+ assert item['video'].shape[0] == int(v_fps*self.crop_len_sec), f'{item["video"].shape} {item["path"]}'
612
+ assert item['audio'].shape[0] == int(a_fps*self.crop_len_sec), f'{item["audio"].shape} {item["path"]}'
613
+
614
+ # caching parameters
615
+ if self.do_offset:
616
+ # NOTE: this is useless for the extreme offsetting
617
+ offset_label, offset_target = quantize_offset(self.class_grid, offset_sec)
618
+ item['targets']['offset_sec'] = offset_sec
619
+ item['targets']['offset_label'] = offset_label
620
+ # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?'
621
+ item['targets']['offset_target'] = offset_target
622
+ item['targets']['v_start_i_sec'] = v_start_i_sec
623
+ item['targets']['sync_target'] = int(offset_is_syncable)
624
+
625
+ return item
626
+
627
+ def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True):
628
+ if len_frames == crop_len_frames:
629
+ return 0, len_frames
630
+ if is_random:
631
+ left_i = random.randint(0, len_frames - crop_len_frames)
632
+ else:
633
+ left_i = int(round((len_frames - crop_len_frames) / 2.))
634
+ return left_i, left_i+crop_len_frames
635
+
636
+
637
+ class RGBToFloatToZeroOne(torch.nn.Module):
638
+
639
+ def __init__(self) -> None:
640
+ super().__init__()
641
+
642
+ def forward(self, item):
643
+ item['video'] = item['video'].to(torch.float32).div(255.)
644
+ return item
645
+
646
+
647
+ class RGBToHalfToZeroOne(torch.nn.Module):
648
+
649
+ def __init__(self) -> None:
650
+ super().__init__()
651
+
652
+ def forward(self, item):
653
+ item['video'] = item['video'].half().div(255.)
654
+ return item
655
+
656
+
657
+ class RGBNormalize(torchvision.transforms.Normalize):
658
+ '''The same as the torchvision`s but with different interface for the dict.
659
+ This should work for any shape (..., C, H, W)'''
660
+
661
+ def __init__(self, mean, std, inplace=False):
662
+ super().__init__(mean, std, inplace)
663
+ logging.info(f'RGBNormalize: mean={mean}, std={std}')
664
+
665
+ def forward(self, item):
666
+ item['video'] = super().forward(item['video'])
667
+ item['meta']['video']['norm_stats'] = {'mean': torch.as_tensor(self.mean),
668
+ 'std': torch.as_tensor(self.std)}
669
+ return item
670
+
671
+
672
+ class AudioRandomVolume(torch.nn.Module):
673
+
674
+ def __init__(self, p: float, **kwargs):
675
+ super().__init__()
676
+ transform = torchaudio.transforms.Vol(**kwargs)
677
+ self.transform = torchvision.transforms.RandomApply([transform], p)
678
+
679
+ def apply_to_single_clip(self, clip):
680
+ return self.transform(clip)
681
+
682
+ def apply_to_each_clip(self, clips):
683
+ for i, clip in enumerate(clips):
684
+ clips[i] = self.apply_to_single_clip(clip)
685
+ return clips
686
+
687
+ def forward(self, item):
688
+ has_batch_dim = len(item['audio'].shape) == 2
689
+ if has_batch_dim:
690
+ fn = self.apply_to_each_clip
691
+ else:
692
+ fn = self.apply_to_single_clip
693
+ item['audio'] = fn(item['audio'])
694
+ return item
695
+
696
+
697
+ class AudioRandomLowpassFilter(torch.nn.Module):
698
+
699
+ def __init__(self, p: float, cutoff_freq: float, Q: float = 0.707):
700
+ super().__init__()
701
+ self.p = p
702
+ self.cutoff_freq = cutoff_freq
703
+ self.Q = Q
704
+
705
+ def apply_to_single_clip(self, clip, sr):
706
+ if self.p > torch.rand(1):
707
+ return torchaudio.functional.lowpass_biquad(clip, sr, self.cutoff_freq, self.Q)
708
+ else:
709
+ return clip
710
+
711
+ def apply_to_each_clip(self, clips, sr):
712
+ for i, clip in enumerate(clips):
713
+ clips[i] = self.apply_to_single_clip(clip, sr)
714
+ return clips
715
+
716
+ def forward(self, item):
717
+ has_batch_dim = len(item['audio'].shape) == 2
718
+ sr = int(item['meta']['audio']['framerate'][0])
719
+ if has_batch_dim:
720
+ fn = self.apply_to_each_clip
721
+ else:
722
+ fn = self.apply_to_single_clip
723
+ item['audio'] = fn(item['audio'], sr)
724
+ return item
725
+
726
+
727
+ class AudioRandomPitchShift(torch.nn.Module):
728
+
729
+ def __init__(self, p: float, shift: int) -> None:
730
+ super().__init__()
731
+ self.p = p
732
+ self.shift = shift
733
+
734
+ def apply_to_single_clip(self, wave, sr):
735
+ if self.p > torch.rand(1):
736
+ effects = [['pitch', f'{self.shift}'], ['rate', f'{sr}']]
737
+ wave = wave.unsqueeze(0)
738
+ wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, sr, effects)
739
+ wave = wave.squeeze(0)
740
+ return wave
741
+
742
+ def apply_to_each_clip(self, waves, sr):
743
+ for i, wave in enumerate(waves):
744
+ waves[i] = self.apply_to_single_clip(wave, sr)
745
+ return waves
746
+
747
+ def forward(self, item):
748
+ has_batch_dim = len(item['audio'].shape) == 2
749
+ sr = int(item['meta']['audio']['framerate'][0])
750
+ if has_batch_dim:
751
+ fn = self.apply_to_each_clip
752
+ else:
753
+ fn = self.apply_to_single_clip
754
+ item['audio'] = fn(item['audio'], sr)
755
+ return item
756
+
757
+
758
+ class AudioRandomReverb(torch.nn.Module):
759
+
760
+ def __init__(self, p: float) -> None:
761
+ super().__init__()
762
+ self.p = p
763
+ self.effects = [['reverb', '-w']]
764
+
765
+ def apply_to_single_clip(self, wave, fps):
766
+ if self.p > torch.rand(1):
767
+ wave = wave.unsqueeze(0)
768
+ wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, fps, self.effects)
769
+ wave = wave.mean(dim=0)
770
+ return wave
771
+
772
+ def apply_to_each_clip(self, waves, fps):
773
+ for i, wave in enumerate(waves):
774
+ waves[i] = self.apply_to_single_clip(wave, fps)
775
+ return waves
776
+
777
+ def forward(self, item):
778
+ has_batch_dim = len(item['audio'].shape) == 2
779
+ sr = int(item['meta']['audio']['framerate'][0])
780
+ if has_batch_dim:
781
+ fn = self.apply_to_each_clip
782
+ else:
783
+ fn = self.apply_to_single_clip
784
+ item['audio'] = fn(item['audio'], sr)
785
+ return item
786
+
787
+ class AudioRandomGaussNoise(torch.nn.Module):
788
+
789
+ def __init__(self, p: float, amplitude=0.01) -> None:
790
+ super().__init__()
791
+ self.p = p
792
+ self.amplitude = amplitude
793
+
794
+ def apply_to_single_clip(self, wave):
795
+ if self.p > torch.rand(1):
796
+ noise = torch.randn_like(wave, dtype=wave.dtype)
797
+ wave = wave + self.amplitude * noise
798
+ return wave
799
+
800
+ def apply_to_each_clip(self, waves):
801
+ for i, wave in enumerate(waves):
802
+ waves[i] = self.apply_to_single_clip(wave)
803
+ return waves
804
+
805
+ def forward(self, item):
806
+ has_batch_dim = len(item['audio'].shape) == 2
807
+ if has_batch_dim:
808
+ fn = self.apply_to_each_clip
809
+ else:
810
+ fn = self.apply_to_single_clip
811
+ item['audio'] = fn(item['audio'])
812
+ return item
813
+
814
+
815
+ class AudioMelSpectrogram(torch.nn.Module):
816
+
817
+ def __init__(self, **kwargs):
818
+ super().__init__()
819
+ self.spec = torchaudio.transforms.MelSpectrogram(**kwargs)
820
+
821
+ def forward(self, item):
822
+ item['audio'] = self.spec(item['audio']) # safe for batched input
823
+ return item
824
+
825
+
826
+ class AudioLog(torch.nn.Module):
827
+
828
+ def __init__(self, eps=1e-6) -> None:
829
+ super().__init__()
830
+ self.eps = eps
831
+
832
+ def forward(self, item):
833
+ item['audio'] = torch.log(item['audio'] + self.eps)
834
+ return item
835
+
836
+ class PadOrTruncate(torch.nn.Module):
837
+
838
+ def __init__(self, max_spec_t: int, pad_mode: str = 'constant', pad_value: float = 0.0):
839
+ super().__init__()
840
+ self.max_spec_t = max_spec_t
841
+ self.pad_mode = pad_mode
842
+ self.pad_value = pad_value
843
+
844
+ def forward(self, item):
845
+ item['audio'] = self.pad_or_truncate(item['audio'])
846
+ return item
847
+
848
+ def pad_or_truncate(self, audio):
849
+ difference = self.max_spec_t - audio.shape[-1] # safe for batched input
850
+ # pad or truncate, depending on difference
851
+ if difference > 0:
852
+ # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
853
+ pad_dims = (0, difference)
854
+ audio = torch.nn.functional.pad(audio, pad_dims, self.pad_mode, self.pad_value)
855
+ elif difference < 0:
856
+ logging.warning(f'Truncating spec ({audio.shape}) to max_spec_t ({self.max_spec_t}).')
857
+ audio = audio[..., :self.max_spec_t] # safe for batched input
858
+ return audio
859
+
860
+
861
+ class AudioNormalizeAST(torch.nn.Module):
862
+ '''Normalization is done with two specified mean and std (half)'''
863
+ def __init__(self, mean: float, std: float) -> None:
864
+ super().__init__()
865
+ self.mean = mean
866
+ self.std = std
867
+
868
+ def forward(self, item):
869
+ item['audio'] = (item['audio'] - self.mean) / (2 * self.std)
870
+ item['meta']['audio']['norm_stats'] = {'mean': self.mean, 'std': self.std}
871
+ return item
872
+
873
+
874
+ class PermuteStreams(torch.nn.Module):
875
+
876
+ def __init__(self, einops_order_audio: str, einops_order_rgb: str) -> None:
877
+ ''' For example:
878
+ einops_order_audio: "S F T -> S T F"
879
+ einops_order_rgb: "S T C H W -> S C T H W"'''
880
+ super().__init__()
881
+ self.einops_order_audio = einops_order_audio
882
+ self.einops_order_rgb = einops_order_rgb
883
+
884
+ def forward(self, item):
885
+ if self.einops_order_audio is not None:
886
+ item['audio'] = einops.rearrange(item['audio'], self.einops_order_audio).contiguous()
887
+ if self.einops_order_rgb is not None:
888
+ item['video'] = einops.rearrange(item['video'], self.einops_order_rgb).contiguous()
889
+ return item
890
+
891
+
892
+ class ResampleAudio(torch.nn.Module):
893
+
894
+ def __init__(self, new_fps: int):
895
+ super().__init__()
896
+ self.new_fps = new_fps
897
+
898
+ def forward(self, item):
899
+ orig_fps = int(item['meta']['audio']['framerate'][0])
900
+ item['meta']['audio']['orig_shape'] = item['audio'].shape
901
+ if orig_fps != self.new_fps:
902
+ item['audio'] = torchaudio.functional.resample(item['audio'], orig_fps, self.new_fps)
903
+ item['meta']['audio']['framerate'][0] = self.new_fps
904
+ return item
905
+
906
+ class ResampleRGB(torch.nn.Module):
907
+
908
+ def __init__(self, new_fps: int) -> None:
909
+ super().__init__()
910
+ self.new_fps = new_fps
911
+
912
+ def forward(self, item):
913
+ orig_fps = float(item['meta']['video']['fps'][0])
914
+ item['meta']['video']['orig_shape'] = item['video'].shape
915
+ if orig_fps != self.new_fps:
916
+ duration_sec = item['video'].shape[0] / orig_fps
917
+ indices = torch.arange(0, orig_fps * duration_sec - 1e-9, orig_fps / self.new_fps)
918
+ # basically, rounding
919
+ indices = indices.to(dtype=torch.long)
920
+ item['video'] = item['video'][indices]
921
+ item['meta']['video']['fps'][0] = self.new_fps
922
+ return item
923
+
924
+ class ResizeAndLetterboxPad(torch.nn.Module):
925
+ '''Adapted from WACV24 Amazon`s challenge'''
926
+
927
+ def __init__(self, new_h, new_w):
928
+ super().__init__()
929
+ self.new_h = new_h
930
+ self.new_w = new_w
931
+ self.aspect_ratio = new_w / new_h
932
+
933
+ def forward(self, item):
934
+ item['video'] = self.resize_and_pad(item['video'])
935
+ return item
936
+
937
+ def resize_and_pad(self, rgb: torch.Tensor):
938
+ _, _, height, width = rgb.shape
939
+ current_aspect_ratio = width / height
940
+ if current_aspect_ratio > self.aspect_ratio:
941
+ scaled_height = round(self.new_w / current_aspect_ratio)
942
+ rgb = torchvision.transforms.functional.resize(rgb, (scaled_height, self.new_w), antialias=None)
943
+ top = (self.new_h - scaled_height) // 2
944
+ bottom = self.new_h - (scaled_height + top)
945
+ rgb = torch.nn.ConstantPad2d((0, 0, top, bottom), 0)(rgb)
946
+ elif current_aspect_ratio < self.aspect_ratio:
947
+ scaled_width = round(self.new_h*current_aspect_ratio)
948
+ rgb = torchvision.transforms.functional.resize(rgb, (self.new_h, scaled_width), antialias=None)
949
+ left = (self.new_w - scaled_width) // 2
950
+ right = self.new_w - (scaled_width + left)
951
+ rgb = torch.nn.ConstantPad2d((left, right, 0, 0), 0)(rgb)
952
+ return rgb
953
+
954
+
955
+ class ResampleResizeLetterboxPad(torch.nn.Module):
956
+
957
+ def __init__(self, afps, vfps, new_h, new_w) -> None:
958
+ super().__init__()
959
+ self.transforms = torchvision.transforms.Compose([
960
+ ResampleAudio(new_fps=afps),
961
+ ResampleRGB(new_fps=vfps),
962
+ ResizeAndLetterboxPad(new_h=new_h, new_w=new_w)
963
+ ])
964
+
965
+ def forward(self, x: dict) -> dict:
966
+ return self.transforms(x)
967
+
968
+ class DoNothing(torch.nn.Module):
969
+ def __init__(self, *args, **kwargs) -> None:
970
+ super().__init__()
971
+
972
+ def forward(self, x: dict) -> dict:
973
+ return x
974
+
975
+
976
+ if __name__ == '__main__':
977
+ grid = make_class_grid(-1, 1, 21)
978
+ grid = make_class_grid(-2, 2, 41)
979
+ print('grid:', grid)
980
+ print('value quantization:', quantize_offset(grid, 0.06))
981
+ v_fps = 25.0
982
+ duration = 10.0
983
+
984
+ input = {
985
+ 'video': torch.randint(0, 256, (int(duration * v_fps), 3, 720//2, 1280//2), dtype=torch.uint8),
986
+ 'audio': torch.arange(221184-1).float(),
987
+ 'targets': {},
988
+ 'meta': {
989
+ 'video': {'duration': [duration], 'fps': [v_fps]},
990
+ 'audio': {'duration': [duration], 'framerate': [22050.0]},
991
+ 'subtitles': {'duration': []},
992
+ 'cc': {'duration': []},
993
+ },
994
+ 'path': '/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4',
995
+ 'split': 'train',
996
+ }
997
+
998
+ print(input['audio'].shape, input['video'].shape)
999
+
1000
+ fn = EqualifyFromRight(clip_max_len_sec=10)
1001
+ input = fn(input)
1002
+ print(input['audio'].shape, input['video'].shape)
1003
+
1004
+ fn = RGBSpatialCrop((224, 224), is_random=True)
1005
+ # fn = RGBSpatialCrop((112, 112), is_random=True)
1006
+ input = fn(input)
1007
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1008
+
1009
+ fn = Resize((224, 224))
1010
+ input = fn(input)
1011
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1012
+
1013
+ fn = GenerateMultipleSegments(segment_size_vframes=16, n_segments=14,
1014
+ is_start_random=False, audio_jitter_sec=0.05, step_size_seg=0.5)
1015
+ input = fn(input)
1016
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1017
+
1018
+ fn = RandomApplyColorDistortion(p_gray_scale=0.5, p_color_jitter=0.5, s=1.0)
1019
+ input = fn(input)
1020
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1021
+
1022
+ fn = RGBToFloatToZeroOne()
1023
+ input = fn(input)
1024
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1025
+ print(input['meta'])
1026
+
1027
+ fn = RGBNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1028
+ input = fn(input)
1029
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1030
+ print(input['video'].mean(dim=(0, 2, 3)))
1031
+ print(input['meta'])
1032
+
1033
+ fn = AudioRandomReverb(p=1.0)
1034
+ input = fn(input)
1035
+
1036
+ fn = AudioRandomVolume(p=1.0, gain=2.0, gain_type='amplitude')
1037
+ input = fn(input)
1038
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1039
+
1040
+ fn = AudioRandomPitchShift(p=1.0, shift=1000)
1041
+ input = fn(input)
1042
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1043
+
1044
+ fn = AudioRandomLowpassFilter(p=1.0, cutoff_freq=100)
1045
+ input = fn(input)
1046
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1047
+
1048
+ fn = AudioRandomGaussNoise(p=1.0, amplitude=0.01)
1049
+ input = fn(input)
1050
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1051
+
1052
+ fn = AudioLog()
1053
+ input = fn(input)
1054
+ print(input['audio'].shape, input['video'].shape, input['meta']['audio'])
1055
+
1056
+ # audio only
1057
+ input = {
1058
+ 'audio': torch.arange(221184).float(),
1059
+ 'meta': {
1060
+ 'video': {'duration': [10.0], 'fps': [10.0]},
1061
+ 'audio': {'duration': [11.0], 'framerate': [22050.0]},
1062
+ 'subtitles': {'duration': []},
1063
+ 'cc': {'duration': []}
1064
+ },
1065
+ 'path': '/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4'
1066
+ }
1067
+
1068
+ print(input['audio'].shape)
1069
+
1070
+ fn = AudioLog()
1071
+ input = fn(input)
1072
+ print(input['audio'].shape, input['meta']['audio'])
1073
+ print(input['meta'])
1074
+ print(input['audio'].min(), input['audio'].max())
modules/dataset/vggsound.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import logging
3
+ import os
4
+ import random
5
+ import sys
6
+ from collections import Counter
7
+ from glob import glob
8
+ from pathlib import Path
9
+
10
+ import torch
11
+
12
+ sys.path.insert(0, '.') # nopep8
13
+ from dataset.dataset_utils import get_fixed_offsets, get_video_and_audio, subsample_dataset
14
+
15
+
16
+ class VGGSound(torch.utils.data.Dataset):
17
+
18
+ def __init__(self,
19
+ split,
20
+ vids_dir,
21
+ transforms=None,
22
+ to_filter_bad_examples=True,
23
+ splits_path='./data',
24
+ meta_path='./data/vggsound.csv',
25
+ seed=1337,
26
+ load_fixed_offsets_on=['valid', 'test'],
27
+ vis_load_backend='read_video',
28
+ size_ratio=None,
29
+ attr_annot_path=None,
30
+ max_attr_per_vid=None):
31
+ super().__init__()
32
+ self.max_clip_len_sec = None
33
+ self.split = split
34
+ self.vids_dir = vids_dir
35
+ self.transforms = transforms
36
+ self.to_filter_bad_examples = to_filter_bad_examples
37
+ self.splits_path = splits_path
38
+ self.meta_path = meta_path
39
+ self.seed = seed
40
+ self.load_fixed_offsets_on = [] if load_fixed_offsets_on is None else load_fixed_offsets_on
41
+ self.vis_load_backend = vis_load_backend
42
+ self.size_ratio = size_ratio
43
+
44
+ vggsound_meta = list(csv.reader(open(meta_path), quotechar='"'))
45
+
46
+ # filter "bad" examples
47
+ if to_filter_bad_examples:
48
+ vggsound_meta = self.filter_bad_examples(vggsound_meta)
49
+
50
+ unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
51
+ self.label2target = {label: target for target, label in enumerate(unique_classes)}
52
+ self.target2label = {target: label for label, target in self.label2target.items()}
53
+ self.video2target = {row[0]: self.label2target[row[2]] for row in vggsound_meta}
54
+
55
+ split_clip_ids_path = os.path.join(splits_path, f'vggsound_{split}.txt')
56
+ if not os.path.exists(split_clip_ids_path):
57
+ self.make_split_files()
58
+ # the ugly string converts ['AdfsGsfII2yQ', '1'] into `AdfsGsfII2yQ_1000_11000`
59
+ meta_available = set([f'{r[0]}_{int(r[1])*1000}_{(int(r[1])+10)*1000}' for r in vggsound_meta])
60
+ within_split = set(open(split_clip_ids_path).read().splitlines())
61
+ clip_paths = [os.path.join(vids_dir, v + '.mp4') for v in meta_available.intersection(within_split)]
62
+ clip_paths = sorted(clip_paths)
63
+
64
+ if split in self.load_fixed_offsets_on:
65
+ logging.info(f'Using fixed offset for {split}')
66
+ self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'vggsound')
67
+
68
+ # making sure that all classes have at least one example
69
+ counter = Counter([self.video2target[Path(p).stem[:11]] for p in clip_paths])
70
+ assert all(counter[c] > 0 for c in self.target2label.keys()), \
71
+ f'Some classes have 0 count: {dict(counter)}'
72
+
73
+ self.dataset = clip_paths
74
+ self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
75
+
76
+ def filter_bad_examples(self, vggsound_meta):
77
+ bad = set()
78
+ base_path = Path('./data/filtered_examples_vggsound')
79
+ lists = [open(p).read().splitlines() for p in sorted(glob(str(base_path / '*.txt')))]
80
+ for s in lists:
81
+ bad = bad.union(s)
82
+ # the ugly string converts '---g-f_I2yQ', '1' into `---g-f_I2yQ_1000_11000`
83
+ vggsound_meta = [r for r in vggsound_meta if f'{r[0]}_{int(r[1])*1000}_{(int(r[1])+10)*1000}' not in bad]
84
+ return vggsound_meta
85
+
86
+ def __getitem__(self, index):
87
+ path = self.dataset[index]
88
+ rgb, audio, meta = self.load_media(path)
89
+ item = self.make_datapoint(path, rgb, audio, meta)
90
+ if self.transforms is not None:
91
+ item = self.transforms(item)
92
+ return item
93
+
94
+ def make_datapoint(self, path, rgb, audio, meta):
95
+ # (Tv, 3, H, W) in [0, 225], (Ta, C) in [-1, 1]
96
+ target = self.video2target[Path(path).stem[:11]]
97
+ item = {
98
+ 'video': rgb,
99
+ 'audio': audio,
100
+ 'meta': meta,
101
+ 'path': str(path),
102
+ 'targets': {'vggsound_target': target, 'vggsound_label': self.target2label[target]},
103
+ 'split': self.split,
104
+ }
105
+
106
+ if self.split in self.load_fixed_offsets_on:
107
+ unique_id = Path(path).stem
108
+ offset_params = self.vid2offset_params[unique_id]
109
+ item['targets']['offset_sec'] = self.vid2offset_params[unique_id]['offset_sec']
110
+ item['targets']['v_start_i_sec'] = self.vid2offset_params[unique_id]['v_start_i_sec']
111
+ if 'oos_target' in offset_params:
112
+ item['targets']['offset_target'] = {
113
+ 'oos': offset_params['oos_target'], 'offset': item['targets']['offset_sec'],
114
+ }
115
+
116
+ return item
117
+
118
+ def load_media(self, path):
119
+ rgb, audio, meta = get_video_and_audio(path, get_meta=True, end_sec=self.max_clip_len_sec)
120
+ return rgb, audio, meta
121
+
122
+ def make_split_files(self):
123
+ if self.to_filter_bad_examples:
124
+ logging.warning('`to_filter_bad_examples` is True. `make_split_files` expects otherwise')
125
+
126
+ logging.info(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
127
+ # The downloaded videos (some went missing on YouTube and no longer available)
128
+ available_vid_paths = sorted(glob(os.path.join(self.vids_dir, '*.mp4')))
129
+ logging.info(f'The number of clips available after download: {len(available_vid_paths)}')
130
+
131
+ # original (full) train and test sets
132
+ vggsound_meta = list(csv.reader(open(self.meta_path), quotechar='"'))
133
+ train_vids = {row[0] for row in vggsound_meta if row[3] == 'train'}
134
+ test_vids = {row[0] for row in vggsound_meta if row[3] == 'test'}
135
+
136
+ # # the cleaned test set
137
+ # vggsound_meta_test_v2 = list(csv.reader(open(self.meta_path_clean_test), quotechar='"'))
138
+ logging.info(f'The number of videos in vggsound train set: {len(train_vids)}')
139
+ logging.info(f'The number of videos in vggsound test set: {len(test_vids)}')
140
+
141
+ # class counts in test set. We would like to have the same distribution in valid
142
+ unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
143
+ label2target = {label: target for target, label in enumerate(unique_classes)}
144
+ video2target = {row[0]: label2target[row[2]] for row in vggsound_meta}
145
+ test_vid_classes = [video2target[vid] for vid in test_vids]
146
+ test_target2count = Counter(test_vid_classes)
147
+
148
+ # now given the counts from test set, sample the same count for validation and the rest leave in train
149
+ train_vids_wo_valid, valid_vids = set(), set()
150
+ for target, label in enumerate(label2target.keys()):
151
+ class_train_vids = [vid for vid in sorted(list(train_vids)) if video2target[vid] == target]
152
+ random.Random(self.seed).shuffle(class_train_vids)
153
+ count = test_target2count[target]
154
+ valid_vids.update(class_train_vids[:count])
155
+ train_vids_wo_valid.update(class_train_vids[count:])
156
+
157
+ # make file with a list of available test videos (each video should contain timestamps as well)
158
+ train_i = valid_i = test_i = 0
159
+ with open(os.path.join(self.splits_path, 'vggsound_train.txt'), 'w') as train_file, \
160
+ open(os.path.join(self.splits_path, 'vggsound_valid.txt'), 'w') as valid_file, \
161
+ open(os.path.join(self.splits_path, 'vggsound_test.txt'), 'w') as test_file:
162
+ # open(os.path.join(self.splits_path, 'vggsound_test_v2.txt'), 'w') as test_file_v2:
163
+ for path in available_vid_paths:
164
+ path = path.replace('.mp4', '')
165
+ vid_name = Path(path).name
166
+ # 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
167
+ if vid_name[:11] in train_vids_wo_valid:
168
+ train_file.write(vid_name + '\n')
169
+ train_i += 1
170
+ elif vid_name[:11] in valid_vids:
171
+ valid_file.write(vid_name + '\n')
172
+ valid_i += 1
173
+ elif vid_name[:11] in test_vids:
174
+ test_file.write(vid_name + '\n')
175
+ test_i += 1
176
+ # else:
177
+ # raise Exception(f'Clip {vid_name} is neither in train, valid nor test. Strange.')
178
+
179
+ logging.info(f'Put {train_i} clips to the train set and saved it to ./data/vggsound_train.txt')
180
+ logging.info(f'Put {valid_i} clips to the valid set and saved it to ./data/vggsound_valid.txt')
181
+ logging.info(f'Put {test_i} clips to the test set and saved it to ./data/vggsound_test.txt')
182
+
183
+ def __len__(self):
184
+ return len(self.dataset)
185
+
186
+ class VGGSoundSparse(VGGSound):
187
+ '''
188
+ The same as VGGSound, except the list of videos is filtered for sparse sounds (sparse_meta_path)
189
+ '''
190
+
191
+ def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
192
+ splits_path='./data', meta_path='./data/vggsound.csv',
193
+ sparse_meta_path='./data/sparse_classes.csv', seed=1337, load_fixed_offsets_on=['valid', 'test'],
194
+ vis_load_backend='read_video', size_ratio=None, attr_annot_path=None, max_attr_per_vid=None):
195
+ super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path, seed,
196
+ load_fixed_offsets_on, vis_load_backend, size_ratio)
197
+ self.sparse_meta_path = sparse_meta_path
198
+ sparse_meta = list(csv.reader(open(sparse_meta_path), quotechar='"', delimiter='\t'))
199
+ sparse_classes = set([row[0] for row in sparse_meta if row[1] == 'y'])
200
+ label2new_target = {label: target for target, label in enumerate(sorted(list(sparse_classes)))}
201
+ new_target2label = {target: label for label, target in label2new_target.items()}
202
+
203
+ sparse_dataset = []
204
+ video2new_target = {}
205
+ for path in self.dataset:
206
+ vid_id = Path(path).stem[:11]
207
+ vid_target = self.video2target[vid_id]
208
+ vid_label = self.target2label[vid_target]
209
+ if vid_label in sparse_classes:
210
+ sparse_dataset.append(path)
211
+ video2new_target[vid_id] = label2new_target[vid_label]
212
+
213
+ self.dataset = sparse_dataset
214
+ logging.debug(f'Filtered VGGSound dataset to sparse classes from {Path(sparse_meta_path).name}.')
215
+
216
+ # redefining the label <-> target variable
217
+ self.label2target = label2new_target
218
+ self.target2label = new_target2label
219
+ self.video2target = video2new_target
220
+ logging.debug('Redefined label <-> target mapping to sparse classes.')
221
+
222
+ counter = Counter([self.video2target[Path(p).stem[:11]] for p in self.dataset])
223
+ assert len(self.dataset) < 1000 or all(counter[c] > 0 for c in self.target2label.keys()), \
224
+ f'Some classes have 0 count: {dict(counter)}'
225
+
226
+
227
+ class VGGSoundSparsePicked(VGGSoundSparse):
228
+
229
+ def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
230
+ splits_path='./data', meta_path='./data/vggsound.csv',
231
+ sparse_meta_path='./data/picked_sparse_classes.csv', seed=1337,
232
+ load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
233
+ attr_annot_path=None, max_attr_per_vid=None):
234
+ super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path,
235
+ meta_path, sparse_meta_path, seed, load_fixed_offsets_on, vis_load_backend,
236
+ size_ratio)
237
+
238
+
239
+ class VGGSoundSparsePickedCleanTest(VGGSoundSparse):
240
+
241
+ def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
242
+ splits_path='./data', meta_path='./data/vggsound.csv',
243
+ sparse_meta_path='./data/picked_sparse_classes.csv', seed=1337,
244
+ load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
245
+ attr_annot_path=None, max_attr_per_vid=None):
246
+ super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path,
247
+ meta_path, sparse_meta_path, seed, load_fixed_offsets_on, vis_load_backend,
248
+ size_ratio)
249
+
250
+ def filter_bad_examples(self, vggsound_meta):
251
+ bad = set()
252
+ base_path1 = Path('./data/filtered_examples_vggsound')
253
+ base_path2 = Path('./data/filtered_examples_vggsound_extra')
254
+ lists1 = [open(p).read().splitlines() for p in sorted(glob(str(base_path1 / '*.txt')))]
255
+ lists2 = [open(p).read().splitlines() for p in sorted(glob(str(base_path2 / '*.txt')))]
256
+ lists = lists1 + lists2
257
+ for s in lists:
258
+ bad = bad.union(s)
259
+ # the ugly string converts '---g-f_I2yQ', '1' into `---g-f_I2yQ_1000_11000`
260
+ vggsound_meta = [r for r in vggsound_meta if f'{r[0]}_{int(r[1])*1000}_{(int(r[1])+10)*1000}' not in bad]
261
+ return vggsound_meta
262
+
263
+
264
+ class VGGSoundSparsePickedCleanTestFixedOffsets(VGGSoundSparse):
265
+ '''This dataset only operates on manually annotated fixed offsets. Meant for evaluation purpose.'''
266
+
267
+ def __init__(self, split, vids_dir, transforms=None, to_filter_bad_examples=True,
268
+ splits_path='./data', meta_path='./data/vggsound.csv',
269
+ sparse_meta_path='./data/picked_sparse_classes.csv', seed=1337,
270
+ load_fixed_offsets_on=['valid', 'test'], vis_load_backend='read_video', size_ratio=None,
271
+ attr_annot_path=None, max_attr_per_vid=None):
272
+ super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path,
273
+ meta_path, sparse_meta_path, seed, load_fixed_offsets_on, vis_load_backend,
274
+ size_ratio)
275
+ # redefine the dataset to only use fixed offsets
276
+ fix_off_path = './data/vggsound_sparse_clean_fixed_offsets.csv'
277
+ self.vid2offset_params = {}
278
+ # lines have the format: dataset_name,video_id,vstart_sec,offset_sec,is_sync
279
+ reader = csv.reader(open(fix_off_path))
280
+ next(reader) # skipping the header
281
+ for _, vi, st, of, sy in reader:
282
+ # FIXME: if there are more than one fixed offset per video, we will need to change the logic here
283
+ assert vi not in self.vid2offset_params, 'offsets from other splits will override each other'
284
+ if sy == '1':
285
+ self.vid2offset_params[vi] = {'offset_sec': float(of), 'v_start_i_sec': float(st)}
286
+ logging.debug(f'Loaded {len(self.vid2offset_params)} fixed offsets from {Path(fix_off_path).name}.')
287
+ # since we care only about the videos in the fixed offset split, we filter the dataset
288
+ self.dataset = [p for p in self.dataset if Path(p).stem in self.vid2offset_params]
289
+ logging.debug(f'Filtered VGGSoundSparse dataset to fixed offsets from {Path(fix_off_path).name}.')
290
+
291
+
292
+ class LongerVGGSound(VGGSound):
293
+ '''This class is different to the parent in the extra filtering it does. If the parent was
294
+ making the splits with filtering for shorter than 9 second, this class filters for shorter than 9.5 sec.
295
+ by applying extra filtering.
296
+ '''
297
+
298
+ def __init__(self,
299
+ split,
300
+ vids_dir,
301
+ transforms=None,
302
+ to_filter_bad_examples=True,
303
+ splits_path='./data',
304
+ meta_path='./data/vggsound.csv',
305
+ seed=1337,
306
+ load_fixed_offsets_on=['valid', 'test'],
307
+ vis_load_backend='read_video',
308
+ size_ratio=None,
309
+ attr_annot_path=None,
310
+ max_attr_per_vid=None):
311
+ # size_ratio=None and fixed_offsets_on=[] to avoid double subsampling and loading fixed offsets
312
+ super().__init__(split, vids_dir, transforms, to_filter_bad_examples, splits_path, meta_path, seed,
313
+ [], vis_load_backend, None, attr_annot_path, max_attr_per_vid)
314
+ # redefining the load_fixed_offsets_on because the parent class does not load them (see above)
315
+ self.load_fixed_offsets_on = load_fixed_offsets_on
316
+ # doing the extra filtering for longer than 9.5 sec
317
+ if to_filter_bad_examples:
318
+ filt_ex_path = Path('./data/filtered_examples_vggsound_shorter/less_than_9.5s.txt')
319
+ bad = set([p for p in open(filt_ex_path).read().splitlines()])
320
+ self.dataset = [p for p in self.dataset if Path(p).stem not in bad]
321
+
322
+ # longer vid clips require new fixed offsets, loading them here again (VGGSound loads them in init)
323
+ if split in self.load_fixed_offsets_on:
324
+ logging.info(f'Using fixed offset for {split}')
325
+ self.vid2offset_params = get_fixed_offsets(transforms, split, splits_path, 'vggsound')
326
+
327
+ self.dataset = subsample_dataset(self.dataset, size_ratio, shuffle=split == 'train')
328
+ logging.info(f'{split} has {len(self.dataset)} items')
329
+
330
+
331
+
332
+ if __name__ == '__main__':
333
+ from omegaconf import OmegaConf
334
+ from scripts.train_utils import get_transforms
335
+ from utils.utils import cfg_sanity_check_and_patch
336
+ cfg = OmegaConf.load('./configs/transformer.yaml')
337
+ vis_load_backend = 'read_video'
338
+
339
+ # set logging for info level
340
+ logging.basicConfig(level=logging.INFO)
341
+
342
+ transforms = get_transforms(cfg)
343
+
344
+ vids_path = 'PLACEHOLDER'
345
+
346
+ # cfg.data.dataset.params.size_ratio = 0.1
347
+
348
+ cfg_sanity_check_and_patch(cfg)
349
+
350
+ datasets = {
351
+ 'train': VGGSound('train', vids_path, transforms['train'], vis_load_backend=vis_load_backend,
352
+ size_ratio=cfg.data.dataset.params.size_ratio),
353
+ 'valid': VGGSound('valid', vids_path, transforms['test'], vis_load_backend=vis_load_backend),
354
+ 'test': VGGSound('test', vids_path, transforms['test'], vis_load_backend=vis_load_backend),
355
+ }
356
+ for phase in ['train', 'valid', 'test']:
357
+ print(phase, len(datasets[phase]))
358
+
359
+ print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
360
+ print(datasets['train'][0]['meta'])
361
+ print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
362
+ print(datasets['valid'][0]['meta'])
363
+ print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
364
+ print(datasets['test'][0]['meta'])
365
+
366
+ # for i in range(300, 1000):
367
+ # datasets['train'][i]['path']
368
+ # print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
369
+ # print(datasets['train'][0]['meta'])
370
+
371
+ datasets = {
372
+ 'train': VGGSoundSparse('train', vids_path, transforms['train']),
373
+ 'valid': VGGSoundSparse('valid', vids_path, transforms['test']),
374
+ 'test': VGGSoundSparse('test', vids_path, transforms['test']),
375
+ }
376
+ for phase in ['train', 'valid', 'test']:
377
+ print(phase, len(datasets[phase]))
378
+
379
+ print(datasets['train'][0]['audio'].shape, datasets['train'][0]['video'].shape)
380
+ print(datasets['train'][0]['meta'])
381
+ print(datasets['valid'][0]['audio'].shape, datasets['valid'][0]['video'].shape)
382
+ print(datasets['valid'][0]['meta'])
383
+ print(datasets['test'][0]['audio'].shape, datasets['test'][0]['video'].shape)
384
+ print(datasets['test'][0]['meta'])
385
+
386
+ datasets = {
387
+ 'test': VGGSoundSparsePickedCleanTestFixedOffsets('test', vids_path, transforms['test']),
388
+ }
389
+ for phase in ['test']:
390
+ print(phase, len(datasets[phase]))
391
+
392
+ for i in range(len(datasets['test'])):
393
+ datasets['test'][i]['path']
394
+ print(datasets['test'][i]['audio'].shape, datasets['test'][i]['video'].shape)
modules/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Auto-generated __init__.py
modules/model/modules/bridges.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class BridgeBase(torch.nn.Module):
6
+
7
+ def __init__(self) -> None:
8
+ super().__init__()
9
+ self.bridge = None
10
+
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ try:
13
+ x = self.bridge(x)
14
+ return x
15
+ except TypeError as e:
16
+ raise TypeError('The class cant be called on its own. Please, use a class that inherits it', e)
17
+
18
+
19
+ class ConvBridgeBase(BridgeBase):
20
+
21
+ def __init__(self, block, **kwargs) -> None:
22
+ super().__init__()
23
+ self.bridge = torch.nn.Sequential(
24
+ block(**kwargs),
25
+ torch.nn.GELU(),
26
+ )
27
+
28
+
29
+ class AvgPoolBridgeBase(BridgeBase):
30
+
31
+ def __init__(self, block, **kwargs) -> None:
32
+ super().__init__()
33
+ # a list [1, 2, 3] specified in omegaconfig, has type ListConf which is not accepted by pytorch
34
+ # interestingly, conv layers don't care
35
+ for k in ['kernel_size', 'stride']:
36
+ kwargs[k] = list(kwargs[k])
37
+ self.bridge = block(**kwargs)
38
+
39
+
40
+ class ConvBridgeAudio(ConvBridgeBase):
41
+
42
+ def __init__(self, **kwargs) -> None:
43
+ super().__init__(block=torch.nn.Conv2d, **kwargs)
44
+
45
+
46
+ class ConvBridgeVisual(ConvBridgeBase):
47
+
48
+ def __init__(self, **kwargs) -> None:
49
+ super().__init__(block=torch.nn.Conv3d, **kwargs)
50
+
51
+
52
+ class AvgPoolBridgeVisual(AvgPoolBridgeBase):
53
+
54
+ def __init__(self, **kwargs) -> None:
55
+ super().__init__(block=torch.nn.AvgPool3d, **kwargs)
56
+
57
+
58
+ class AvgPoolBridgeAudio(AvgPoolBridgeBase):
59
+
60
+ def __init__(self, **kwargs) -> None:
61
+ super().__init__(block=torch.nn.AvgPool2d, **kwargs)
62
+
63
+
64
+ class DoNothingBridge(BridgeBase):
65
+
66
+ def __init__(self, **kwargs) -> None:
67
+ super().__init__()
68
+ self.bridge = torch.nn.Identity(**kwargs)
69
+
70
+
71
+ class AppendZerosToHidden(BridgeBase):
72
+
73
+ def __init__(self, target_hidden_size, dim) -> None:
74
+ super().__init__()
75
+ self.target_hidden_size = target_hidden_size
76
+ self.dim = dim
77
+
78
+ def forward(self, x):
79
+ d_res = self.target_hidden_size - x.shape[self.dim]
80
+ # going to insert the new dimension into the x.shape output
81
+ shape_target = list(x.shape[:self.dim]) + [d_res] + list(x.shape[self.dim+1:])
82
+ # creating the zeros to append to x
83
+ zeros = torch.zeros(shape_target).to(x)
84
+ x = torch.cat([x, zeros], self.dim)
85
+ return x
86
+
87
+
88
+ class SpatialpoolConvTemporalpool(torch.nn.Module):
89
+ '''Similar to S3D but with slightly different kernel for F.avg_pool3d.
90
+ To be used in AVCLIP in visual branch'''
91
+
92
+ def __init__(self, **kwargs) -> None:
93
+ super().__init__()
94
+ self.conv = torch.nn.Conv3d(**kwargs)
95
+
96
+ def forward(self, x: torch.Tensor):
97
+ B, t, d, h, w = x.shape
98
+ x = x.permute(0, 2, 1, 3, 4) # (B, d, t, h, w)
99
+ # pool as in S3D but without temporal pooling (2-->1, h, w)
100
+ x = F.avg_pool3d(x, (1, h, w), stride=1) # (B, d, t, 1, 1)
101
+ x = self.conv(x) # (B, D, t, 1, 1)
102
+ x = x.view(B, self.conv.out_channels, t) # squeeze the spatial dimensions
103
+ x = x.mean(dim=-1) # temporal pooling
104
+ return x # (B, d)
105
+
106
+
107
+ class FrequencypoolConvTemporalpool(torch.nn.Module):
108
+ '''Similar to the visual branch of S3D, which is a stack of spatial pool, conv, temporal pool blocks.
109
+ Instead, this is a stack of frequency pool, conv, temporal pooling blocks.
110
+ To be used in AVCLIP in audio branch'''
111
+
112
+ def __init__(self, **kwargs) -> None:
113
+ super().__init__()
114
+ self.conv = torch.nn.Conv2d(**kwargs)
115
+
116
+ def forward(self, x: torch.Tensor):
117
+ B, d, f, t = x.shape
118
+ # frequency pooling (f-->1)
119
+ x = F.avg_pool2d(x, (f, 1), stride=1) # (B, d, 1, t)
120
+ x = self.conv(x) # (B, D, 1, t)
121
+ x = x.view(B, self.conv.out_channels, t) # squeeze the frequency dimension
122
+ x = x.mean(dim=-1) # temporal pooling
123
+ return x # (B, d)
124
+
125
+
126
+ if __name__ == '__main__':
127
+ v = torch.rand(2, 50, 512, 7, 7)
128
+ a = torch.rand(2, 512, 9, 27)
129
+
130
+ in_channels = 512
131
+ out_channels = 512
132
+ kernel_size_v = [1, 7, 7]
133
+ kernel_size_a = [9, 1]
134
+ stride = 1
135
+ bias = True
136
+
137
+ conv_bridge_a = ConvBridgeAudio(in_channels=in_channels, out_channels=out_channels,
138
+ kernel_size=kernel_size_a, stride=stride, bias=bias)
139
+ conv_bridge_v = ConvBridgeVisual(in_channels=in_channels, out_channels=out_channels,
140
+ kernel_size=kernel_size_v, stride=stride, bias=bias)
141
+ avg_bridge_a = AvgPoolBridgeAudio(kernel_size=kernel_size_a)
142
+ avg_bridge_v = AvgPoolBridgeVisual(kernel_size=kernel_size_v)
143
+ i_bridge_a = DoNothingBridge(some_arg=123)
144
+ i_bridge_v = DoNothingBridge(some_arg=123)
145
+ h_bridge_a = AppendZerosToHidden(target_hidden_size=1024, dim=1)
146
+ h_bridge_v = AppendZerosToHidden(target_hidden_size=1024, dim=1)
147
+
148
+ print('v', v.shape)
149
+ print('conv_v(v)', conv_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
150
+ print()
151
+
152
+ print('a', a.shape)
153
+ print('conv_a(a)', conv_bridge_a(a).shape)
154
+ print()
155
+
156
+ print('v', v.shape)
157
+ print('avg3d(v)', avg_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
158
+ print()
159
+
160
+ print('a', a.shape)
161
+ print('avg2d(a)', avg_bridge_a(a).shape)
162
+ print()
163
+
164
+ print('v', v.shape)
165
+ print('i(v)', i_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
166
+ print()
167
+
168
+ print('a', a.shape)
169
+ print('i(a)', i_bridge_a(a).shape)
170
+ print()
171
+
172
+ print('v', v.shape)
173
+ print('h(v)', h_bridge_v(v.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4).shape)
174
+ print()
175
+
176
+ print('a', a.shape)
177
+ print('h(a)', h_bridge_a(a).shape)
178
+ print()
modules/model/modules/feat_extractors/audio/ast.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ from torch import nn
4
+ # importing modified version of AST
5
+ from model.modules.feat_extractors.audio.hf_src.modeling_ast import ASTForAudioClassification, ASTConfig
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+
8
+ from model.modules.feat_extractors.visual.motionformer import (AveragePooling, BaseEncoderLayer,
9
+ TemporalTransformerEncoderLayer)
10
+ from utils.utils import check_if_file_exists_else_download
11
+
12
+
13
+ class AST(torch.nn.Module):
14
+ def __init__(self,
15
+ extract_features: bool = False,
16
+ ckpt_path: str = None,
17
+ feat_type: str = None,
18
+ max_spec_t: int = None,
19
+ factorize_freq_time: bool = None,
20
+ agg_freq_module: str = None,
21
+ agg_time_module: str = None,
22
+ add_global_repr: bool = True,
23
+ agg_segments_module: str = None,
24
+ max_segments: int = None,
25
+ ) -> None:
26
+ '''
27
+ extract_features: if True, then the model will return the features instead of head's output
28
+ ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub.
29
+ feat_type: if extract_features is True, this parameter specifies the type of features to return
30
+ max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec
31
+ factorize_freq_time: if True, then the model will use a factorized freq/time aggregation
32
+ agg_freq_module: if specified, then the model will use this module for freq aggregation
33
+ agg_time_module: if specified, then the model will use this module for time aggregation
34
+ add_global_repr: if True, adds a global representation to the features (aggregation on segments)
35
+ agg_segments_module: if specified, then the model will use this module for segments aggregation
36
+ max_segments: if specified, the initialization of PE in the global agg module will use this value.
37
+ This should correspond to the max number of segments per video (if None, 16 is used)
38
+ '''
39
+ super().__init__()
40
+ self.extract_features = extract_features
41
+ self.ckpt_path = ckpt_path
42
+ self.max_spec_t = max_spec_t
43
+ self.max_segments = max_segments
44
+
45
+ # depending on whether the feat extractor was pre-trained contrastively or not, we need to
46
+ # load the state dict differently.
47
+
48
+ # if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model
49
+ if ckpt_path == 'MIT/ast-finetuned-audioset-10-10-0.4593':
50
+ revision = 'c1c0c66' # fixing the revision for compatibility (V4.27.4)
51
+ self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision)
52
+ full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision)
53
+ logging.info(f'Loaded AST from {ckpt_path}')
54
+ else:
55
+ self.config = ASTConfig()
56
+ self.config.num_labels = 527 # 2 by default, audioset has 527 labels
57
+ full_model = ASTForAudioClassification(self.config)
58
+ logging.info('Initialized AST from scratch with the AST AudioSet config')
59
+
60
+ was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith('.pt')
61
+
62
+ # feature extractor
63
+ self.ast = full_model.audio_spectrogram_transformer
64
+
65
+ if self.extract_features:
66
+ # assign `feat_type` (use default if not specified)
67
+ self.feat_type = 'last_hidden_state' if feat_type is None else feat_type
68
+ # define adapters if needed
69
+ self.factorize_freq_time = factorize_freq_time
70
+ # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
71
+ transf_enc_layer_kwargs = dict(
72
+ d_model=self.config.hidden_size, nhead=self.config.num_attention_heads,
73
+ dim_feedforward=self.config.intermediate_size, activation=nn.GELU(), batch_first=True,
74
+ dropout=self.config.attention_probs_dropout_prob, layer_norm_eps=1e-6, norm_first=True,
75
+ )
76
+ if factorize_freq_time:
77
+ self.feat_type = 'last_hidden_state' # this feat_type supports factorization
78
+ # frequency aggreration
79
+ if agg_freq_module == 'TransformerEncoderLayer':
80
+ self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
81
+ elif agg_freq_module == 'AveragePooling':
82
+ self.freq_attn_agg = AveragePooling(avg_pattern='BS D f t -> BS D t',
83
+ then_permute_pattern='BS D t -> BS t D')
84
+ # time aggreration
85
+ if agg_time_module == 'TransformerEncoderLayer':
86
+ self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
87
+ elif agg_time_module == 'AveragePooling':
88
+ self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
89
+ elif 'Identity' in agg_time_module:
90
+ self.temp_attn_agg = nn.Identity()
91
+ # define a global aggregation layer (aggregarate over segments)
92
+ self.add_global_repr = add_global_repr
93
+ if add_global_repr:
94
+ if agg_segments_module == 'TransformerEncoderLayer':
95
+ # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
96
+ # we need to add pos emb (PE) because previously we added the same PE for each segment
97
+ pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
98
+ self.global_attn_agg = TemporalTransformerEncoderLayer(
99
+ add_pos_emb=True, pos_emb_drop=self.config.hidden_dropout_prob,
100
+ pos_max_len=pos_max_len, **transf_enc_layer_kwargs
101
+ )
102
+ elif agg_segments_module == 'AveragePooling':
103
+ self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
104
+ else:
105
+ self.classifier = full_model.classifier
106
+
107
+ # AST.device fails with AttributeError. This is a workaround
108
+ self.device = full_model.device
109
+
110
+ # pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74)
111
+ self.patch_position_emb()
112
+
113
+ if was_pt_on_avclip:
114
+ # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
115
+ # and keep only the state_dict of the feat extractor
116
+ check_if_file_exists_else_download(self.ckpt_path)
117
+ ckpt = torch.load(ckpt_path, map_location='cpu')
118
+ ckpt_weights = dict()
119
+ for k, v in ckpt['state_dict'].items():
120
+ if k.startswith(('module.a_encoder.', 'a_encoder.')):
121
+ k = k.replace('module.', '').replace('a_encoder.', '')
122
+ ckpt_weights[k] = v
123
+ _load_status = self.load_state_dict(ckpt_weights, strict=False)
124
+ if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
125
+ logging.warning(f'Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n' \
126
+ f'Missing keys ({len(_load_status.missing_keys)}): ' \
127
+ f'{_load_status.missing_keys}, \n' \
128
+ f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
129
+ f'{_load_status.unexpected_keys} \n' \
130
+ f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
131
+ else:
132
+ logging.info(f'Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.')
133
+
134
+ # print the number of parameters
135
+ logging.info(f'AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}')
136
+
137
+ def forward(self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None,
138
+ **ast_kwargs) -> torch.Tensor:
139
+ '''
140
+ x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins,
141
+ ast_kwargs: additional arguments for the AST model
142
+ cont_mask: (B, S, T, F) where 0s are the values to be masked out
143
+ if `for_loop=True`, we use a for loop to extract features for each segment separately.
144
+ if `for_loop=False`, we extract features for all segments at once.
145
+ Using the for loop is slower but more memory efficient, while using all segments at once
146
+ is faster but more memory inefficient.
147
+ Using for loop allows to control the memory footprint by varying the number of videos in a
148
+ batch (batch size) rather than the number of segments in a video.
149
+ '''
150
+ B, S, T, F = x.shape
151
+
152
+ if for_loop:
153
+ assert cont_mask is None, 'cont_mask is not supported with for_loop=True'
154
+ orig_shape_s = (B, 1, T, F)
155
+ # NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F).
156
+ # (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1.
157
+ x = torch.cat(
158
+ [self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)],
159
+ dim=1)
160
+ else:
161
+ orig_shape = (B, S, T, F)
162
+ x = x.view(B * S, T, F)
163
+ if cont_mask is not None:
164
+ cont_mask = cont_mask.reshape(B * S, T, F)
165
+ # AST expects a tensor of shape (B*S, T, F).
166
+ x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
167
+ # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
168
+ x = x.view(B, S, *x.shape[1:])
169
+ # x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
170
+
171
+ global_x = None
172
+ if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
173
+ assert len(x.shape) == 3, f'Local representation should be (B, S, D) {x.shape}'
174
+ global_x = self.global_attn_agg(x) # (B, D)
175
+
176
+ return x, global_x # x is (B, S, ...), global_x is (B, D) or None
177
+
178
+ def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
179
+ '''x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out'''
180
+ # 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, <tokens>]
181
+ # x_mask is (B, T) where 0s are the values to be masked out
182
+ x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
183
+
184
+ if self.extract_features:
185
+ x = self.get_features_by_type(x)
186
+ if self.factorize_freq_time:
187
+ x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
188
+ if cont_mask is not None:
189
+ # duplicating the mask for the latent dimension (D) to be compatible with the next func
190
+ x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
191
+ x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
192
+ # again removing the latent
193
+ x_mask = x_mask[:, 0, :, :]
194
+ else:
195
+ x_mask = None
196
+ x = self.freq_attn_agg(x, x_mask) # (BS, t, D)
197
+ x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
198
+ else:
199
+ x = x['pooler_output']
200
+ x = self.classifier(x)
201
+ return x
202
+
203
+ def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor:
204
+ if self.feat_type == 'pooler_output':
205
+ return x['pooler_output'] # (B, D)
206
+ elif self.feat_type == 'CLS':
207
+ return x['last_hidden_state'][:, 0, :] # (B, D)
208
+ elif self.feat_type == 'last_hidden_state':
209
+ return x['last_hidden_state'] # (B, 2+T, D)
210
+ elif self.feat_type == 'last_hidden_state_no_AUX':
211
+ return x['last_hidden_state'][:, 2:, :] # (B, T, D) removing CLS and distill tokens
212
+ else:
213
+ raise ValueError(f'Unknown feature type: {self.feat_type}')
214
+
215
+ def restore_freq_temp_dims(self, feats, orig_shape: tuple):
216
+ '''
217
+ feats are of shape (B*S, T, D)
218
+ where T = 2 + f * t (if feat_type == 'last_hidden_state')
219
+ where T = f * t (if feat_type == 'last_hidden_state_no_AUX')
220
+ Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching.
221
+ From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats:
222
+ `feats.transpose(1, 2).view(B*S, D, f, t)`
223
+
224
+ (Similar function is defined in for RGB features in `motionformer.py`)
225
+ '''
226
+ B, S, T, F = orig_shape
227
+ D = self.config.hidden_size
228
+
229
+ # num patches in each dimension
230
+ f, t = self.ast.embeddings.get_shape(self.config)
231
+
232
+ if self.feat_type == 'last_hidden_state':
233
+ feats = feats[:, 2:, :] # removing CLS and distill tokens
234
+
235
+ feats = feats.permute(0, 2, 1) # (B*S, D, T)
236
+ feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
237
+
238
+ return feats
239
+
240
+ def patch_position_emb(self):
241
+ if self.max_spec_t is not None:
242
+ self.config.max_length = self.max_spec_t
243
+ f, t = self.ast.embeddings.get_shape(self.config)
244
+ shortened = self.ast.embeddings.position_embeddings[:, :f*t+2].clone() # +2 for CLS and distill tokens
245
+ self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
246
+
247
+ def to(self, device):
248
+ '''AST.device fails with AttributeError. This is a workaround. '''
249
+ self.device = torch.device(device)
250
+ return super().to(device)
251
+
252
+
253
+ class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
254
+ ''' This layer is used to aggregate the features along the frequency axis.
255
+ It follows the same logic as spatio-temporal aggregation in visual feature extractor.
256
+ Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py` '''
257
+
258
+ def __init__(self, *args, **kwargs):
259
+ super().__init__(*args, **kwargs)
260
+
261
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
262
+ ''' x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out '''
263
+ BS, D, f, t = x.shape
264
+
265
+ # time as a batch dimension
266
+ x = x.permute(0, 3, 2, 1) # (B*S, t, f, D)
267
+ x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory
268
+ # similar to mask
269
+ if x_mask is not None:
270
+ x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f)
271
+ x_mask = x_mask.reshape(BS * t, f)
272
+
273
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
274
+ x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
275
+
276
+ # reshape back to (B*S, t, D)
277
+ x = x.view(BS, t, D)
278
+
279
+ return x # (B*S, t, D)
modules/model/modules/feat_extractors/audio/hf_src/modeling_ast.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Modified by v-iashin to support token masking
17
+
18
+ """ PyTorch Audio Spectrogram Transformer (AST) model."""
19
+
20
+ import math
21
+ from typing import Dict, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
32
+ from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
33
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ # General docstring
39
+ _CONFIG_FOR_DOC = "ASTConfig"
40
+
41
+ # Base docstring
42
+ _CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
43
+ _EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
44
+
45
+ # Audio classification docstring
46
+ _SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
47
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
48
+ _SEQ_CLASS_EXPECTED_LOSS = 0.17
49
+
50
+
51
+ AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
52
+ "MIT/ast-finetuned-audioset-10-10-0.4593",
53
+ # See all Audio Spectrogram Transformer models at https://huggingface.co/models?filter=ast
54
+ ]
55
+
56
+
57
+ class ASTEmbeddings(nn.Module):
58
+ """
59
+ Construct the CLS token, position and patch embeddings.
60
+ """
61
+
62
+ def __init__(self, config: ASTConfig) -> None:
63
+ super().__init__()
64
+
65
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
66
+ self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
67
+ self.patch_embeddings = ASTPatchEmbeddings(config)
68
+
69
+ frequency_out_dimension, time_out_dimension = self.get_shape(config)
70
+ num_patches = frequency_out_dimension * time_out_dimension
71
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
72
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
73
+ self.config = config
74
+
75
+ def get_shape(self, config):
76
+ # see Karpathy's cs231n blog on how to calculate the output dimensions
77
+ # https://cs231n.github.io/convolutional-networks/#conv
78
+ frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
79
+ time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
80
+
81
+ return frequency_out_dimension, time_out_dimension
82
+
83
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
84
+ batch_size = input_values.shape[0]
85
+ embeddings = self.patch_embeddings(input_values)
86
+
87
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
88
+ distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
89
+ embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
90
+ embeddings = embeddings + self.position_embeddings
91
+ embeddings = self.dropout(embeddings)
92
+
93
+ return embeddings
94
+
95
+
96
+ class ASTPatchEmbeddings(nn.Module):
97
+ """
98
+ This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
99
+ seq_length, hidden_size)` to be consumed by a Transformer.
100
+ """
101
+
102
+ def __init__(self, config):
103
+ super().__init__()
104
+
105
+ patch_size = config.patch_size
106
+ frequency_stride = config.frequency_stride
107
+ time_stride = config.time_stride
108
+
109
+ self.projection = nn.Conv2d(
110
+ 1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride)
111
+ )
112
+
113
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
114
+ input_values = input_values.unsqueeze(1)
115
+ input_values = input_values.transpose(2, 3)
116
+ embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
117
+ return embeddings
118
+
119
+
120
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
121
+ class ASTSelfAttention(nn.Module):
122
+ def __init__(self, config: ASTConfig) -> None:
123
+ super().__init__()
124
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
125
+ raise ValueError(
126
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
127
+ f"heads {config.num_attention_heads}."
128
+ )
129
+
130
+ self.num_attention_heads = config.num_attention_heads
131
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
132
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
133
+
134
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
135
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
136
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
137
+
138
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
139
+
140
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
141
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
142
+ x = x.view(new_x_shape)
143
+ return x.permute(0, 2, 1, 3)
144
+
145
+ def forward(
146
+ self, hidden_states, tok_mask: Optional[torch.Tensor] = None,
147
+ head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
148
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
149
+ mixed_query_layer = self.query(hidden_states)
150
+
151
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
152
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
153
+ query_layer = self.transpose_for_scores(mixed_query_layer)
154
+
155
+ # Take the dot product between "query" and "key" to get the raw attention scores.
156
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
157
+
158
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
159
+
160
+ # apply masking if provided, tok_mask is (BS, N): 1s - keep; attention_scores is (BS, H, N, N)
161
+ if tok_mask is not None:
162
+ BS, N = tok_mask.shape
163
+ attention_scores = attention_scores.masked_fill(tok_mask.view(BS, 1, 1, N) == 0, float('-inf'))
164
+
165
+ # Normalize the attention scores to probabilities.
166
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
167
+
168
+ # This is actually dropping out entire tokens to attend to, which might
169
+ # seem a bit unusual, but is taken from the original Transformer paper.
170
+ attention_probs = self.dropout(attention_probs)
171
+
172
+ # Mask heads if we want to
173
+ if head_mask is not None:
174
+ attention_probs = attention_probs * head_mask
175
+
176
+ context_layer = torch.matmul(attention_probs, value_layer)
177
+
178
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
179
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
180
+ context_layer = context_layer.view(new_context_layer_shape)
181
+
182
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
183
+
184
+ return outputs
185
+
186
+
187
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
188
+ class ASTSelfOutput(nn.Module):
189
+ """
190
+ The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
191
+ layernorm applied before each block.
192
+ """
193
+
194
+ def __init__(self, config: ASTConfig) -> None:
195
+ super().__init__()
196
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
197
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
198
+
199
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
200
+ hidden_states = self.dense(hidden_states)
201
+ hidden_states = self.dropout(hidden_states)
202
+
203
+ return hidden_states
204
+
205
+
206
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST
207
+ class ASTAttention(nn.Module):
208
+ def __init__(self, config: ASTConfig) -> None:
209
+ super().__init__()
210
+ self.attention = ASTSelfAttention(config)
211
+ self.output = ASTSelfOutput(config)
212
+ self.pruned_heads = set()
213
+
214
+ def prune_heads(self, heads: Set[int]) -> None:
215
+ if len(heads) == 0:
216
+ return
217
+ heads, index = find_pruneable_heads_and_indices(
218
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
219
+ )
220
+
221
+ # Prune linear layers
222
+ self.attention.query = prune_linear_layer(self.attention.query, index)
223
+ self.attention.key = prune_linear_layer(self.attention.key, index)
224
+ self.attention.value = prune_linear_layer(self.attention.value, index)
225
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
226
+
227
+ # Update hyper params and store pruned heads
228
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
229
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
230
+ self.pruned_heads = self.pruned_heads.union(heads)
231
+
232
+ def forward(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ tok_mask: Optional[torch.Tensor] = None,
236
+ head_mask: Optional[torch.Tensor] = None,
237
+ output_attentions: bool = False,
238
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
239
+ self_outputs = self.attention(hidden_states, tok_mask, head_mask, output_attentions)
240
+
241
+ attention_output = self.output(self_outputs[0], hidden_states)
242
+
243
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
244
+ return outputs
245
+
246
+
247
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
248
+ class ASTIntermediate(nn.Module):
249
+ def __init__(self, config: ASTConfig) -> None:
250
+ super().__init__()
251
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
252
+ if isinstance(config.hidden_act, str):
253
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
254
+ else:
255
+ self.intermediate_act_fn = config.hidden_act
256
+
257
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
258
+ hidden_states = self.dense(hidden_states)
259
+ hidden_states = self.intermediate_act_fn(hidden_states)
260
+
261
+ return hidden_states
262
+
263
+
264
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST
265
+ class ASTOutput(nn.Module):
266
+ def __init__(self, config: ASTConfig) -> None:
267
+ super().__init__()
268
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
269
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
270
+
271
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
272
+ hidden_states = self.dense(hidden_states)
273
+ hidden_states = self.dropout(hidden_states)
274
+
275
+ hidden_states = hidden_states + input_tensor
276
+
277
+ return hidden_states
278
+
279
+
280
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST
281
+ class ASTLayer(nn.Module):
282
+ """This corresponds to the Block class in the timm implementation."""
283
+
284
+ def __init__(self, config: ASTConfig) -> None:
285
+ super().__init__()
286
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
287
+ self.seq_len_dim = 1
288
+ self.attention = ASTAttention(config)
289
+ self.intermediate = ASTIntermediate(config)
290
+ self.output = ASTOutput(config)
291
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
292
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
293
+
294
+ def forward(
295
+ self,
296
+ hidden_states: torch.Tensor,
297
+ tok_mask: Optional[torch.Tensor] = None,
298
+ head_mask: Optional[torch.Tensor] = None,
299
+ output_attentions: bool = False,
300
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
301
+ self_attention_outputs = self.attention(
302
+ self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention
303
+ tok_mask,
304
+ head_mask,
305
+ output_attentions=output_attentions,
306
+ )
307
+ attention_output = self_attention_outputs[0]
308
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
309
+
310
+ # first residual connection
311
+ hidden_states = attention_output + hidden_states
312
+
313
+ # in AST, layernorm is also applied after self-attention
314
+ layer_output = self.layernorm_after(hidden_states)
315
+ layer_output = self.intermediate(layer_output)
316
+
317
+ # second residual connection is done here
318
+ layer_output = self.output(layer_output, hidden_states)
319
+
320
+ outputs = (layer_output,) + outputs
321
+
322
+ return outputs
323
+
324
+
325
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST
326
+ class ASTEncoder(nn.Module):
327
+ def __init__(self, config: ASTConfig) -> None:
328
+ super().__init__()
329
+ self.config = config
330
+ self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
331
+ self.gradient_checkpointing = False
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states: torch.Tensor,
336
+ tok_mask: Optional[torch.Tensor] = None,
337
+ head_mask: Optional[torch.Tensor] = None,
338
+ output_attentions: bool = False,
339
+ output_hidden_states: bool = False,
340
+ return_dict: bool = True,
341
+ ) -> Union[tuple, BaseModelOutput]:
342
+ all_hidden_states = () if output_hidden_states else None
343
+ all_self_attentions = () if output_attentions else None
344
+
345
+ for i, layer_module in enumerate(self.layer):
346
+ if output_hidden_states:
347
+ all_hidden_states = all_hidden_states + (hidden_states,)
348
+
349
+ layer_head_mask = head_mask[i] if head_mask is not None else None
350
+
351
+ if self.gradient_checkpointing and self.training:
352
+
353
+ def create_custom_forward(module):
354
+ def custom_forward(*inputs):
355
+ return module(*inputs, output_attentions)
356
+
357
+ return custom_forward
358
+
359
+ layer_outputs = torch.utils.checkpoint.checkpoint(
360
+ create_custom_forward(layer_module),
361
+ hidden_states,
362
+ tok_mask,
363
+ layer_head_mask,
364
+ )
365
+ else:
366
+ layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions)
367
+
368
+ hidden_states = layer_outputs[0]
369
+
370
+ if output_attentions:
371
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
372
+
373
+ if output_hidden_states:
374
+ all_hidden_states = all_hidden_states + (hidden_states,)
375
+
376
+ if not return_dict:
377
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
378
+ return BaseModelOutput(
379
+ last_hidden_state=hidden_states,
380
+ hidden_states=all_hidden_states,
381
+ attentions=all_self_attentions,
382
+ )
383
+
384
+
385
+ class ASTPreTrainedModel(PreTrainedModel):
386
+ """
387
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
388
+ models.
389
+ """
390
+
391
+ config_class = ASTConfig
392
+ base_model_prefix = "audio_spectrogram_transformer"
393
+ main_input_name = "input_values"
394
+ supports_gradient_checkpointing = True
395
+
396
+ # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
397
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
398
+ """Initialize the weights"""
399
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
400
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
401
+ # `trunc_normal_cpu` not implemented in `half` issues
402
+ module.weight.data = nn.init.trunc_normal_(
403
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
404
+ ).to(module.weight.dtype)
405
+ if module.bias is not None:
406
+ module.bias.data.zero_()
407
+ elif isinstance(module, nn.LayerNorm):
408
+ module.bias.data.zero_()
409
+ module.weight.data.fill_(1.0)
410
+
411
+ # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
412
+ def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None:
413
+ if isinstance(module, ASTEncoder):
414
+ module.gradient_checkpointing = value
415
+
416
+
417
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
418
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
419
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
420
+ behavior.
421
+
422
+ Parameters:
423
+ config ([`ASTConfig`]):
424
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
425
+ load the weights associated with the model, only the configuration. Check out the
426
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
427
+ """
428
+
429
+ AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
430
+ Args:
431
+ input_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
432
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
433
+ [`ASTFeatureExtractor.__call__`] for details.
434
+
435
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
436
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
437
+
438
+ - 1 indicates the head is **not masked**,
439
+ - 0 indicates the head is **masked**.
440
+
441
+ output_attentions (`bool`, *optional*):
442
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
443
+ tensors for more detail.
444
+ output_hidden_states (`bool`, *optional*):
445
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
446
+ more detail.
447
+ return_dict (`bool`, *optional*):
448
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
449
+ """
450
+
451
+
452
+ @add_start_docstrings(
453
+ "The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
454
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
455
+ )
456
+ class ASTModel(ASTPreTrainedModel):
457
+ def __init__(self, config: ASTConfig):
458
+ super().__init__(config)
459
+ self.config = config
460
+
461
+ self.embeddings = ASTEmbeddings(config)
462
+ self.encoder = ASTEncoder(config)
463
+
464
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
465
+
466
+ # Initialize weights and apply final processing
467
+ self.post_init()
468
+
469
+ def get_input_embeddings(self) -> ASTPatchEmbeddings:
470
+ return self.embeddings.patch_embeddings
471
+
472
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
473
+ """
474
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
475
+ class PreTrainedModel
476
+ """
477
+ for layer, heads in heads_to_prune.items():
478
+ self.encoder.layer[layer].attention.prune_heads(heads)
479
+
480
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
481
+ @add_code_sample_docstrings(
482
+ checkpoint=_CHECKPOINT_FOR_DOC,
483
+ output_type=BaseModelOutputWithPooling,
484
+ config_class=_CONFIG_FOR_DOC,
485
+ modality="audio",
486
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
487
+ )
488
+ def forward(
489
+ self,
490
+ input_values: Optional[torch.Tensor] = None,
491
+ cont_mask: Optional[torch.Tensor] = None,
492
+ head_mask: Optional[torch.Tensor] = None,
493
+ output_attentions: Optional[bool] = None,
494
+ output_hidden_states: Optional[bool] = None,
495
+ return_dict: Optional[bool] = None,
496
+ ):
497
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
498
+ output_hidden_states = (
499
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
500
+ )
501
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
502
+
503
+ if input_values is None:
504
+ raise ValueError("You have to specify input_values")
505
+
506
+ # Prepare head mask if needed
507
+ # 1.0 in head_mask indicate we keep the head
508
+ # attention_probs has shape bsz x n_heads x N x N
509
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
510
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
511
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
512
+
513
+ embedding_output = self.embeddings(input_values)
514
+
515
+ # transforms the mask that has spectrogram dims to the token masking which is obtained after patching.
516
+ # Due to the ovelap in patching, getting the token mask from spectrogram mask is not straightforward,
517
+ # because one 16x16 content patch is encoded in two tokens if stride is <16. So, to get the mask for
518
+ # tokens I will apply the patching func (self.embeddings) to the tensor with infinities at the masked
519
+ # content position. For infs, the patching fn will return nans, which I'll use to get the token mask.
520
+ if cont_mask is not None:
521
+ indicator = torch.ones_like(input_values).to(input_values.dtype)
522
+ # replace content mask (0s) with infs
523
+ indicator[~cont_mask] = torch.inf
524
+ # apply patching; now nans are where the content mask was
525
+ with torch.no_grad():
526
+ indicator = self.embeddings(indicator) # BS, N, D
527
+ # replace nans with 0s; these are the tokens that correspond to the masked content
528
+ tok_mask = ~torch.isnan(indicator)
529
+ # since all values in the D-dimension (latent) will also be nans, we can just use the first el
530
+ tok_mask = tok_mask[:, :, 0] # (BS, 2+num_patches) -- 2 is from CLS and DISTIL tokens
531
+ else:
532
+ tok_mask = None
533
+
534
+ encoder_outputs = self.encoder(
535
+ embedding_output,
536
+ tok_mask=tok_mask,
537
+ head_mask=head_mask,
538
+ output_attentions=output_attentions,
539
+ output_hidden_states=output_hidden_states,
540
+ return_dict=return_dict,
541
+ )
542
+ sequence_output = encoder_outputs[0]
543
+ sequence_output = self.layernorm(sequence_output)
544
+
545
+ pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
546
+
547
+ if not return_dict:
548
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
549
+
550
+ return BaseModelOutputWithPooling(
551
+ last_hidden_state=sequence_output,
552
+ pooler_output=pooled_output,
553
+ hidden_states=encoder_outputs.hidden_states,
554
+ attentions=encoder_outputs.attentions,
555
+ ), tok_mask
556
+
557
+
558
+ class ASTMLPHead(nn.Module):
559
+ def __init__(self, config: ASTConfig):
560
+ super().__init__()
561
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
562
+ self.dense = nn.Linear(
563
+ config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
564
+
565
+ def forward(self, hidden_state):
566
+ hidden_state = self.layernorm(hidden_state)
567
+ hidden_state = self.dense(hidden_state)
568
+ return hidden_state
569
+
570
+
571
+ @add_start_docstrings(
572
+ """
573
+ Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
574
+ output) e.g. for datasets like AudioSet, Speech Commands v2.
575
+ """,
576
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
577
+ )
578
+ class ASTForAudioClassification(ASTPreTrainedModel):
579
+ def __init__(self, config: ASTConfig) -> None:
580
+ super().__init__(config)
581
+
582
+ self.num_labels = config.num_labels
583
+ self.audio_spectrogram_transformer = ASTModel(config)
584
+
585
+ # Classifier head
586
+ self.classifier = ASTMLPHead(config)
587
+
588
+ # Initialize weights and apply final processing
589
+ self.post_init()
590
+
591
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
592
+ @add_code_sample_docstrings(
593
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
594
+ output_type=SequenceClassifierOutput,
595
+ config_class=_CONFIG_FOR_DOC,
596
+ modality="audio",
597
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
598
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
599
+ )
600
+ def forward(
601
+ self,
602
+ input_values: Optional[torch.Tensor] = None,
603
+ cont_mask: Optional[torch.Tensor] = None,
604
+ head_mask: Optional[torch.Tensor] = None,
605
+ labels: Optional[torch.Tensor] = None,
606
+ output_attentions: Optional[bool] = None,
607
+ output_hidden_states: Optional[bool] = None,
608
+ return_dict: Optional[bool] = None,
609
+ ) -> Union[tuple, SequenceClassifierOutput]:
610
+ r"""
611
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
612
+ Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
613
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
614
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
615
+ """
616
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
617
+
618
+ outputs = self.audio_spectrogram_transformer(
619
+ input_values,
620
+ cont_mask=cont_mask,
621
+ head_mask=head_mask,
622
+ output_attentions=output_attentions,
623
+ output_hidden_states=output_hidden_states,
624
+ return_dict=return_dict,
625
+ )
626
+
627
+ pooled_output = outputs[1]
628
+ logits = self.classifier(pooled_output)
629
+
630
+ loss = None
631
+ if labels is not None:
632
+ if self.config.problem_type is None:
633
+ if self.num_labels == 1:
634
+ self.config.problem_type = "regression"
635
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
636
+ self.config.problem_type = "single_label_classification"
637
+ else:
638
+ self.config.problem_type = "multi_label_classification"
639
+
640
+ if self.config.problem_type == "regression":
641
+ loss_fct = MSELoss()
642
+ if self.num_labels == 1:
643
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
644
+ else:
645
+ loss = loss_fct(logits, labels)
646
+ elif self.config.problem_type == "single_label_classification":
647
+ loss_fct = CrossEntropyLoss()
648
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
649
+ elif self.config.problem_type == "multi_label_classification":
650
+ loss_fct = BCEWithLogitsLoss()
651
+ loss = loss_fct(logits, labels)
652
+
653
+ if not return_dict:
654
+ output = (logits,) + outputs[2:]
655
+ return ((loss,) + output) if loss is not None else output
656
+
657
+ return SequenceClassifierOutput(
658
+ loss=loss,
659
+ logits=logits,
660
+ hidden_states=outputs.hidden_states,
661
+ attentions=outputs.attentions,
662
+ )
modules/model/modules/feat_extractors/audio/resnet.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import logging
4
+ import einops
5
+
6
+ import torch
7
+ from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
8
+
9
+ sys.path.append('.') # nopep8
10
+
11
+ from utils.utils import check_if_file_exists_else_download
12
+ from model.modules.feat_extractors.audio.ast import FrequencyTransformerEncoderLayer
13
+ from model.modules.feat_extractors.visual.motionformer import AveragePooling, TemporalTransformerEncoderLayer
14
+
15
+
16
+ class ResNetAudio(ResNet):
17
+
18
+ def __init__(self, arch_name, num_classes, extract_features, ckpt_path=None, **kwargs):
19
+
20
+ if arch_name == 'resnet18':
21
+ block = BasicBlock
22
+ layers = [2, 2, 2, 2]
23
+ elif arch_name == 'resnet34':
24
+ block = BasicBlock
25
+ layers = [3, 4, 6, 3]
26
+ elif arch_name == 'resnet50':
27
+ block = Bottleneck
28
+ layers = [3, 4, 6, 3]
29
+ elif arch_name == 'resnet101':
30
+ block = Bottleneck
31
+ layers = [3, 4, 23, 3]
32
+ elif arch_name == 'resnet152':
33
+ block = Bottleneck
34
+ layers = [3, 8, 36, 3]
35
+ else:
36
+ raise NotImplementedError
37
+
38
+ super().__init__(block, layers, num_classes, **kwargs)
39
+
40
+ # replacing the old conv1 to the new one (RGB - 3; spectrogram - 1)
41
+ conv1 = self.conv1
42
+ self.conv1 = torch.nn.Conv2d(1, conv1.out_channels, conv1.kernel_size,
43
+ conv1.stride, conv1.padding, bias=conv1.bias)
44
+ self.extract_features = extract_features
45
+ self.embed_dim = self.fc.in_features
46
+
47
+ # load the ckpt
48
+ load_state_dict_resnet(self, ckpt_path, prefix='afeat_extractor.')
49
+
50
+ def _forward_impl(self, x):
51
+ # See note [TorchScript super()]
52
+ x = self.conv1(x)
53
+ x = self.bn1(x)
54
+ x = self.relu(x)
55
+ x = self.maxpool(x)
56
+
57
+ x = self.layer1(x)
58
+ x = self.layer2(x)
59
+ x = self.layer3(x)
60
+ x = self.layer4(x)
61
+
62
+ if self.extract_features:
63
+ return x
64
+
65
+ x = self.avgpool(x)
66
+ x = torch.flatten(x, 1)
67
+ x = self.fc(x)
68
+
69
+ return x
70
+
71
+ def forward(self, x):
72
+ return super().forward(x)
73
+
74
+
75
+ class ResNet18AudioFeatures(ResNetAudio):
76
+
77
+ # ckpt_path should default to None, otherwise when no pre-training is desired it will throw an error
78
+ def __init__(self,
79
+ extract_features: bool = False,
80
+ ckpt_path: str = None,
81
+ feat_type: str = None,
82
+ max_spec_t: int = None,
83
+ factorize_freq_time: bool = None,
84
+ agg_freq_module: str = None,
85
+ agg_time_module: str = None,
86
+ add_global_repr: bool = True,
87
+ agg_segments_module: str = None,
88
+ max_segments: int = None,
89
+ ) -> None:
90
+ super().__init__(arch_name='resnet18', num_classes=308, extract_features=extract_features,
91
+ ckpt_path=ckpt_path)
92
+ assert extract_features, 'Not implemented otherwise'
93
+ self.extract_features = extract_features
94
+ self.feat_type = feat_type
95
+ self.max_spec_t = max_spec_t
96
+ # similar to s3d
97
+ self.nhead = 8
98
+ self.mlp_ratio = 4
99
+ self.drop_rate = 0.0
100
+
101
+ if ckpt_path is not None:
102
+ ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
103
+ was_pt_on_vgs_cls = 'ResNetAudio-' in Path(ckpt_path).stem
104
+ if was_pt_on_vgs_cls:
105
+ self.load_state_dict(ckpt['model'], strict=True)
106
+
107
+ # saving some memory
108
+ if extract_features:
109
+ self.avgpool = torch.nn.Identity()
110
+ self.fc = torch.nn.Identity()
111
+
112
+ # define adapters if needed
113
+ self.factorize_freq_time = factorize_freq_time
114
+ # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
115
+ transf_enc_layer_kwargs = dict(
116
+ d_model=self.embed_dim, nhead=self.nhead, dim_feedforward=self.mlp_ratio*self.embed_dim,
117
+ activation=torch.nn.GELU(), batch_first=True, dropout=self.drop_rate, layer_norm_eps=1e-6,
118
+ norm_first=True,
119
+ )
120
+ if factorize_freq_time:
121
+ self.feat_type = 'last_hidden_state' # this feat_type supports factorization
122
+ # frequency aggreration
123
+ if agg_freq_module == 'TransformerEncoderLayer':
124
+ self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
125
+ elif agg_freq_module == 'AveragePooling':
126
+ self.freq_attn_agg = AveragePooling(avg_pattern='BS D f t -> BS D t',
127
+ then_permute_pattern='BS D t -> BS t D')
128
+ # time aggreration
129
+ if agg_time_module == 'TransformerEncoderLayer':
130
+ self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
131
+ elif agg_time_module == 'AveragePooling':
132
+ self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
133
+ elif 'Identity' in agg_time_module:
134
+ self.temp_attn_agg = torch.nn.Identity()
135
+ # define a global aggregation layer (aggregarate over segments)
136
+ self.add_global_repr = add_global_repr
137
+ if add_global_repr:
138
+ if agg_segments_module == 'TransformerEncoderLayer':
139
+ # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
140
+ # we need to add pos emb (PE) because previously we added the same PE for each segment
141
+ pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
142
+ self.global_attn_agg = TemporalTransformerEncoderLayer(
143
+ add_pos_emb=True, pos_emb_drop=self.drop_rate,
144
+ pos_max_len=pos_max_len, **transf_enc_layer_kwargs
145
+ )
146
+ elif agg_segments_module == 'AveragePooling':
147
+ self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
148
+
149
+ # do not keep fc to save memory
150
+ self.fc = torch.nn.Identity()
151
+
152
+ if ckpt_path is not None:
153
+ ckpt = ckpt['state_dict']
154
+ was_pt_on_avclip = any('a_encoder.' in k[0] or 'v_encoder.' in k[0] for k in ckpt.items())
155
+ assert was_pt_on_vgs_cls is False, f'Unexpected ckpt: {ckpt_path}'
156
+ if was_pt_on_avclip:
157
+ ckpt_weights = dict()
158
+ for k, v in ckpt.items():
159
+ if k.startswith(('module.a_encoder.', 'a_encoder.')):
160
+ k = k.replace('module.', '').replace('a_encoder.', '')
161
+ ckpt_weights[k] = v
162
+ _load_status = self.load_state_dict(ckpt_weights, strict=False)
163
+ if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
164
+ logging.warning(f'Loading exact ckpt from {ckpt_path} failed. \n' \
165
+ f'Missing keys ({len(_load_status.missing_keys)}): ' \
166
+ f'{_load_status.missing_keys}, \n' \
167
+ f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
168
+ f'{_load_status.unexpected_keys} \n' \
169
+ f'freq_attn_agg are expected to be unexpected if ckpt was pt contrastively '\
170
+ f'as well as fc could be missing because we use features, not a classifier.')
171
+ else:
172
+ logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
173
+
174
+ # print the number of parameters
175
+ logging.info(f'afeat_extractor: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}')
176
+
177
+ def forward(self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None):
178
+ assert for_loop is False and cont_mask is None, 'Not implemented'
179
+ B, S, T, F = x.shape
180
+
181
+ # (BS, D) <- (B, S, T, D)
182
+ x = self.forward_segments(x)
183
+
184
+ # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
185
+ x = x.view(B, S, *x.shape[1:])
186
+ # x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
187
+
188
+ global_x = None
189
+ if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
190
+ assert len(x.shape) == 3, f'Local representation should be (B, S, D) {x.shape}'
191
+ global_x = self.global_attn_agg(x) # (B, D)
192
+
193
+ return x, global_x # x is (B, S, ...), global_x is (B, D) or None
194
+
195
+ def forward_segments(self, x):
196
+ x = einops.rearrange(x, 'B S T F -> (B S) 1 F T')
197
+ # (BS, D, f, t) <- (BS, 1, F, T)
198
+ x = super().forward(x)
199
+
200
+ if self.extract_features:
201
+ if self.factorize_freq_time:
202
+ x = self.freq_attn_agg(x) # (BS, t, D)
203
+ x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
204
+ return x
205
+
206
+
207
+ def load_state_dict_resnet(model, ckpt_path, prefix):
208
+ if ckpt_path is not None:
209
+ check_if_file_exists_else_download(ckpt_path)
210
+ ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
211
+ ckpt = ckpt.get('model', ckpt.get('state_dict', ckpt))
212
+ # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
213
+ # and keep only the state_dict of the feat extractor
214
+ # FIXME: this is a bit hacky, but it works
215
+ was_pt_on_avclip = any('a_encoder.' in k[0] or 'v_encoder.' in k[0] for k in ckpt.items())
216
+ if not was_pt_on_avclip:
217
+ model.load_state_dict(ckpt)
218
+ logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
219
+ # if was_pt_on_avclip:
220
+ # ckpt_weights = dict()
221
+ # for k, v in ckpt.items():
222
+ # if k.startswith(('module.a_encoder.', 'a_encoder.')):
223
+ # k = k.replace('module.', '').replace('a_encoder.', '')
224
+ # ckpt_weights[k] = v
225
+ # _load_status = model.load_state_dict(ckpt_weights, strict=False)
226
+ # if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
227
+ # logging.warning(f'Loading exact ckpt from {ckpt_path} failed. \n' \
228
+ # f'Missing keys ({len(_load_status.missing_keys)}): ' \
229
+ # f'{_load_status.missing_keys}, \n' \
230
+ # f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
231
+ # f'{_load_status.unexpected_keys} \n' \
232
+ # f'freq_attn_agg are expected to be unexpected if ckpt was pt contrastively '\
233
+ # f'as well as fc could be missing because we use features, not a classifier.')
234
+ # else:
235
+ # logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
236
+ # else:
237
+ # model.load_state_dict(ckpt)
238
+ # logging.info(f'Loading ResNet ckpt from {ckpt_path} succeeded.')
239
+
240
+
241
+ if __name__ == '__main__':
242
+ B = 2
243
+ ckpt_path = './model/modules/feat_extractors/audio/22-06-24T08-10-33/ResNetAudio-22-06-24T08-10-33.pt'
244
+ afeat_extractor = ResNet18AudioFeatures(ckpt_path=ckpt_path)
245
+ # x = torch.rand(B, 1, 257, 1551)
246
+ # x = torch.rand(B, 1, 257, 1379)
247
+ x = torch.rand(B, 1, 128, 66)
248
+ x = afeat_extractor(x)
249
+ print(x.shape)
modules/model/modules/feat_extractors/train_clip_src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys
2
+ sys.path.append('model/modules/feat_extractors')
3
+ sys.path.append('model/modules/feat_extractors/train_clip_src')
modules/model/modules/feat_extractors/train_clip_src/open_clip/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .coca_model import CoCa
2
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3
+ from .factory import create_model, create_model_from_pretrained, get_tokenizer, create_loss
4
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
8
+ from .openai import load_openai_model, list_openai_models
9
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12
+ from .tokenizer import SimpleTokenizer, tokenize, decode
13
+ from .transform import image_transform, AugmentationCfg
modules/model/modules/feat_extractors/train_clip_src/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
modules/model/modules/feat_extractors/train_clip_src/open_clip/coca_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StoppingCriteriaList
27
+ )
28
+
29
+ GENERATION_TYPES = {
30
+ "top_k": TopKLogitsWarper,
31
+ "top_p": TopPLogitsWarper,
32
+ "beam_search": "beam_search"
33
+ }
34
+ _has_transformers = True
35
+ except ImportError as e:
36
+ GENERATION_TYPES = {
37
+ "top_k": None,
38
+ "top_p": None,
39
+ "beam_search": "beam_search"
40
+ }
41
+ _has_transformers = False
42
+
43
+
44
+ @dataclass
45
+ class MultimodalCfg(CLIPTextCfg):
46
+ mlp_ratio: int = 4
47
+ dim_head: int = 64
48
+ heads: int = 8
49
+ n_queries: int = 256
50
+ attn_pooler_heads: int = 8
51
+
52
+
53
+ def _build_text_decoder_tower(
54
+ embed_dim,
55
+ multimodal_cfg,
56
+ quick_gelu: bool = False,
57
+ cast_dtype: Optional[torch.dtype] = None,
58
+ ):
59
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
60
+ act_layer = QuickGELU if quick_gelu else nn.GELU
61
+ norm_layer = (
62
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
63
+ )
64
+
65
+ decoder = MultimodalTransformer(
66
+ context_length=multimodal_cfg.context_length,
67
+ width=multimodal_cfg.width,
68
+ heads=multimodal_cfg.heads,
69
+ layers=multimodal_cfg.layers,
70
+ ls_init_value=multimodal_cfg.ls_init_value,
71
+ output_dim=embed_dim,
72
+ act_layer=act_layer,
73
+ norm_layer=norm_layer,
74
+ )
75
+
76
+ return decoder
77
+
78
+
79
+ class CoCa(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ multimodal_cfg: MultimodalCfg,
84
+ text_cfg: CLIPTextCfg,
85
+ vision_cfg: CLIPVisionCfg,
86
+ quick_gelu: bool = False,
87
+ cast_dtype: Optional[torch.dtype] = None,
88
+ pad_id: int = 0,
89
+ ):
90
+ super().__init__()
91
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
92
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
94
+
95
+ self.text = _build_text_tower(
96
+ embed_dim=embed_dim,
97
+ text_cfg=text_cfg,
98
+ quick_gelu=quick_gelu,
99
+ cast_dtype=cast_dtype,
100
+ )
101
+
102
+ vocab_size = (
103
+ text_cfg.vocab_size # for hf models
104
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
105
+ else text_cfg.vocab_size
106
+ )
107
+
108
+ self.visual = _build_vision_tower(
109
+ embed_dim=embed_dim,
110
+ vision_cfg=vision_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ self.text_decoder = _build_text_decoder_tower(
116
+ vocab_size,
117
+ multimodal_cfg=multimodal_cfg,
118
+ quick_gelu=quick_gelu,
119
+ cast_dtype=cast_dtype,
120
+ )
121
+
122
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
123
+ self.pad_id = pad_id
124
+
125
+ @torch.jit.ignore
126
+ def set_grad_checkpointing(self, enable=True):
127
+ self.visual.set_grad_checkpointing(enable)
128
+ self.text.set_grad_checkpointing(enable)
129
+ self.text_decoder.set_grad_checkpointing(enable)
130
+
131
+ def _encode_image(self, images, normalize=True):
132
+ image_latent, tokens_embs = self.visual(images)
133
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
134
+ return image_latent, tokens_embs
135
+
136
+ def _encode_text(self, text, normalize=True, embed_cls=True):
137
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
138
+ text_latent, token_emb = self.text(text)
139
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
140
+ return text_latent, token_emb
141
+
142
+ def encode_image(self, images, normalize=True):
143
+ image_latent, _ = self._encode_image(images, normalize=normalize)
144
+ return image_latent
145
+
146
+ def encode_text(self, text, normalize=True, embed_cls=True):
147
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
148
+ return text_latent
149
+
150
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
151
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
152
+ if image_latent is None or image_embs is None:
153
+ image_latent, image_embs = self._encode_image(image)
154
+
155
+ # TODO: add assertion to avoid bugs?
156
+ labels = text[:, -token_embs.shape[1]:]
157
+
158
+ logits = self.text_decoder(image_embs, token_embs)
159
+ return {
160
+ "image_features": image_latent,
161
+ "text_features": text_latent,
162
+ "logits": logits,
163
+ "labels": labels,
164
+ "logit_scale": self.logit_scale.exp()
165
+ }
166
+
167
+ def generate(
168
+ self,
169
+ image,
170
+ text=None,
171
+ seq_len=30,
172
+ max_seq_len=77,
173
+ temperature=1.,
174
+ generation_type="beam_search",
175
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
176
+ top_k=1, # keeps the top_k most probable tokens
177
+ pad_token_id=None,
178
+ eos_token_id=None,
179
+ sot_token_id=None,
180
+ num_beams=6,
181
+ num_beam_groups=3,
182
+ min_seq_len=5,
183
+ stopping_criteria=None,
184
+ repetition_penalty=1.0,
185
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
186
+ ):
187
+ # taking many ideas and components from HuggingFace GenerationMixin
188
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
189
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
190
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
191
+
192
+ with torch.no_grad():
193
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
194
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
195
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
196
+ logit_processor = LogitsProcessorList(
197
+ [
198
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
199
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
200
+ ]
201
+ )
202
+
203
+ if stopping_criteria is None:
204
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
205
+
206
+ stopping_criteria = StoppingCriteriaList(
207
+ stopping_criteria
208
+ )
209
+
210
+ device = image.device
211
+
212
+ if generation_type == "beam_search":
213
+ output = self._generate_beamsearch(
214
+ image_inputs = image,
215
+ pad_token_id=pad_token_id,
216
+ eos_token_id=eos_token_id,
217
+ sot_token_id=sot_token_id,
218
+ num_beams=num_beams,
219
+ num_beam_groups=num_beam_groups,
220
+ min_seq_len=min_seq_len,
221
+ stopping_criteria=stopping_criteria,
222
+ logit_processor=logit_processor,
223
+ )
224
+ if fixed_output_length and output.shape[1] < seq_len:
225
+ return torch.cat(
226
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
227
+ dim=1
228
+ )
229
+ return output
230
+
231
+ elif generation_type == "top_p":
232
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
233
+ elif generation_type == "top_k":
234
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
235
+ else:
236
+ raise ValueError(
237
+ f"generation_type has to be one of "
238
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
239
+ )
240
+
241
+ image_latent, image_embs = self._encode_image(image)
242
+
243
+ if text is None:
244
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
245
+
246
+ was_training = self.training
247
+ num_dims = len(text.shape)
248
+
249
+ if num_dims == 1:
250
+ text = text[None, :]
251
+
252
+ cur_len = text.shape[1]
253
+ self.eval()
254
+ out = text
255
+
256
+ while True:
257
+ x = out[:, -max_seq_len:]
258
+ cur_len = x.shape[1]
259
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
260
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
261
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
262
+
263
+ if mask.all():
264
+ if not fixed_output_length:
265
+ break
266
+ else:
267
+ logits = logits[~mask, :]
268
+ filtered_logits = logit_processor(x[~mask, :], logits)
269
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
270
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
271
+
272
+ if (cur_len + 1 == seq_len):
273
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
274
+ else:
275
+ sample[~mask, :] = torch.multinomial(probs, 1)
276
+
277
+ out = torch.cat((out, sample), dim=-1)
278
+
279
+ cur_len += 1
280
+
281
+ if stopping_criteria(out, None):
282
+ break
283
+
284
+ if num_dims == 1:
285
+ out = out.squeeze(0)
286
+
287
+ self.train(was_training)
288
+ return out
289
+
290
+ def _generate_beamsearch(
291
+ self,
292
+ image_inputs,
293
+ pad_token_id=None,
294
+ eos_token_id=None,
295
+ sot_token_id=None,
296
+ num_beams=6,
297
+ num_beam_groups=3,
298
+ min_seq_len=5,
299
+ stopping_criteria=None,
300
+ logit_processor=None,
301
+ logit_warper=None,
302
+ ):
303
+ device = image_inputs.device
304
+ batch_size = image_inputs.shape[0]
305
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
306
+ image_latent, image_embs = self._encode_image(image_inputs)
307
+
308
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
309
+ input_ids = input_ids * sot_token_id
310
+ beam_scorer = BeamSearchScorer(
311
+ batch_size=batch_size,
312
+ num_beams=num_beams,
313
+ device=device,
314
+ num_beam_groups=num_beam_groups,
315
+ )
316
+ # instantiate logits processors
317
+ logits_processor = (
318
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
319
+ if logit_processor is None
320
+ else logit_processor
321
+ )
322
+
323
+ batch_size = len(beam_scorer._beam_hyps)
324
+ num_beams = beam_scorer.num_beams
325
+ num_beam_groups = beam_scorer.num_beam_groups
326
+ num_sub_beams = num_beams // num_beam_groups
327
+ batch_beam_size, cur_len = input_ids.shape
328
+ beam_indices = None
329
+
330
+ if num_beams * batch_size != batch_beam_size:
331
+ raise ValueError(
332
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
333
+ )
334
+
335
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
336
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
337
+ # the same group don't produce same tokens everytime.
338
+ beam_scores[:, ::num_sub_beams] = 0
339
+ beam_scores = beam_scores.view((batch_size * num_beams,))
340
+
341
+ while True:
342
+
343
+ # predicted tokens in cur_len step
344
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
345
+
346
+ # indices which will form the beams in the next time step
347
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
348
+
349
+ # do one decoder step on all beams of all sentences in batch
350
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
351
+ outputs = self(
352
+ model_inputs['images'],
353
+ model_inputs['text'],
354
+ embed_cls=False,
355
+ image_latent=image_latent,
356
+ image_embs=image_embs
357
+ )
358
+
359
+ for beam_group_idx in range(num_beam_groups):
360
+ group_start_idx = beam_group_idx * num_sub_beams
361
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
362
+ group_size = group_end_idx - group_start_idx
363
+
364
+ # indices of beams of current group among all sentences in batch
365
+ batch_group_indices = []
366
+
367
+ for batch_idx in range(batch_size):
368
+ batch_group_indices.extend(
369
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
370
+ )
371
+ group_input_ids = input_ids[batch_group_indices]
372
+
373
+ # select outputs of beams of currentg group only
374
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
375
+ vocab_size = next_token_logits.shape[-1]
376
+
377
+ next_token_scores_processed = logits_processor(
378
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
379
+ )
380
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
381
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
382
+
383
+ # reshape for beam search
384
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
385
+
386
+ next_token_scores, next_tokens = torch.topk(
387
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
388
+ )
389
+
390
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
391
+ next_tokens = next_tokens % vocab_size
392
+
393
+ # stateless
394
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
395
+ beam_outputs = beam_scorer.process(
396
+ group_input_ids,
397
+ next_token_scores,
398
+ next_tokens,
399
+ next_indices,
400
+ pad_token_id=pad_token_id,
401
+ eos_token_id=eos_token_id,
402
+ beam_indices=process_beam_indices,
403
+ )
404
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
405
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
406
+ beam_idx = beam_outputs["next_beam_indices"]
407
+
408
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
409
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
410
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
411
+
412
+ # (beam_idx // group_size) -> batch_idx
413
+ # (beam_idx % group_size) -> offset of idx inside the group
414
+ reordering_indices[batch_group_indices] = (
415
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
416
+ )
417
+
418
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
419
+
420
+ # increase cur_len
421
+ cur_len = cur_len + 1
422
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
423
+ break
424
+
425
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
426
+ sequence_outputs = beam_scorer.finalize(
427
+ input_ids,
428
+ beam_scores,
429
+ next_tokens,
430
+ next_indices,
431
+ pad_token_id=pad_token_id,
432
+ eos_token_id=eos_token_id,
433
+ max_length=stopping_criteria.max_length,
434
+ beam_indices=final_beam_indices,
435
+ )
436
+ return sequence_outputs['sequences']
437
+
438
+
439
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
440
+ if past:
441
+ input_ids = input_ids[:, -1].unsqueeze(-1)
442
+
443
+ attention_mask = kwargs.get("attention_mask", None)
444
+ position_ids = kwargs.get("position_ids", None)
445
+
446
+ if attention_mask is not None and position_ids is None:
447
+ # create position_ids on the fly for batch generation
448
+ position_ids = attention_mask.long().cumsum(-1) - 1
449
+ position_ids.masked_fill_(attention_mask == 0, 1)
450
+ else:
451
+ position_ids = None
452
+ return {
453
+ "text": input_ids,
454
+ "images": image_inputs,
455
+ "past_key_values": past,
456
+ "position_ids": position_ids,
457
+ "attention_mask": attention_mask,
458
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
modules/model/modules/feat_extractors/train_clip_src/open_clip/factory.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+
9
+ from utils.utils import instantiate_from_config
10
+
11
+ from .model import convert_to_custom_text_state_dict, resize_pos_embed
12
+ from .loss import AVCLIPLoss, ClipLoss, DistillClipLoss, CoCaLoss, MultilevelAVCLIPLoss
13
+ from .transform import image_transform
14
+ from .tokenizer import HFTokenizer, tokenize
15
+
16
+
17
+ HF_HUB_PREFIX = 'hf-hub:'
18
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
19
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
20
+
21
+
22
+ def _natural_key(string_):
23
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
24
+
25
+
26
+ def _rescan_model_configs():
27
+ global _MODEL_CONFIGS
28
+
29
+ config_ext = ('.json',)
30
+ config_files = []
31
+ for config_path in _MODEL_CONFIG_PATHS:
32
+ if config_path.is_file() and config_path.suffix in config_ext:
33
+ config_files.append(config_path)
34
+ elif config_path.is_dir():
35
+ for ext in config_ext:
36
+ config_files.extend(config_path.glob(f'*{ext}'))
37
+
38
+ for cf in config_files:
39
+ with open(cf, 'r') as f:
40
+ model_cfg = json.load(f)
41
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
42
+ _MODEL_CONFIGS[cf.stem] = model_cfg
43
+
44
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
45
+
46
+
47
+ _rescan_model_configs() # initial populate of model config registry
48
+
49
+
50
+ def list_models():
51
+ """ enumerate available model architectures based on config files """
52
+ return list(_MODEL_CONFIGS.keys())
53
+
54
+
55
+ def add_model_config(path):
56
+ """ add model config path or file and update registry """
57
+ if not isinstance(path, Path):
58
+ path = Path(path)
59
+ _MODEL_CONFIG_PATHS.append(path)
60
+ _rescan_model_configs()
61
+
62
+
63
+ def get_model_config(model_name):
64
+ if model_name in _MODEL_CONFIGS:
65
+ return deepcopy(_MODEL_CONFIGS[model_name])
66
+ else:
67
+ return None
68
+
69
+
70
+ def get_tokenizer(model_name):
71
+ if model_name.startswith(HF_HUB_PREFIX):
72
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
73
+ else:
74
+ config = get_model_config(model_name)
75
+ tokenizer = HFTokenizer(
76
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
77
+ return tokenizer
78
+
79
+
80
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
81
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
82
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
83
+ state_dict = checkpoint['state_dict']
84
+ else:
85
+ state_dict = checkpoint
86
+ if next(iter(state_dict.items()))[0].startswith('module'):
87
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
88
+ return state_dict
89
+
90
+
91
+ def load_checkpoint(model, checkpoint_path, strict=True):
92
+ state_dict = load_state_dict(checkpoint_path)
93
+ # detect old format and make compatible with new format
94
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
95
+ state_dict = convert_to_custom_text_state_dict(state_dict)
96
+ resize_pos_embed(state_dict, model)
97
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
98
+ return incompatible_keys
99
+
100
+
101
+ def create_model(cfg, device: Union[str, torch.device] = 'cpu'):
102
+ if isinstance(device, str):
103
+ device = torch.device(device)
104
+ model = instantiate_from_config(cfg.model)
105
+ model.to(device=device)
106
+ return model
107
+
108
+
109
+ def create_loss(args):
110
+ if args.distill:
111
+ return DistillClipLoss(
112
+ local_loss=args.local_loss,
113
+ gather_with_grad=args.gather_with_grad,
114
+ cache_labels=True,
115
+ rank=args.rank,
116
+ world_size=args.world_size,
117
+ )
118
+ elif 'coca' in args.model.target.split('.')[-1].lower():
119
+ return CoCaLoss(
120
+ caption_loss_weight=args.coca_caption_loss_weight,
121
+ clip_loss_weight=args.coca_contrastive_loss_weight,
122
+ local_loss=args.local_loss,
123
+ gather_with_grad=args.gather_with_grad,
124
+ cache_labels=True,
125
+ rank=args.rank,
126
+ world_size=args.world_size,
127
+ )
128
+ elif 'multilevel' in args.model.target.split('.')[-1].lower():
129
+ return MultilevelAVCLIPLoss(
130
+ local_loss=args.local_loss,
131
+ gather_with_grad=args.gather_with_grad,
132
+ cache_labels=True,
133
+ rank=args.rank,
134
+ world_size=args.world_size,
135
+ )
136
+ elif 'avclip' in args.model.target.split('.')[-1].lower():
137
+ return AVCLIPLoss(
138
+ local_loss=args.local_loss,
139
+ gather_with_grad=args.gather_with_grad,
140
+ cache_labels=True,
141
+ rank=args.rank,
142
+ world_size=args.world_size,
143
+ )
144
+
145
+ return ClipLoss(
146
+ local_loss=args.local_loss,
147
+ gather_with_grad=args.gather_with_grad,
148
+ cache_labels=True,
149
+ rank=args.rank,
150
+ world_size=args.world_size,
151
+ )
152
+
153
+
154
+ def create_model_from_pretrained(
155
+ model_name: str,
156
+ pretrained: Optional[str] = None,
157
+ precision: str = 'fp32',
158
+ device: Union[str, torch.device] = 'cpu',
159
+ jit: bool = False,
160
+ force_quick_gelu: bool = False,
161
+ force_custom_text: bool = False,
162
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
163
+ return_transform: bool = True,
164
+ image_mean: Optional[Tuple[float, ...]] = None,
165
+ image_std: Optional[Tuple[float, ...]] = None,
166
+ cache_dir: Optional[str] = None,
167
+ ):
168
+ model = create_model(
169
+ model_name,
170
+ pretrained,
171
+ precision=precision,
172
+ device=device,
173
+ jit=jit,
174
+ force_quick_gelu=force_quick_gelu,
175
+ force_custom_text=force_custom_text,
176
+ force_image_size=force_image_size,
177
+ cache_dir=cache_dir,
178
+ require_pretrained=True,
179
+ )
180
+
181
+ if not return_transform:
182
+ return model
183
+
184
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
185
+ image_std = image_std or getattr(model.visual, 'image_std', None)
186
+ preprocess = image_transform(
187
+ model.visual.image_size,
188
+ is_train=False,
189
+ mean=image_mean,
190
+ std=image_std,
191
+ )
192
+
193
+ return model, preprocess
modules/model/modules/feat_extractors/train_clip_src/open_clip/generation_utils.py ADDED
File without changes
modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_configs.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/hf_model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import TensorType
11
+
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+
31
+ # utils
32
+ def _camel2snake(s):
33
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
34
+
35
+
36
+ # TODO: ?last - for gpt-like models
37
+ _POOLERS = {}
38
+
39
+
40
+ def register_pooler(cls):
41
+ """Decorator registering pooler class"""
42
+ _POOLERS[_camel2snake(cls.__name__)] = cls
43
+ return cls
44
+
45
+
46
+ @register_pooler
47
+ class MeanPooler(nn.Module):
48
+ """Mean pooling"""
49
+
50
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
51
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
52
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
53
+
54
+
55
+ @register_pooler
56
+ class MaxPooler(nn.Module):
57
+ """Max pooling"""
58
+
59
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
60
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
61
+ return masked_output.max(1).values
62
+
63
+
64
+ @register_pooler
65
+ class ClsPooler(nn.Module):
66
+ """CLS token pooling"""
67
+
68
+ def __init__(self, use_pooler_output=True):
69
+ super().__init__()
70
+ self.cls_token_position = 0
71
+ self.use_pooler_output = use_pooler_output
72
+
73
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
74
+ if (self.use_pooler_output and
75
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
76
+ (x.pooler_output is not None)
77
+ ):
78
+ return x.pooler_output
79
+
80
+ return x.last_hidden_state[:, self.cls_token_position, :]
81
+
82
+
83
+ class HFTextEncoder(nn.Module):
84
+ """HuggingFace model adapter"""
85
+ output_tokens: torch.jit.Final[bool]
86
+
87
+ def __init__(
88
+ self,
89
+ model_name_or_path: str,
90
+ output_dim: int,
91
+ config: PretrainedConfig = None,
92
+ pooler_type: str = None,
93
+ proj: str = None,
94
+ pretrained: bool = True,
95
+ output_tokens: bool = False,
96
+ ):
97
+ super().__init__()
98
+ self.output_tokens = output_tokens
99
+ self.output_dim = output_dim
100
+
101
+ # TODO: find better way to get this information
102
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
103
+
104
+ if transformers is None:
105
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
106
+ if config is None:
107
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
108
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
109
+ AutoModel.from_config, self.config)
110
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
111
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
112
+ self.transformer = create_func(model_args)
113
+ self.transformer = self.transformer.encoder
114
+ else:
115
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
116
+ else:
117
+ self.config = config
118
+ self.transformer = AutoModel.from_config(config)
119
+ if pooler_type is None: # get default arch pooler
120
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
121
+
122
+ self.pooler = _POOLERS[pooler_type]()
123
+
124
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
125
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
126
+ self.proj = nn.Identity()
127
+ elif proj == 'linear':
128
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
129
+ elif proj == 'mlp':
130
+ hidden_size = (d_model + output_dim) // 2
131
+ self.proj = nn.Sequential(
132
+ nn.Linear(d_model, hidden_size, bias=False),
133
+ nn.GELU(),
134
+ nn.Linear(hidden_size, output_dim, bias=False),
135
+ )
136
+
137
+ def forward(self, x: TensorType):
138
+ attn_mask = (x != self.config.pad_token_id).long()
139
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
140
+ pooled_out = self.pooler(out, attn_mask)
141
+ projected = self.proj(pooled_out)
142
+
143
+ seq_len = out.last_hidden_state.shape[1]
144
+ tokens = (
145
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
146
+ if type(self.pooler) == ClsPooler
147
+ else out.last_hidden_state
148
+ )
149
+
150
+ if self.output_tokens:
151
+ return projected, tokens
152
+ return projected
153
+
154
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
155
+ if not unlocked_layers: # full freezing
156
+ for n, p in self.transformer.named_parameters():
157
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
158
+ return
159
+
160
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
161
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
162
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
163
+ embeddings = getattr(
164
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
165
+ modules = [embeddings, *layer_list][:-unlocked_layers]
166
+ # freeze layers
167
+ for module in modules:
168
+ for n, p in module.named_parameters():
169
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
170
+
171
+ @torch.jit.ignore
172
+ def set_grad_checkpointing(self, enable=True):
173
+ self.transformer.gradient_checkpointing_enable()
174
+
175
+ def init_parameters(self):
176
+ pass
modules/model/modules/feat_extractors/train_clip_src/open_clip/loss.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ try:
6
+ import torch.distributed.nn
7
+ from torch import distributed as dist
8
+
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+
14
+ def gather_features(
15
+ image_features,
16
+ text_features,
17
+ local_loss=False,
18
+ gather_with_grad=False,
19
+ rank=0,
20
+ world_size=1,
21
+ ):
22
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
23
+ # We gather tensors from all gpus
24
+ if gather_with_grad:
25
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
26
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
27
+ else:
28
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
29
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
30
+ dist.all_gather(gathered_image_features, image_features)
31
+ dist.all_gather(gathered_text_features, text_features)
32
+ if not local_loss:
33
+ # ensure grads for local rank when all_* features don't have a gradient
34
+ gathered_image_features[rank] = image_features
35
+ gathered_text_features[rank] = text_features
36
+ all_image_features = torch.cat(gathered_image_features, dim=0)
37
+ all_text_features = torch.cat(gathered_text_features, dim=0)
38
+
39
+ return all_image_features, all_text_features
40
+
41
+
42
+ class ClipLoss(nn.Module):
43
+
44
+ def __init__(
45
+ self,
46
+ local_loss=False,
47
+ gather_with_grad=False,
48
+ cache_labels=False,
49
+ rank=0,
50
+ world_size=1,
51
+ ):
52
+ super().__init__()
53
+ self.local_loss = local_loss
54
+ self.gather_with_grad = gather_with_grad
55
+ self.cache_labels = cache_labels
56
+ self.rank = rank
57
+ self.world_size = world_size
58
+
59
+ # cache state
60
+ self.prev_num_logits = 0
61
+ self.labels = {}
62
+
63
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
64
+ # calculated ground-truth and cache if enabled
65
+ if self.prev_num_logits != num_logits or device not in self.labels:
66
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
67
+ if self.world_size > 1 and self.local_loss:
68
+ labels = labels + num_logits * self.rank
69
+ if self.cache_labels:
70
+ self.labels[device] = labels
71
+ self.prev_num_logits = num_logits
72
+ else:
73
+ labels = self.labels[device]
74
+ return labels
75
+
76
+ def get_logits(self, image_features, text_features, logit_scale):
77
+ if self.world_size > 1:
78
+ all_image_features, all_text_features = gather_features(
79
+ image_features, text_features,
80
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size)
81
+
82
+ if self.local_loss:
83
+ logits_per_image = logit_scale * image_features @ all_text_features.T
84
+ logits_per_text = logit_scale * text_features @ all_image_features.T
85
+ else:
86
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
87
+ logits_per_text = logits_per_image.T
88
+ else:
89
+ logits_per_image = logit_scale * image_features @ text_features.T
90
+ logits_per_text = logit_scale * text_features @ image_features.T
91
+
92
+ return logits_per_image, logits_per_text
93
+
94
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
95
+ device = image_features.device
96
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
97
+
98
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
99
+
100
+ total_loss = (
101
+ F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)
102
+ ) / 2
103
+
104
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
105
+
106
+
107
+ class CoCaLoss(ClipLoss):
108
+ def __init__(
109
+ self,
110
+ caption_loss_weight,
111
+ clip_loss_weight,
112
+ pad_id=0, # pad_token for open_clip custom tokenizer
113
+ local_loss=False,
114
+ gather_with_grad=False,
115
+ cache_labels=False,
116
+ rank=0,
117
+ world_size=1,
118
+ ):
119
+ super().__init__(
120
+ local_loss=local_loss,
121
+ gather_with_grad=gather_with_grad,
122
+ cache_labels=cache_labels,
123
+ rank=rank,
124
+ world_size=world_size,
125
+ )
126
+
127
+ self.clip_loss_weight = clip_loss_weight
128
+ self.caption_loss_weight = caption_loss_weight
129
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
130
+
131
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
132
+ clip_loss = super().forward(image_features, text_features, logit_scale)
133
+ clip_loss = self.clip_loss_weight * clip_loss
134
+
135
+ caption_loss = self.caption_loss(
136
+ logits.permute(0, 2, 1),
137
+ labels,
138
+ )
139
+ caption_loss = caption_loss * self.caption_loss_weight
140
+
141
+ if output_dict:
142
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
143
+
144
+ return clip_loss, caption_loss
145
+
146
+
147
+ class DistillClipLoss(ClipLoss):
148
+
149
+ def dist_loss(self, teacher_logits, student_logits):
150
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
151
+
152
+ def forward(
153
+ self,
154
+ image_features,
155
+ text_features,
156
+ logit_scale,
157
+ dist_image_features,
158
+ dist_text_features,
159
+ dist_logit_scale,
160
+ output_dict=False,
161
+ ):
162
+ logits_per_image, logits_per_text = \
163
+ self.get_logits(image_features, text_features, logit_scale)
164
+
165
+ dist_logits_per_image, dist_logits_per_text = \
166
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
167
+
168
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
169
+
170
+ contrastive_loss = (
171
+ F.cross_entropy(logits_per_image, labels) +
172
+ F.cross_entropy(logits_per_text, labels)
173
+ ) / 2
174
+
175
+ distill_loss = (
176
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
177
+ self.dist_loss(dist_logits_per_text, logits_per_text)
178
+ ) / 2
179
+
180
+ if output_dict:
181
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
182
+
183
+ return contrastive_loss, distill_loss
184
+
185
+
186
+ class AVCLIPLoss(ClipLoss):
187
+ '''This loss is resembles the CLIP loss, but it simply renames the variables'''
188
+
189
+ def __init__(self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1):
190
+ super().__init__(local_loss, gather_with_grad, cache_labels, rank, world_size)
191
+
192
+ def forward(self, rgb_features, audio_features, logit_scale, output_dict=False):
193
+ return super().forward(rgb_features, audio_features, logit_scale, output_dict)
194
+
195
+
196
+ class MultilevelAVCLIPLoss(nn.Module):
197
+
198
+ def __init__(self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1):
199
+ super().__init__()
200
+ self.segment_avclip_loss = AVCLIPLoss(
201
+ local_loss, gather_with_grad, cache_labels, rank, world_size)
202
+ self.global_avclip_loss = AVCLIPLoss(
203
+ local_loss, gather_with_grad, cache_labels, rank, world_size)
204
+
205
+ def forward(self, rgb_features, audio_features, logit_scales, output_dict=False):
206
+ segment_rgb, global_rgb = rgb_features
207
+ segment_audio, global_audio = audio_features
208
+ segment_logit_scale, global_logit_scale = logit_scales
209
+
210
+ assert segment_rgb.shape[0] == segment_audio.shape[0], f'{segment_rgb.shape} != {segment_audio.shape}'
211
+ B, S, D = segment_rgb.shape
212
+
213
+ # (B*S, D) <- (B, S, D)
214
+ segment_loss = self.segment_avclip_loss(segment_rgb.view(B*S, D), segment_audio.view(B*S, D),
215
+ segment_logit_scale, output_dict)
216
+ global_loss = None
217
+ if global_rgb is not None:
218
+ global_loss = self.global_avclip_loss(global_rgb, global_audio, global_logit_scale, output_dict)
219
+
220
+ if output_dict:
221
+ losses = {'segment_contrastive_loss': segment_loss['contrastive_loss']}
222
+ if global_rgb is not None:
223
+ losses['global_contrastive_loss'] = global_loss['contrastive_loss']
224
+ return losses
225
+ else:
226
+ if global_loss is None:
227
+ return segment_loss
228
+ else:
229
+ return segment_loss, global_loss
modules/model/modules/feat_extractors/train_clip_src/open_clip/model.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ from omegaconf import OmegaConf
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+ from utils.utils import instantiate_from_config
17
+
18
+ from .hf_model import HFTextEncoder
19
+ from .modified_resnet import ModifiedResNet
20
+ from .timm_model import TimmModel
21
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
22
+ from .utils import to_2tuple
23
+
24
+
25
+ @dataclass
26
+ class CLIPVisionCfg:
27
+ layers: Union[Tuple[int, int, int, int], int] = 12
28
+ width: int = 768
29
+ head_width: int = 64
30
+ mlp_ratio: float = 4.0
31
+ patch_size: int = 16
32
+ image_size: Union[Tuple[int, int], int] = 224
33
+ ls_init_value: Optional[float] = None # layer scale initial value
34
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
35
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
36
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
37
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
38
+ n_queries: int = 256 # n_queries for attentional pooler
39
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
40
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
41
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
42
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
43
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
44
+ timm_proj_bias: bool = False # enable bias final projection
45
+ timm_drop: float = 0. # head dropout
46
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
47
+ output_tokens: bool = False
48
+
49
+
50
+ @dataclass
51
+ class CLIPTextCfg:
52
+ context_length: int = 77
53
+ vocab_size: int = 49408
54
+ width: int = 512
55
+ heads: int = 8
56
+ layers: int = 12
57
+ ls_init_value: Optional[float] = None # layer scale initial value
58
+ hf_model_name: str = None
59
+ hf_tokenizer_name: str = None
60
+ hf_model_pretrained: bool = True
61
+ proj: str = 'mlp'
62
+ pooler_type: str = 'mean_pooler'
63
+ embed_cls: bool = False
64
+ pad_id: int = 0
65
+ output_tokens: bool = False
66
+
67
+
68
+ def get_cast_dtype(precision: str):
69
+ cast_dtype = None
70
+ if precision == 'bf16':
71
+ cast_dtype = torch.bfloat16
72
+ elif precision == 'fp16':
73
+ cast_dtype = torch.float16
74
+ return cast_dtype
75
+
76
+
77
+ def _build_vision_tower(
78
+ embed_dim: int,
79
+ vision_cfg: CLIPVisionCfg,
80
+ quick_gelu: bool = False,
81
+ cast_dtype: Optional[torch.dtype] = None
82
+ ):
83
+ if isinstance(vision_cfg, dict):
84
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
85
+
86
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
87
+ # memory efficient in recent PyTorch releases (>= 1.10).
88
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
89
+ act_layer = QuickGELU if quick_gelu else nn.GELU
90
+
91
+ if vision_cfg.timm_model_name:
92
+ visual = TimmModel(
93
+ vision_cfg.timm_model_name,
94
+ pretrained=vision_cfg.timm_model_pretrained,
95
+ pool=vision_cfg.timm_pool,
96
+ proj=vision_cfg.timm_proj,
97
+ proj_bias=vision_cfg.timm_proj_bias,
98
+ drop=vision_cfg.timm_drop,
99
+ drop_path=vision_cfg.timm_drop_path,
100
+ embed_dim=embed_dim,
101
+ image_size=vision_cfg.image_size,
102
+ )
103
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
104
+ elif isinstance(vision_cfg.layers, (tuple, list)):
105
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
106
+ visual = ModifiedResNet(
107
+ layers=vision_cfg.layers,
108
+ output_dim=embed_dim,
109
+ heads=vision_heads,
110
+ image_size=vision_cfg.image_size,
111
+ width=vision_cfg.width,
112
+ )
113
+ else:
114
+ vision_heads = vision_cfg.width // vision_cfg.head_width
115
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
116
+ visual = VisionTransformer(
117
+ image_size=vision_cfg.image_size,
118
+ patch_size=vision_cfg.patch_size,
119
+ width=vision_cfg.width,
120
+ layers=vision_cfg.layers,
121
+ heads=vision_heads,
122
+ mlp_ratio=vision_cfg.mlp_ratio,
123
+ ls_init_value=vision_cfg.ls_init_value,
124
+ patch_dropout=vision_cfg.patch_dropout,
125
+ input_patchnorm=vision_cfg.input_patchnorm,
126
+ global_average_pool=vision_cfg.global_average_pool,
127
+ attentional_pool=vision_cfg.attentional_pool,
128
+ n_queries=vision_cfg.n_queries,
129
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
130
+ output_tokens=vision_cfg.output_tokens,
131
+ output_dim=embed_dim,
132
+ act_layer=act_layer,
133
+ norm_layer=norm_layer,
134
+ )
135
+
136
+ return visual
137
+
138
+
139
+ def _build_text_tower(
140
+ embed_dim: int,
141
+ text_cfg: CLIPTextCfg,
142
+ quick_gelu: bool = False,
143
+ cast_dtype: Optional[torch.dtype] = None,
144
+ ):
145
+ if isinstance(text_cfg, dict):
146
+ text_cfg = CLIPTextCfg(**text_cfg)
147
+
148
+ if text_cfg.hf_model_name:
149
+ text = HFTextEncoder(
150
+ text_cfg.hf_model_name,
151
+ output_dim=embed_dim,
152
+ proj=text_cfg.proj,
153
+ pooler_type=text_cfg.pooler_type,
154
+ pretrained=text_cfg.hf_model_pretrained,
155
+ output_tokens=text_cfg.output_tokens,
156
+ )
157
+ else:
158
+ act_layer = QuickGELU if quick_gelu else nn.GELU
159
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
160
+
161
+ text = TextTransformer(
162
+ context_length=text_cfg.context_length,
163
+ vocab_size=text_cfg.vocab_size,
164
+ width=text_cfg.width,
165
+ heads=text_cfg.heads,
166
+ layers=text_cfg.layers,
167
+ ls_init_value=text_cfg.ls_init_value,
168
+ output_dim=embed_dim,
169
+ embed_cls=text_cfg.embed_cls,
170
+ output_tokens=text_cfg.output_tokens,
171
+ pad_id=text_cfg.pad_id,
172
+ act_layer=act_layer,
173
+ norm_layer=norm_layer,
174
+ )
175
+ return text
176
+
177
+
178
+ class CLIP(nn.Module):
179
+ output_dict: torch.jit.Final[bool]
180
+
181
+ def __init__(
182
+ self,
183
+ embed_dim: int,
184
+ vision_cfg: CLIPVisionCfg,
185
+ text_cfg: CLIPTextCfg,
186
+ quick_gelu: bool = False,
187
+ cast_dtype: Optional[torch.dtype] = None,
188
+ output_dict: bool = False,
189
+ ):
190
+ super().__init__()
191
+ self.output_dict = output_dict
192
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
193
+
194
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
195
+ self.transformer = text.transformer
196
+ self.vocab_size = text.vocab_size
197
+ self.token_embedding = text.token_embedding
198
+ self.positional_embedding = text.positional_embedding
199
+ self.ln_final = text.ln_final
200
+ self.text_projection = text.text_projection
201
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
202
+
203
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
204
+
205
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
206
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
207
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
208
+
209
+ @torch.jit.ignore
210
+ def set_grad_checkpointing(self, enable=True):
211
+ self.visual.set_grad_checkpointing(enable)
212
+ self.transformer.grad_checkpointing = enable
213
+
214
+ def encode_image(self, image, normalize: bool = False):
215
+ features = self.visual(image)
216
+ return F.normalize(features, dim=-1) if normalize else features
217
+
218
+ def encode_text(self, text, normalize: bool = False):
219
+ cast_dtype = self.transformer.get_cast_dtype()
220
+
221
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
222
+
223
+ x = x + self.positional_embedding.to(cast_dtype)
224
+ x = x.permute(1, 0, 2) # NLD -> LND
225
+ x = self.transformer(x, attn_mask=self.attn_mask)
226
+ x = x.permute(1, 0, 2) # LND -> NLD
227
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
228
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
229
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
230
+ return F.normalize(x, dim=-1) if normalize else x
231
+
232
+ def forward(self, image, text):
233
+ image_features = self.encode_image(image, normalize=True)
234
+ text_features = self.encode_text(text, normalize=True)
235
+ if self.output_dict:
236
+ return {
237
+ "image_features": image_features,
238
+ "text_features": text_features,
239
+ "logit_scale": self.logit_scale.exp()
240
+ }
241
+ return image_features, text_features, self.logit_scale.exp()
242
+
243
+
244
+ class CustomTextCLIP(nn.Module):
245
+ output_dict: torch.jit.Final[bool]
246
+
247
+ def __init__(
248
+ self,
249
+ embed_dim: int,
250
+ vision_cfg: CLIPVisionCfg,
251
+ text_cfg: CLIPTextCfg,
252
+ quick_gelu: bool = False,
253
+ cast_dtype: Optional[torch.dtype] = None,
254
+ output_dict: bool = False,
255
+ ):
256
+ super().__init__()
257
+ self.output_dict = output_dict
258
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
259
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
260
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
261
+
262
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
263
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
264
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
265
+
266
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
267
+ self.text.lock(unlocked_layers, freeze_layer_norm)
268
+
269
+ @torch.jit.ignore
270
+ def set_grad_checkpointing(self, enable=True):
271
+ self.visual.set_grad_checkpointing(enable)
272
+ self.text.set_grad_checkpointing(enable)
273
+
274
+ def encode_image(self, image, normalize: bool = False):
275
+ features = self.visual(image)
276
+ return F.normalize(features, dim=-1) if normalize else features
277
+
278
+ def encode_text(self, text, normalize: bool = False):
279
+ features = self.text(text)
280
+ return F.normalize(features, dim=-1) if normalize else features
281
+
282
+ def forward(self, image, text):
283
+ image_features = self.encode_image(image, normalize=True)
284
+ text_features = self.encode_text(text, normalize=True)
285
+ if self.output_dict:
286
+ return {
287
+ "image_features": image_features,
288
+ "text_features": text_features,
289
+ "logit_scale": self.logit_scale.exp()
290
+ }
291
+ return image_features, text_features, self.logit_scale.exp()
292
+
293
+
294
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
295
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
296
+
297
+ def _convert_weights(l):
298
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
299
+ l.weight.data = l.weight.data.to(dtype)
300
+ if l.bias is not None:
301
+ l.bias.data = l.bias.data.to(dtype)
302
+
303
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
304
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
305
+ tensor = getattr(l, attr)
306
+ if tensor is not None:
307
+ tensor.data = tensor.data.to(dtype)
308
+
309
+ for name in ["text_projection", "proj"]:
310
+ if hasattr(l, name):
311
+ attr = getattr(l, name)
312
+ if attr is not None:
313
+ attr.data = attr.data.to(dtype)
314
+
315
+ model.apply(_convert_weights)
316
+
317
+
318
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
319
+
320
+
321
+ # used to maintain checkpoint compatibility
322
+ def convert_to_custom_text_state_dict(state_dict: dict):
323
+ if 'text_projection' in state_dict:
324
+ # old format state_dict, move text tower -> .text
325
+ new_state_dict = {}
326
+ for k, v in state_dict.items():
327
+ if any(k.startswith(p) for p in (
328
+ 'text_projection',
329
+ 'positional_embedding',
330
+ 'token_embedding',
331
+ 'transformer',
332
+ 'ln_final',
333
+ )):
334
+ k = 'text.' + k
335
+ new_state_dict[k] = v
336
+ return new_state_dict
337
+ return state_dict
338
+
339
+
340
+ def build_model_from_openai_state_dict(
341
+ state_dict: dict,
342
+ quick_gelu=True,
343
+ cast_dtype=torch.float16,
344
+ ):
345
+ vit = "visual.proj" in state_dict
346
+
347
+ if vit:
348
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
349
+ vision_layers = len(
350
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
351
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
352
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
353
+ image_size = vision_patch_size * grid_size
354
+ else:
355
+ counts: list = [
356
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
357
+ vision_layers = tuple(counts)
358
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
359
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
360
+ vision_patch_size = None
361
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
362
+ image_size = output_width * 32
363
+
364
+ embed_dim = state_dict["text_projection"].shape[1]
365
+ context_length = state_dict["positional_embedding"].shape[0]
366
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
367
+ transformer_width = state_dict["ln_final.weight"].shape[0]
368
+ transformer_heads = transformer_width // 64
369
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
370
+
371
+ vision_cfg = CLIPVisionCfg(
372
+ layers=vision_layers,
373
+ width=vision_width,
374
+ patch_size=vision_patch_size,
375
+ image_size=image_size,
376
+ )
377
+ text_cfg = CLIPTextCfg(
378
+ context_length=context_length,
379
+ vocab_size=vocab_size,
380
+ width=transformer_width,
381
+ heads=transformer_heads,
382
+ layers=transformer_layers,
383
+ )
384
+ model = CLIP(
385
+ embed_dim,
386
+ vision_cfg=vision_cfg,
387
+ text_cfg=text_cfg,
388
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
389
+ cast_dtype=cast_dtype,
390
+ )
391
+
392
+ for key in ["input_resolution", "context_length", "vocab_size"]:
393
+ state_dict.pop(key, None)
394
+
395
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
396
+ model.load_state_dict(state_dict)
397
+ return model.eval()
398
+
399
+
400
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
401
+ model.eval()
402
+ image_size = model.visual.image_size
403
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
404
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
405
+ model = torch.jit.trace_module(
406
+ model,
407
+ inputs=dict(
408
+ forward=(example_images, example_text),
409
+ encode_text=(example_text,),
410
+ encode_image=(example_images,)
411
+ ))
412
+ model.visual.image_size = image_size
413
+ return model
414
+
415
+
416
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
417
+ # Rescale the grid of position embeddings when loading from state_dict
418
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
419
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
420
+ return
421
+ grid_size = to_2tuple(model.visual.grid_size)
422
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
423
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
424
+ if new_seq_len == old_pos_embed.shape[0]:
425
+ return
426
+
427
+ if extra_tokens:
428
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
429
+ else:
430
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
431
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
432
+
433
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
434
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
435
+ pos_emb_img = F.interpolate(
436
+ pos_emb_img,
437
+ size=grid_size,
438
+ mode=interpolation,
439
+ antialias=antialias,
440
+ align_corners=False,
441
+ )
442
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
443
+ if pos_emb_tok is not None:
444
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
445
+ else:
446
+ new_pos_embed = pos_emb_img
447
+ state_dict['visual.positional_embedding'] = new_pos_embed
448
+
449
+ class AVCLIP(nn.Module):
450
+
451
+ def __init__(self, n_embd: int, afeat_extractor: OmegaConf, vfeat_extractor: OmegaConf,
452
+ aproj: OmegaConf, vproj: OmegaConf, init_scale: float = 0.07, clamp_scale_min: float = 0.001,
453
+ clamp_scale_max: float = 0.5, gather_for_loss: bool = False):
454
+ super().__init__()
455
+ self.output_dict = True
456
+ self.n_embd = n_embd
457
+
458
+ # loading audio and rgb towers
459
+ self.v_encoder = instantiate_from_config(vfeat_extractor)
460
+ self.a_encoder = instantiate_from_config(afeat_extractor)
461
+ # loading audio and rgb towers and projection layers to account for different feature dimensions
462
+ self.aproj = instantiate_from_config(aproj)
463
+ self.vproj = instantiate_from_config(vproj)
464
+
465
+ self.clamp_scale_min, self.clamp_scale_max = clamp_scale_min, clamp_scale_max
466
+ self.init_scale = init_scale
467
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.init_scale) # NOTE: exp(1/OpenCLIP)
468
+
469
+ self.gather_for_loss = gather_for_loss
470
+
471
+ # self.ln_final = text.ln_final # perhaps only useful for transformer towers
472
+ # self.register_buffer('attn_mask', text.attn_mask, persistent=False)
473
+
474
+ def forward(self, vis: torch.Tensor, aud: torch.Tensor, alpha: float = 0.0, for_loop: bool = False,
475
+ world_size=1):
476
+ '''
477
+ Args:
478
+ vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
479
+ aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
480
+ alpha (float): linear interpolation coefficient for pseudo-targets (with 0.0 targets are 1-hot)
481
+ for_loop (bool): whether to process each segment in a for loop or all at once
482
+ (speed-memory tradeoff)
483
+ Returns:
484
+ rgb_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None RGB features
485
+ audio_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None audio features
486
+ logit_scale (tuple(torch.Tensor)): local and global logit scales (1, )
487
+ '''
488
+ assert alpha == 0.0, f'alpha={alpha} not supported yet'
489
+ logit_scales = self.clamp_logit_scales()
490
+ vfeat, _, afeat, _ = self.encode_streams(vis, aud, for_loop, do_norm=True)
491
+
492
+ if world_size > 1 and self.gather_for_loss: # gather all features
493
+ vfeat_all = torch.cat(torch.distributed.nn.all_gather(vfeat), dim=0)
494
+ afeat_all = torch.cat(torch.distributed.nn.all_gather(afeat), dim=0)
495
+ else:
496
+ vfeat_all = vfeat
497
+ afeat_all = afeat
498
+
499
+ loss_avc, _ = self.compute_loss(vfeat, afeat, vfeat_all.mT, afeat_all.mT, self.logit_scale, alpha=0)
500
+ out = {
501
+ 'rgb_features': (vfeat, None), 'audio_features': (afeat, None),
502
+ 'logit_scales': logit_scales,
503
+ 'losses': {'segment_contrastive_loss': loss_avc},
504
+ }
505
+ return out
506
+
507
+ def compute_loss(self, vfeat, afeat, vfeat_all, afeat_all, scale, alpha=0.0, vfeat_m=None, afeat_m=None):
508
+ '''For Multi-level contrastive learning, the losses are made the same way for all levels'''
509
+ sim_v2a = vfeat @ afeat_all / scale
510
+ sim_a2v = afeat @ vfeat_all / scale
511
+ sim_v2a_t, sim_a2v_t = self._make_targets(sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m)
512
+ loss = self._loss(sim_v2a, sim_a2v, sim_v2a_t, sim_a2v_t)
513
+ return loss, (sim_v2a, sim_a2v)
514
+
515
+ @torch.no_grad()
516
+ def _make_targets(self, sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m):
517
+ # NOTE: for simplicity, we assume that sim_v2a.shape[0] == sim_a2v.shape[0]
518
+ # NOTE: sim_targets is not square (sim_v2a.shape is (bsize, bsize+Qsize) )
519
+ sim_targets = torch.eye(*sim_v2a.shape, device=sim_v2a.device, dtype=sim_v2a.dtype)
520
+ sim_v2a_targets = sim_targets
521
+ sim_a2v_targets = sim_targets
522
+ return sim_v2a_targets, sim_a2v_targets
523
+
524
+ def _loss(self, sim_v2a, sim_a2v, sim_v2a_targets, sim_a2v_targets):
525
+ loss_v2a = F.cross_entropy(sim_v2a, sim_v2a_targets)
526
+ loss_a2v = F.cross_entropy(sim_a2v, sim_a2v_targets)
527
+ return (loss_v2a + loss_a2v) / 2
528
+
529
+ def encode_streams(self, vis, aud, for_loop, do_norm=True):
530
+ # (B*S, D), (B, D) or None; because `flatten_to_2D = True`
531
+ flatten_to_2D = True
532
+ vfeat, _ = self.encode_stream(vis, self.v_encoder, self.vproj, do_norm, flatten_to_2D, for_loop)
533
+ afeat, _ = self.encode_stream(aud, self.a_encoder, self.aproj, do_norm, flatten_to_2D, for_loop)
534
+ return vfeat, None, afeat, None
535
+
536
+ def encode_stream(self, x, feat_extractor_fn, proj_fn, do_norm, flatten_to_2D, for_loop):
537
+ # x is (B, S, ...)
538
+ segment_x, _ = feat_extractor_fn(x, for_loop) # segment_x: (B, S, D), global_x: (B, D)
539
+ if flatten_to_2D:
540
+ B, S, D = segment_x.shape
541
+ segment_x = segment_x.view(B*S, D) # flatten batch and segment dims
542
+ segment_x = proj_fn(segment_x)
543
+ segment_x = F.normalize(segment_x, dim=-1) if do_norm else segment_x
544
+ # do_global is passed in to avoid computing global features when not needed (e.g. during eval)
545
+ return segment_x, None # (B*S, D), (B, D) or None
546
+
547
+ def forward_for_logging(self, vis, aud, for_momentum=False, for_loop=False, do_norm=True):
548
+ '''
549
+ Runs the forward pass but keeps certain tensors in memory for logging purposes, ie code duplication.
550
+ NOTE: to be used outside of this module, most likely during logging
551
+
552
+ Args:
553
+ vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
554
+ aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
555
+ '''
556
+ flatten_to_2D = True
557
+ out = dict()
558
+
559
+ # factorizing self.encode_streams into encode_visual/encode_audio to avoid unnecessary computations
560
+ # (B*S, D), (B, D) or None;
561
+ # vfeat, _ = self.encode_stream(vis, self.v_encoder, self.vproj, do_norm, flatten_to_2D, for_loop)
562
+ # afeat, _ = self.encode_stream(aud, self.a_encoder, self.aproj, do_norm, flatten_to_2D, for_loop)
563
+ vfeat, _, afeat, _ = self.encode_streams(vis, aud, for_loop, do_norm)
564
+ # cache features (for 0-shot evaluation)
565
+ out['segment_vfeat'] = vfeat.clone()
566
+ out['segment_afeat'] = afeat.clone()
567
+ # and similiarity matrices (for visualization) (B*S, B*S)
568
+ out['segment_sim_v2a'] = out['segment_vfeat'] @ out['segment_afeat'].mT / self.logit_scale
569
+ out['segment_sim_a2v'] = out['segment_afeat'] @ out['segment_vfeat'].mT / self.logit_scale
570
+ # self
571
+ out['segment_sim_v2v'] = out['segment_vfeat'] @ out['segment_vfeat'].mT / self.logit_scale
572
+ out['segment_sim_a2a'] = out['segment_afeat'] @ out['segment_afeat'].mT / self.logit_scale
573
+
574
+ # compute losses
575
+ loss, _ = self.compute_loss(vfeat, afeat, vfeat.mT, afeat.mT, self.logit_scale)
576
+ out['segment_contrastive_loss'] = loss
577
+ return out
578
+
579
+ @torch.no_grad()
580
+ def clamp_logit_scales(self):
581
+ self.logit_scale.clamp_(self.clamp_scale_min, self.clamp_scale_max)
582
+ return (self.logit_scale, None)
583
+
584
+
585
+ class MultilevelMoCoCLIP(nn.Module):
586
+
587
+ def __init__(self, n_embd: int, queue_size: int, momentum: float,
588
+ afeat_extractor: OmegaConf, vfeat_extractor: OmegaConf, aproj: OmegaConf, vproj: OmegaConf,
589
+ init_scale: float = 0.07, clamp_scale_min: float = 0.001, clamp_scale_max: float = 0.5):
590
+ super().__init__()
591
+ self.output_dict = True
592
+ self.n_embd = n_embd
593
+ self.momentum = momentum
594
+ self.to_add_global_repr = afeat_extractor.params.add_global_repr
595
+
596
+ # loading audio and rgb towers
597
+ self.v_encoder = instantiate_from_config(vfeat_extractor)
598
+ self.a_encoder = instantiate_from_config(afeat_extractor)
599
+ # loading audio and rgb towers and projection layers to account for different feature dimensions
600
+ self.segment_aproj = instantiate_from_config(aproj)
601
+ self.segment_vproj = instantiate_from_config(vproj)
602
+ self.global_aproj = instantiate_from_config(aproj) if self.to_add_global_repr else None
603
+ self.global_vproj = instantiate_from_config(vproj) if self.to_add_global_repr else None
604
+
605
+ self.clamp_scale_min, self.clamp_scale_max = clamp_scale_min, clamp_scale_max
606
+ self.init_scale = init_scale
607
+ self.segment_logit_scale = nn.Parameter(torch.ones([]) * self.init_scale) # NOTE: exp(1/OpenCLIP)
608
+ self.global_logit_scale = nn.Parameter(torch.ones([]) * self.init_scale) if self.to_add_global_repr else None
609
+
610
+ # create momentum models
611
+ self.v_encoder_m = instantiate_from_config(vfeat_extractor)
612
+ self.a_encoder_m = instantiate_from_config(afeat_extractor)
613
+ self.segment_aproj_m = instantiate_from_config(aproj)
614
+ self.segment_vproj_m = instantiate_from_config(vproj)
615
+ self.global_aproj_m = instantiate_from_config(aproj) if self.to_add_global_repr else None
616
+ self.global_vproj_m = instantiate_from_config(vproj) if self.to_add_global_repr else None
617
+
618
+ self.model_pairs = [
619
+ [self.v_encoder, self.v_encoder_m], [self.segment_vproj, self.segment_vproj_m],
620
+ [self.a_encoder, self.a_encoder_m], [self.segment_aproj, self.segment_aproj_m],
621
+ ]
622
+ if self.to_add_global_repr:
623
+ self.model_pairs += [
624
+ [self.global_aproj, self.global_aproj_m], [self.global_vproj, self.global_vproj_m],
625
+ ]
626
+
627
+ self.copy_params()
628
+
629
+ self.segment_queue_size = queue_size * afeat_extractor.params.max_segments # scaled by # of segments
630
+ self.global_queue_size = queue_size if self.to_add_global_repr else None
631
+ self.init_Qs(self.segment_queue_size, self.global_queue_size, self.n_embd)
632
+
633
+ # self.ln_final = text.ln_final # perhaps only useful for transformer towers
634
+ # self.register_buffer('attn_mask', text.attn_mask, persistent=False)
635
+
636
+ def forward(self, vis: torch.Tensor, aud: torch.Tensor, alpha: float = 0.0, for_loop: bool = False,
637
+ world_size=None):
638
+ '''
639
+ Args:
640
+ vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
641
+ aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
642
+ alpha (float): linear interpolation coefficient for pseudo-targets (with 0.0 targets are 1-hot)
643
+ for_loop (bool): whether to process each segment in a for loop or all at once
644
+ (speed-memory tradeoff)
645
+ Returns:
646
+ rgb_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None RGB features
647
+ audio_features (tuple(torch.Tensor)): local (B, S, D) and global (B, D) or None audio features
648
+ logit_scale (tuple(torch.Tensor)): local and global logit scales (1, )
649
+ '''
650
+ logit_scales = self.clamp_logit_scales()
651
+ to_add_global_repr = self.to_add_global_repr # for readability only
652
+
653
+ feats = self.encode_streams(vis, aud, for_momentum=False, for_loop=for_loop, do_norm=True)
654
+ segment_vfeat, global_vfeat, segment_afeat, global_afeat = feats
655
+
656
+ # get momentum features
657
+ with torch.no_grad():
658
+ if self.training:
659
+ self._momentum_update()
660
+ feats_m = self.encode_streams(vis, aud, for_momentum=True, for_loop=for_loop, do_norm=True)
661
+ segment_vfeat_m, global_vfeat_m, segment_afeat_m, global_afeat_m = feats_m
662
+
663
+ # cat with queue to extend the list of negatives
664
+ segment_vfeat_all = torch.cat([segment_vfeat_m.t(), self.segment_v_queue.clone().detach()], dim=1)
665
+ segment_afeat_all = torch.cat([segment_afeat_m.t(), self.segment_a_queue.clone().detach()], dim=1)
666
+ if to_add_global_repr:
667
+ global_vfeat_all = torch.cat([global_vfeat_m.t(), self.global_v_queue.clone().detach()], dim=1)
668
+ global_afeat_all = torch.cat([global_afeat_m.t(), self.global_a_queue.clone().detach()], dim=1)
669
+
670
+ segment_loss_avc, _ = self.compute_loss(segment_vfeat, segment_afeat, segment_vfeat_all,
671
+ segment_afeat_all, self.segment_logit_scale,
672
+ alpha, segment_vfeat_m, segment_afeat_m)
673
+
674
+ global_loss_avc = None
675
+ if to_add_global_repr:
676
+ global_loss_avc, _ = self.compute_loss(global_vfeat, global_afeat, global_vfeat_all,
677
+ global_afeat_all, self.global_logit_scale,
678
+ alpha, global_vfeat_m, global_afeat_m)
679
+
680
+ if self.training:
681
+ self._multilevel_dequeue_and_enqueue(segment_vfeat_m, segment_afeat_m, global_vfeat_m, global_afeat_m)
682
+ else:
683
+ raise Exception('This module is used only during training. Use model.something instead.')
684
+
685
+ out = {
686
+ 'rgb_features': (segment_vfeat, global_vfeat),
687
+ 'audio_features': (segment_afeat, global_afeat), 'logit_scales': logit_scales,
688
+ 'losses': {'segment_contrastive_loss': segment_loss_avc},
689
+ }
690
+ if global_loss_avc is not None:
691
+ out['losses']['global_contrastive_loss'] = global_loss_avc
692
+ return out
693
+
694
+ def compute_loss(self, vfeat, afeat, vfeat_all, afeat_all, scale, alpha=0.0, vfeat_m=None, afeat_m=None):
695
+ '''For Multi-level contrastive learning, the losses are made the same way for all levels'''
696
+ sim_v2a = vfeat @ afeat_all / scale
697
+ sim_a2v = afeat @ vfeat_all / scale
698
+ sim_v2a_t, sim_a2v_t = self._make_targets(sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m)
699
+ loss = self._loss(sim_v2a, sim_a2v, sim_v2a_t, sim_a2v_t)
700
+ return loss, (sim_v2a, sim_a2v)
701
+
702
+ @torch.no_grad()
703
+ def _make_targets(self, sim_v2a, vfeat_all, afeat_all, scale, alpha, vfeat_m, afeat_m):
704
+ # NOTE: for simplicity, we assume that sim_v2a.shape[0] == sim_a2v.shape[0]
705
+ # NOTE: sim_targets is not square (sim_v2a.shape is (bsize, bsize+Qsize) )
706
+ sim_targets = torch.eye(*sim_v2a.shape, device=sim_v2a.device, dtype=sim_v2a.dtype)
707
+ # the ALBEF alpha trick
708
+ if alpha > 0.0:
709
+ sim_v2a_m = vfeat_m @ afeat_all / scale
710
+ sim_a2v_m = afeat_m @ vfeat_all / scale
711
+ sim_v2a_targets = alpha * F.softmax(sim_v2a_m, dim=1) + (1 - alpha) * sim_targets
712
+ sim_a2v_targets = alpha * F.softmax(sim_a2v_m, dim=1) + (1 - alpha) * sim_targets
713
+ else:
714
+ sim_v2a_targets = sim_targets
715
+ sim_a2v_targets = sim_targets
716
+ return sim_v2a_targets, sim_a2v_targets
717
+
718
+ def _loss(self, sim_v2a, sim_a2v, sim_v2a_targets, sim_a2v_targets):
719
+ loss_v2a = F.cross_entropy(sim_v2a, sim_v2a_targets)
720
+ loss_a2v = F.cross_entropy(sim_a2v, sim_a2v_targets)
721
+ return (loss_v2a + loss_a2v) / 2
722
+
723
+ def encode_streams(self, vis, aud, for_momentum, for_loop, do_norm=True):
724
+ # (B*S, D), (B, D) or None; because `flatten_to_2D = True`
725
+ flatten_to_2D = True
726
+ segment_vfeat, global_vfeat = self.encode_visual(vis, for_momentum, self.to_add_global_repr, do_norm,
727
+ for_loop=for_loop, flatten_to_2D=flatten_to_2D)
728
+ segment_afeat, global_afeat = self.encode_audio(aud, for_momentum, self.to_add_global_repr, do_norm,
729
+ for_loop=for_loop, flatten_to_2D=flatten_to_2D)
730
+ return segment_vfeat, global_vfeat, segment_afeat, global_afeat
731
+
732
+ def encode_audio(self, x, for_momentum: bool = False, do_global: bool = True, do_norm: bool = True,
733
+ flatten_to_2D=True, for_loop=False):
734
+ # define callables
735
+ encode_fn = self.a_encoder_m if for_momentum else self.a_encoder
736
+ segment_proj_fn = self.segment_aproj_m if for_momentum else self.segment_aproj
737
+ global_proj_fn = self.global_aproj_m if for_momentum else self.global_aproj
738
+ # do the encoding
739
+ return self.encode_stream(x, encode_fn, segment_proj_fn, global_proj_fn, do_global, do_norm,
740
+ flatten_to_2D, for_loop)
741
+
742
+ def encode_visual(self, x, for_momentum: bool = False, do_global: bool = True, do_norm: bool = True,
743
+ flatten_to_2D=True, for_loop=False):
744
+ # define callables
745
+ encode_fn = self.v_encoder_m if for_momentum else self.v_encoder
746
+ segment_proj_fn = self.segment_vproj_m if for_momentum else self.segment_vproj
747
+ global_proj_fn = self.global_vproj_m if for_momentum else self.global_vproj
748
+ # do the encoding
749
+ return self.encode_stream(x, encode_fn, segment_proj_fn, global_proj_fn, do_global, do_norm,
750
+ flatten_to_2D, for_loop)
751
+
752
+ def encode_stream(self, x, feat_extractor_fn, segment_proj_fn, global_proj_fn, do_global, do_norm,
753
+ flatten_to_2D, for_loop):
754
+ # x is (B, S, ...)
755
+ segment_x, global_x = feat_extractor_fn(x, for_loop) # segment_x: (B, S, D), global_x: (B, D)
756
+ if flatten_to_2D:
757
+ B, S, D = segment_x.shape
758
+ segment_x = segment_x.view(B*S, D) # flatten batch and segment dims
759
+ segment_x = segment_proj_fn(segment_x)
760
+ segment_x = F.normalize(segment_x, dim=-1) if do_norm else segment_x
761
+ # do_global is passed in to avoid computing global features when not needed (e.g. during eval)
762
+ if do_global and self.to_add_global_repr:
763
+ global_x = global_proj_fn(global_x)
764
+ global_x = F.normalize(global_x, dim=-1) if do_norm else global_x
765
+ return segment_x, global_x # (B*S, D), (B, D) or None
766
+
767
+ def forward_for_logging(self, vis, aud, for_momentum=False, for_loop=False, do_norm=True):
768
+ '''
769
+ Runs the forward pass but keeps certain tensors in memory for logging purposes, ie code duplication.
770
+ NOTE: to be used outside of this module, most likely during logging
771
+
772
+ Args:
773
+ vis (torch.Tensor): RGB frames (B, S, C, Tv, H, W)
774
+ aud (torch.Tensor): audio spectrograms (B, S, Ta, F)
775
+ '''
776
+ flatten_to_2D = True
777
+
778
+ out = dict()
779
+
780
+ # factorizing self.encode_streams into encode_visual/encode_audio to avoid unnecessary computations
781
+ # (B*S, D), (B, D) or None;
782
+ segment_vfeat, global_vfeat = self.encode_visual(vis, for_momentum, self.to_add_global_repr, do_norm,
783
+ flatten_to_2D, for_loop)
784
+ segment_afeat, global_afeat = self.encode_audio(aud, for_momentum, self.to_add_global_repr, do_norm,
785
+ flatten_to_2D, for_loop)
786
+ # cache features (for 0-shot evaluation)
787
+ out['segment_vfeat'] = segment_vfeat.clone()
788
+ out['segment_afeat'] = segment_afeat.clone()
789
+ # and similiarity matrices (for visualization) (B*S, B*S)
790
+ out['segment_sim_v2a'] = out['segment_vfeat'] @ out['segment_afeat'].mT / self.segment_logit_scale
791
+ out['segment_sim_a2v'] = out['segment_afeat'] @ out['segment_vfeat'].mT / self.segment_logit_scale
792
+ # self
793
+ out['segment_sim_v2v'] = out['segment_vfeat'] @ out['segment_vfeat'].mT / self.segment_logit_scale
794
+ out['segment_sim_a2a'] = out['segment_afeat'] @ out['segment_afeat'].mT / self.segment_logit_scale
795
+
796
+ # compute losses
797
+ segment_loss, _ = self.compute_loss(segment_vfeat, segment_afeat, segment_vfeat.mT, segment_afeat.mT,
798
+ self.segment_logit_scale)
799
+ out['segment_contrastive_loss'] = segment_loss
800
+ if self.to_add_global_repr:
801
+ global_loss, _ = self.compute_loss(global_vfeat, global_afeat, global_vfeat.mT, global_afeat.mT,
802
+ self.global_logit_scale)
803
+ out['global_contrastive_loss'] = global_loss
804
+
805
+ return out
806
+
807
+
808
+ @torch.no_grad()
809
+ def clamp_logit_scales(self):
810
+ self.segment_logit_scale.clamp_(self.clamp_scale_min, self.clamp_scale_max)
811
+ if self.to_add_global_repr:
812
+ self.global_logit_scale.clamp_(self.clamp_scale_min, self.clamp_scale_max)
813
+ return (self.segment_logit_scale, self.global_logit_scale)
814
+
815
+ @torch.no_grad()
816
+ def copy_params(self):
817
+ for model_pair in self.model_pairs:
818
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
819
+ param_m.data.copy_(param.data) # initialize
820
+ param_m.requires_grad = False # not update by gradient
821
+
822
+ @torch.no_grad()
823
+ def _momentum_update(self):
824
+ for model_pair in self.model_pairs:
825
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
826
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
827
+
828
+ @torch.no_grad()
829
+ def _multilevel_dequeue_and_enqueue(self, segment_vfeat_m, segment_afeat_m, global_vfeat_m, global_afeat_m):
830
+ if self.segment_queue_size > 0:
831
+ self._dequeue_and_enqueue(segment_vfeat_m, segment_afeat_m, 'segment_')
832
+ if self.to_add_global_repr and self.global_queue_size > 0:
833
+ self._dequeue_and_enqueue(global_vfeat_m, global_afeat_m, 'global_')
834
+
835
+ @torch.no_grad()
836
+ def _dequeue_and_enqueue(self, vfeat, afeat, level_prefix_: str):
837
+ # gather keys before updating queue
838
+ if torch.distributed.is_initialized():
839
+ vfeats = concat_all_gather(vfeat)
840
+ afeats = concat_all_gather(afeat)
841
+ else:
842
+ vfeats = vfeat
843
+ afeats = afeat
844
+
845
+ batch_size = vfeats.shape[0]
846
+ queue_size = getattr(self, level_prefix_ + 'queue_size')
847
+
848
+ # same as `ptr = int(self.segment_queue_ptr)` but allows accessing the attribute by string
849
+ ptr = int(getattr(self, level_prefix_ + 'queue_ptr'))
850
+ assert queue_size % batch_size == 0, f'For simplicity: {queue_size} % {batch_size} == 0'
851
+
852
+ # replace the keys at ptr (dequeue and enqueue)
853
+ getattr(self, level_prefix_ + 'v_queue')[:, ptr:ptr + batch_size] = vfeats.T
854
+ getattr(self, level_prefix_ + 'a_queue')[:, ptr:ptr + batch_size] = afeats.T
855
+ ptr = (ptr + batch_size) % queue_size # move pointer
856
+
857
+ getattr(self, level_prefix_ + 'queue_ptr')[0] = ptr
858
+
859
+ def init_Qs(self, segment_queue_size: int, global_queue_size: int, n_embd: int):
860
+ # create the queues; TODO: flip the dimensions, yikes!
861
+ self.register_buffer('segment_v_queue', torch.randn(n_embd, segment_queue_size))
862
+ self.register_buffer('segment_a_queue', torch.randn(n_embd, segment_queue_size))
863
+ self.register_buffer('segment_queue_ptr', torch.zeros(1, dtype=torch.long))
864
+ self.segment_v_queue = nn.functional.normalize(self.segment_v_queue, dim=0)
865
+ self.segment_a_queue = nn.functional.normalize(self.segment_a_queue, dim=0)
866
+ if self.to_add_global_repr:
867
+ self.register_buffer('global_v_queue', torch.randn(n_embd, global_queue_size))
868
+ self.register_buffer('global_a_queue', torch.randn(n_embd, global_queue_size))
869
+ self.register_buffer('global_queue_ptr', torch.zeros(1, dtype=torch.long))
870
+ self.global_v_queue = nn.functional.normalize(self.global_v_queue, dim=0)
871
+ self.global_a_queue = nn.functional.normalize(self.global_a_queue, dim=0)
872
+
873
+ @torch.no_grad()
874
+ def concat_all_gather(tensor):
875
+ """
876
+ Performs all_gather operation on the provided tensors.
877
+ *** Warning ***: torch.distributed.all_gather has no gradient.
878
+ """
879
+ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
880
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
881
+
882
+ output = torch.cat(tensors_gather, dim=0)
883
+ return output
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus-240.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 240,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16-plus.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-plus-256.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 256,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-14.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 14
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-H-16.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 16
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-280.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 280,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14-336.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16-320.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 320,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
modules/model/modules/feat_extractors/train_clip_src/open_clip/model_configs/ViT-L-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }