tanthinhdt commited on
Commit
3e1357a
1 Parent(s): ac8572e

Upload model

Browse files
Files changed (9) hide show
  1. README.md +199 -0
  2. config.json +101 -0
  3. configuration.py +181 -0
  4. encoder.py +110 -0
  5. generation_config.json +4 -0
  6. model.safetensors +3 -0
  7. modelling.py +753 -0
  8. resnet.py +216 -0
  9. utils.py +166 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "activation_fn": "gelu",
4
+ "apply_mask": false,
5
+ "arch": "avsp_llm",
6
+ "architectures": [
7
+ "AVSPLLMModel"
8
+ ],
9
+ "attention_dropout": 0.1,
10
+ "audio_dropout": 0.0,
11
+ "audio_feat_dim": 104,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration.AVSPLLMConfig",
14
+ "AutoModelForVideoClassification": "modelling.AVSPLLMModel"
15
+ },
16
+ "conv_bias": false,
17
+ "conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
18
+ "conv_pos": 128,
19
+ "conv_pos_groups": 16,
20
+ "data": null,
21
+ "decoder_activation_dropout": 0.0,
22
+ "decoder_attention_dropout": 0.1,
23
+ "decoder_attention_heads": 4,
24
+ "decoder_dropout": 0.1,
25
+ "decoder_embed_dim": 2560,
26
+ "decoder_ffn_embed_dim": 3072,
27
+ "decoder_layerdrop": 0.0,
28
+ "decoder_layers": 6,
29
+ "decoder_learned_pos": false,
30
+ "decoder_normalize_before": false,
31
+ "dropout": 0.1,
32
+ "dropout_features": 0.0,
33
+ "dropout_input": 0.0,
34
+ "encoder_attention_heads": 16,
35
+ "encoder_embed_dim": 1024,
36
+ "encoder_ffn_embed_dim": 4096,
37
+ "encoder_layerdrop": 0.0,
38
+ "encoder_layers": 24,
39
+ "extractor_mode": "default",
40
+ "feature_ds_rate": 1,
41
+ "feature_grad_mult": 1.0,
42
+ "final_dim": 256,
43
+ "final_dropout": 0.1,
44
+ "freeze_finetune_updates": 0,
45
+ "ignored_weights": [],
46
+ "input_modality": "video",
47
+ "label_rate": 25,
48
+ "latent_temp": [
49
+ 2.0,
50
+ 0.5,
51
+ 0.999995
52
+ ],
53
+ "layer_norm_first": false,
54
+ "layerdrop": 0.0,
55
+ "llm_ckpt_path": "vilm/vinallama-2.7b",
56
+ "logit_temp": 0.1,
57
+ "mask_channel_length": 64,
58
+ "mask_channel_min_space": 1,
59
+ "mask_channel_other": 0.0,
60
+ "mask_channel_prob": 0.5,
61
+ "mask_channel_selection": "static",
62
+ "mask_length": 10,
63
+ "mask_length_audio": 10,
64
+ "mask_length_image": 10,
65
+ "mask_min_space": 1,
66
+ "mask_other": 0.0,
67
+ "mask_prob": 0.5,
68
+ "mask_prob_audio": 0.65,
69
+ "mask_prob_image": 0.65,
70
+ "mask_selection": "static",
71
+ "masking_type": "input",
72
+ "masking_updates": 0,
73
+ "max_target_positions": 2048,
74
+ "modality_dropout": 0.0,
75
+ "modality_fuse": "concat",
76
+ "model_type": "avsp_llm",
77
+ "no_mask_channel_overlap": false,
78
+ "no_mask_overlap": false,
79
+ "no_pretrained_weights": false,
80
+ "no_scale_embedding": true,
81
+ "no_token_positional_embeddings": false,
82
+ "normalize": false,
83
+ "num_classes": 2004,
84
+ "num_frames": 16,
85
+ "num_frozen_layers": 0,
86
+ "pretrained": "tanthinhdt/ViAVSP-LLM_v1.0",
87
+ "resnet_relu_type": "prelu",
88
+ "resnet_weights": null,
89
+ "sample_rate": 25,
90
+ "selection_type": "same_other_seq",
91
+ "share_decoder_input_output_embed": false,
92
+ "sim_type": "cosine",
93
+ "skip_masked": false,
94
+ "skip_nomask": false,
95
+ "sub_encoder_layers": 0,
96
+ "target_glu": false,
97
+ "torch_dtype": "float32",
98
+ "transformers_version": "4.41.2",
99
+ "untie_final_proj": false,
100
+ "w2v_args": null
101
+ }
configuration.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class AVHubertConfig(PretrainedConfig):
6
+ model_type = "av_hubert"
7
+
8
+ def __init__(
9
+ self,
10
+ label_rate: int = 25,
11
+ sample_rate: int = 25,
12
+ input_modality: str = "video",
13
+ extractor_mode: str = "default",
14
+ encoder_layers: int = 24,
15
+ encoder_embed_dim: int = 1024,
16
+ encoder_ffn_embed_dim: int = 4096,
17
+ encoder_attention_heads: int = 16,
18
+ activation_fn: str = "gelu",
19
+ dropout: float = 0.1,
20
+ attention_dropout: float = 0.1,
21
+ activation_dropout: float = 0.1,
22
+ encoder_layerdrop: float = 0.0,
23
+ dropout_input: float = 0.0,
24
+ dropout_features: float = 0.0,
25
+ final_dim: int = 256,
26
+ untie_final_proj: bool = False,
27
+ layer_norm_first: bool = False,
28
+ conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
29
+ conv_bias: bool = False,
30
+ logit_temp: float = 0.1,
31
+ target_glu: bool = False,
32
+ feature_grad_mult: float = 1.0,
33
+ mask_length_audio: int = 10,
34
+ mask_prob_audio: float = 0.65,
35
+ mask_length_image: int = 10,
36
+ mask_prob_image: float = 0.65,
37
+ mask_selection: str = "static",
38
+ mask_other: float = 0.0,
39
+ no_mask_overlap: bool = False,
40
+ mask_min_space: int = 1,
41
+ mask_channel_length: int = 64,
42
+ mask_channel_prob: float = 0.5,
43
+ mask_channel_selection: str = "static",
44
+ mask_channel_other: float = 0.0,
45
+ no_mask_channel_overlap: bool = False,
46
+ mask_channel_min_space: int = 1,
47
+ conv_pos: int = 128,
48
+ conv_pos_groups: int = 16,
49
+ latent_temp: Tuple[float, float, float] = (2.0, 0.5, 0.999995),
50
+ skip_masked: bool = False,
51
+ skip_nomask: bool = False,
52
+ resnet_relu_type: str = "prelu",
53
+ resnet_weights: str = None,
54
+ sim_type: str = "cosine",
55
+ sub_encoder_layers: int = 0,
56
+ audio_feat_dim: int = 104,
57
+ modality_dropout: float = 0.0,
58
+ audio_dropout: float = 0.0,
59
+ modality_fuse: str = "concat",
60
+ selection_type: str = "same_other_seq",
61
+ masking_type: str = "input",
62
+ decoder_embed_dim: int = 2560,
63
+ decoder_ffn_embed_dim: int = 3072,
64
+ decoder_layers: int = 6,
65
+ decoder_layerdrop: float = 0.0,
66
+ decoder_attention_heads: int = 4,
67
+ decoder_learned_pos: bool = False,
68
+ decoder_normalize_before: bool = False,
69
+ no_token_positional_embeddings: bool = False,
70
+ decoder_dropout: float = 0.1,
71
+ decoder_attention_dropout: float = 0.1,
72
+ decoder_activation_dropout: float = 0.0,
73
+ max_target_positions: int = 2048,
74
+ share_decoder_input_output_embed: bool = False,
75
+ no_scale_embedding: bool = True,
76
+ num_classes: int = 2004,
77
+ **kwargs,
78
+ ) -> None:
79
+ super().__init__(**kwargs)
80
+ self.label_rate = label_rate
81
+ self.sample_rate = sample_rate
82
+ self.input_modality = input_modality
83
+ self.extractor_mode = extractor_mode
84
+ self.encoder_layers = encoder_layers
85
+ self.encoder_embed_dim = encoder_embed_dim
86
+ self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
87
+ self.encoder_attention_heads = encoder_attention_heads
88
+ self.activation_fn = activation_fn
89
+ self.dropout = dropout
90
+ self.attention_dropout = attention_dropout
91
+ self.activation_dropout = activation_dropout
92
+ self.encoder_layerdrop = encoder_layerdrop
93
+ self.dropout_input = dropout_input
94
+ self.dropout_features = dropout_features
95
+ self.final_dim = final_dim
96
+ self.untie_final_proj = untie_final_proj
97
+ self.layer_norm_first = layer_norm_first
98
+ self.conv_feature_layers = conv_feature_layers
99
+ self.conv_bias = conv_bias
100
+ self.logit_temp = logit_temp
101
+ self.target_glu = target_glu
102
+ self.feature_grad_mult = feature_grad_mult
103
+ self.mask_length_audio = mask_length_audio
104
+ self.mask_prob_audio = mask_prob_audio
105
+ self.mask_length_image = mask_length_image
106
+ self.mask_prob_image = mask_prob_image
107
+ self.mask_selection = mask_selection
108
+ self.mask_other = mask_other
109
+ self.no_mask_overlap = no_mask_overlap
110
+ self.mask_min_space = mask_min_space
111
+ self.mask_channel_length = mask_channel_length
112
+ self.mask_channel_prob = mask_channel_prob
113
+ self.mask_channel_selection = mask_channel_selection
114
+ self.mask_channel_other = mask_channel_other
115
+ self.no_mask_channel_overlap = no_mask_channel_overlap
116
+ self.mask_channel_min_space = mask_channel_min_space
117
+ self.conv_pos = conv_pos
118
+ self.conv_pos_groups = conv_pos_groups
119
+ self.latent_temp = latent_temp
120
+ self.skip_masked = skip_masked
121
+ self.skip_nomask = skip_nomask
122
+ self.resnet_relu_type = resnet_relu_type
123
+ self.resnet_weights = resnet_weights
124
+ self.sim_type = sim_type
125
+ self.sub_encoder_layers = sub_encoder_layers
126
+ self.audio_feat_dim = audio_feat_dim
127
+ self.modality_dropout = modality_dropout
128
+ self.audio_dropout = audio_dropout
129
+ self.modality_fuse = modality_fuse
130
+ self.selection_type = selection_type
131
+ self.masking_type = masking_type
132
+ self.decoder_embed_dim = decoder_embed_dim
133
+ self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
134
+ self.decoder_layers = decoder_layers
135
+ self.decoder_layerdrop = decoder_layerdrop
136
+ self.decoder_attention_heads = decoder_attention_heads
137
+ self.decoder_learned_pos = decoder_learned_pos
138
+ self.decoder_normalize_before = decoder_normalize_before
139
+ self.no_token_positional_embeddings = no_token_positional_embeddings
140
+ self.decoder_dropout = decoder_dropout
141
+ self.decoder_attention_dropout = decoder_attention_dropout
142
+ self.decoder_activation_dropout = decoder_activation_dropout
143
+ self.max_target_positions = max_target_positions
144
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
145
+ self.no_scale_embedding = no_scale_embedding
146
+ self.num_classes = num_classes
147
+ self.feature_ds_rate = 1
148
+
149
+
150
+ class AVSPLLMConfig(AVHubertConfig):
151
+ model_type = "avsp_llm"
152
+
153
+ def __init__(
154
+ self,
155
+ llm_ckpt_path: str = "vilm/vinallama-2.7b",
156
+ no_pretrained_weights: bool = False,
157
+ final_dropout: float = 0.1,
158
+ apply_mask: bool = False,
159
+ mask_length: int = 10,
160
+ mask_prob: float = 0.5,
161
+ masking_updates: int = 0,
162
+ layerdrop: float = 0.0,
163
+ normalize: bool = False,
164
+ data: str = None,
165
+ w2v_args: dict = None,
166
+ freeze_finetune_updates: int = 0,
167
+ **kwargs,
168
+ ) -> None:
169
+ super().__init__(**kwargs)
170
+ self.llm_ckpt_path = llm_ckpt_path
171
+ self.no_pretrained_weights = no_pretrained_weights
172
+ self.final_dropout = final_dropout
173
+ self.apply_mask = apply_mask
174
+ self.mask_length = mask_length
175
+ self.mask_prob = mask_prob
176
+ self.masking_updates = masking_updates
177
+ self.layerdrop = layerdrop
178
+ self.normalize = normalize
179
+ self.data = data
180
+ self.w2v_args = w2v_args
181
+ self.freeze_finetune_updates = freeze_finetune_updates
encoder.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import List, Optional, Tuple
7
+ from .configuration import AVHubertConfig
8
+ from fairseq.utils import index_put
9
+ from fairseq.modules import LayerNorm, SamePad
10
+ from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
11
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
12
+
13
+
14
+ class TransformerEncoder(nn.Module):
15
+ def __init__(self, config: AVHubertConfig) -> None:
16
+ super().__init__()
17
+
18
+ self.dropout = config.dropout
19
+ self.embedding_dim = config.encoder_embed_dim
20
+
21
+ self.pos_conv = nn.Conv1d(
22
+ self.embedding_dim,
23
+ self.embedding_dim,
24
+ kernel_size=config.conv_pos,
25
+ padding=config.conv_pos // 2,
26
+ groups=config.conv_pos_groups,
27
+ )
28
+ dropout = 0
29
+ std = math.sqrt((4 * (1.0 - dropout)) / (config.conv_pos * self.embedding_dim))
30
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
31
+ nn.init.constant_(self.pos_conv.bias, 0)
32
+
33
+ self.pos_conv = nn.utils.weight_norm(
34
+ self.pos_conv, name="weight", dim=2
35
+ )
36
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(config.conv_pos), nn.GELU())
37
+
38
+ self.layers = nn.ModuleList(
39
+ [
40
+ TransformerSentenceEncoderLayer(
41
+ embedding_dim=self.embedding_dim,
42
+ ffn_embedding_dim=config.encoder_ffn_embed_dim,
43
+ num_attention_heads=config.encoder_attention_heads,
44
+ dropout=self.dropout,
45
+ attention_dropout=config.attention_dropout,
46
+ activation_dropout=config.activation_dropout,
47
+ activation_fn=config.activation_fn,
48
+ layer_norm_first=config.layer_norm_first,
49
+ )
50
+ for _ in range(config.encoder_layers)
51
+ ]
52
+ )
53
+
54
+ self.layer_norm_first = config.layer_norm_first
55
+ self.layer_norm = LayerNorm(self.embedding_dim)
56
+ self.layerdrop = config.encoder_layerdrop
57
+
58
+ self.apply(init_bert_params)
59
+
60
+ def forward(
61
+ self,
62
+ x: torch.Tensor,
63
+ padding_mask: Optional[torch.Tensor] = None,
64
+ layer: Optional[int] = None,
65
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
66
+ x, layer_results = self.extract_features(x, padding_mask, layer)
67
+ if self.layer_norm_first and layer is None:
68
+ x = self.layer_norm(x)
69
+ return x, layer_results
70
+
71
+ def extract_features(
72
+ self,
73
+ x: torch.Tensor,
74
+ padding_mask: Optional[torch.Tensor] = None,
75
+ tgt_layer: Optional[int] = None,
76
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
77
+ if padding_mask is not None:
78
+ x = index_put(x, padding_mask, 0)
79
+
80
+ x_conv = self.pos_conv(x.transpose(1, 2))
81
+ x_conv = x_conv.transpose(1, 2)
82
+ x = x + x_conv
83
+
84
+ if not self.layer_norm_first:
85
+ x = self.layer_norm(x)
86
+
87
+ x = F.dropout(x, p=self.dropout, training=self.training)
88
+
89
+ # B x T x C -> T x B x C
90
+ x = x.transpose(0, 1)
91
+
92
+ layer_results = []
93
+ r = None
94
+ for i, layer in enumerate(self.layers):
95
+ dropout_probability = np.random.random()
96
+ if not self.training or (dropout_probability > self.layerdrop):
97
+ x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
98
+ if tgt_layer is not None:
99
+ layer_results.append((x, z))
100
+ if i == tgt_layer:
101
+ r = x
102
+ break
103
+
104
+ if r is not None:
105
+ x = r
106
+
107
+ # T x B x C -> B x T x C
108
+ x = x.transpose(0, 1)
109
+
110
+ return x, layer_results
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.41.2"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8879d30edde4184fc1a4bd5bf54c0468b309c1dad2ff7b1b820dabbde5f44ec9
3
+ size 3126662252
modelling.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import contextlib
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from .resnet import ResNetEncoder
7
+ from .utils import compute_mask_indices
8
+ from .encoder import TransformerEncoder
9
+ from .configuration import AVHubertConfig, AVSPLLMConfig
10
+ from typing import Optional, Tuple, List, Dict, Any
11
+ from peft import get_peft_model, LoraConfig
12
+ from fairseq.modules import GradMultiply, LayerNorm
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+ from transformers import (
15
+ FeatureExtractionMixin,
16
+ PreTrainedModel,
17
+ BitsAndBytesConfig,
18
+ AutoModelForCausalLM,
19
+ GenerationConfig,
20
+ )
21
+
22
+
23
+ class AVHubertFeatureExtractor(FeatureExtractionMixin):
24
+ def __init__(self, **kwargs):
25
+ super().__init__(**kwargs)
26
+
27
+
28
+ class AVSPLLMFeatureExtractor(AVHubertFeatureExtractor):
29
+ def __init__(self, **kwargs):
30
+ super().__init__(**kwargs)
31
+
32
+
33
+ class AVHubertVideoFeatureEncoder(nn.Module):
34
+ def __init__(self, config: AVHubertConfig) -> None:
35
+ super().__init__()
36
+ self.resnet = ResNetEncoder(relu_type=config.resnet_relu_type)
37
+ self.proj = nn.Linear(self.resnet.backend_out, config.encoder_embed_dim)
38
+ self.encoder = (
39
+ TransformerEncoder(config)
40
+ if config.sub_encoder_layers > 0
41
+ else None
42
+ )
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ x = self.resnet(x)
46
+ x = self.proj(x.transpose(1, 2))
47
+ if self.encoder is not None:
48
+ x = self.encoder(x)[0].transpose(1, 2)
49
+ else:
50
+ x = x.transpose(1, 2)
51
+ return x
52
+
53
+
54
+ class AVHubertAudioFeatureEncoder(nn.Module):
55
+ def __init__(self, config: AVHubertConfig) -> None:
56
+ super().__init__()
57
+ self.proj = nn.Linear(config.audio_feat_dim, config.encoder_embed_dim)
58
+ self.encoder = (
59
+ TransformerEncoder(config)
60
+ if config.sub_encoder_layers > 0
61
+ else None
62
+ )
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ x = self.proj(x.transpose(1, 2))
66
+ if self.encoder is not None:
67
+ x = self.encoder(x)[0].transpose(1, 2)
68
+ else:
69
+ x = x.transpose(1, 2)
70
+ return x
71
+
72
+
73
+ class AVHubertModel(PreTrainedModel):
74
+ config_class = AVHubertConfig
75
+
76
+ def __init__(
77
+ self,
78
+ config: AVHubertConfig = AVHubertConfig(),
79
+ dictionaries: List = [None],
80
+ ) -> None:
81
+ super().__init__(config=config)
82
+ label_rate = config.label_rate
83
+ feature_ds_rate = config.feature_ds_rate
84
+ sample_rate = config.sample_rate
85
+ self.feat2tar_ration = label_rate * feature_ds_rate / sample_rate
86
+
87
+ self.feature_extractor_video = AVHubertVideoFeatureEncoder(config)
88
+ self.feature_extractor_audio = AVHubertAudioFeatureEncoder(config)
89
+
90
+ if config.modality_fuse == "concat":
91
+ self.encoder_embed_dim = config.encoder_embed_dim * 2
92
+ elif config.modality_fuse == "add":
93
+ self.encoder_embed_dim = config.encoder_embed_dim
94
+
95
+ self.post_extract_proj = (
96
+ nn.Linear(self.encoder_embed_dim, config.encoder_embed_dim)
97
+ if self.encoder_embed_dim != config.encoder_embed_dim
98
+ else None
99
+ )
100
+
101
+ self.dropout_input = nn.Dropout(config.dropout_input)
102
+ self.dropout_features = nn.Dropout(config.dropout_features)
103
+
104
+ if self.config.final_dim > 0:
105
+ final_dim = config.final_dim
106
+ else:
107
+ final_dim = config.encoder_embed_dim
108
+
109
+ self.mask_emb = nn.Parameter(
110
+ torch.FloatTensor(config.audio_feat_dim).uniform_()
111
+ if config.masking_type == "input"
112
+ else torch.FloatTensor(config.encoder_embed_dim).uniform_()
113
+ )
114
+
115
+ self.encoder = TransformerEncoder(self.config)
116
+ self.layer_norm = LayerNorm(self.encoder_embed_dim)
117
+
118
+ self.target_glu = None
119
+ if config.target_glu:
120
+ self.target_glu = nn.Sequential(
121
+ nn.Linear(config.final_dim, config.final_dim * 2),
122
+ nn.GLU(),
123
+ )
124
+
125
+ if config.untie_final_proj:
126
+ self.final_proj = nn.Linear(
127
+ config.encoder_embed_dim,
128
+ final_dim * len(dictionaries),
129
+ )
130
+ else:
131
+ self.final_proj = nn.Linear(config.encoder_embed_dim, final_dim)
132
+
133
+ # modules below are not needed during fine-tuning
134
+ if any([d is None for d in dictionaries]):
135
+ self.num_classes = config.num_classes
136
+ else:
137
+ self.num_classes = sum([len(d) for d in dictionaries])
138
+ self.label_embs_concat = nn.Parameter(
139
+ torch.FloatTensor(self.num_classes, final_dim)
140
+ )
141
+ nn.init.uniform_(self.label_embs_concat)
142
+
143
+ def apply_input_mask(
144
+ self,
145
+ x: torch.Tensor,
146
+ padding_mask: torch.Tensor,
147
+ target_list: List[torch.Tensor],
148
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
149
+ B, C, T = x.shape[:3]
150
+ is_audio = True if len(x.shape) == 3 else False
151
+
152
+ if is_audio:
153
+ mask_prob = self.config.mask_prob_audio
154
+ mask_length = self.config.mask_length_audio
155
+ else:
156
+ mask_prob = self.config.mask_prob_image
157
+ mask_length = self.config.mask_length_image
158
+
159
+ if mask_prob > 0:
160
+ mask_indices, starts, ends, batch_indexes = compute_mask_indices(
161
+ (B, T),
162
+ padding_mask,
163
+ mask_prob,
164
+ mask_length,
165
+ self.config.mask_selection,
166
+ self.config.mask_other,
167
+ min_masks=2,
168
+ no_overlap=self.config.no_mask_overlap,
169
+ min_space=self.config.mask_min_space,
170
+ )
171
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
172
+ x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
173
+ if B == 1:
174
+ x[mask_indices] = 0
175
+ elif is_audio:
176
+ x[mask_indices] = self.mask_emb
177
+ elif self.config.selection_type == "same_other_seq":
178
+ perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
179
+ x_perm = x[perm]
180
+ x[mask_indices] = x_perm[mask_indices]
181
+ elif self.config.selection_type == "same_seq":
182
+ batch_indexes_, other_indexes = [], []
183
+ for batch_index, start, end in zip(batch_indexes, starts, ends):
184
+ length = end - start
185
+ other_start = np.setdiff1d(
186
+ np.arange(T), np.arange(max(0, start - length), end)
187
+ )
188
+ if len(other_start) > 0:
189
+ other_start = np.random.choice(other_start, size=1)
190
+ else:
191
+ other_start = 0
192
+ other_end = other_start + length
193
+ other_indexes.append(
194
+ np.arange(other_start, other_end).clip(max=T - 1)
195
+ )
196
+ batch_indexes_.append(
197
+ np.zeros([length], dtype=np.int64) + batch_index
198
+ )
199
+ batch_indexes = np.concatenate(batch_indexes_)
200
+ other_indexes = np.concatenate(other_indexes)
201
+ x[mask_indices] = x[batch_indexes, other_indexes]
202
+ x = x.transpose(1, 2).contiguous()
203
+ else:
204
+ mask_indices = None
205
+
206
+ if self.config.mask_channel_prob > 0:
207
+ logging.info("No mask channel prob for input masking")
208
+ return x, mask_indices
209
+
210
+ def apply_feature_mask(
211
+ self,
212
+ x: torch.Tensor,
213
+ padding_mask: torch.Tensor,
214
+ target_list: List[torch.Tensor],
215
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
216
+ B, T, C = x.shape
217
+ assert all((
218
+ self.config.mask_prob_audio == self.config.mask_prob_image,
219
+ self.config.mask_length_audio == self.config.mask_length_image,
220
+ )), "masking prob/length for image/audio be same for feature masking"
221
+
222
+ mask_prob = self.config.mask_prob_audio
223
+ mask_length = self.config.mask_length_image
224
+ if mask_prob > 0:
225
+ mask_indices, _, _, _ = compute_mask_indices(
226
+ (B, T),
227
+ padding_mask,
228
+ mask_prob,
229
+ mask_length,
230
+ self.config.mask_selection,
231
+ self.config.mask_other,
232
+ min_masks=2,
233
+ no_overlap=self.config.no_mask_overlap,
234
+ min_space=self.config.mask_min_space,
235
+ )
236
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
237
+ x[mask_indices] = self.mask_emb
238
+ else:
239
+ mask_indices = None
240
+
241
+ if self.config.mask_channel_prob > 0:
242
+ mask_channel_indices, _, _, _ = compute_mask_indices(
243
+ (B, C),
244
+ None,
245
+ self.config.mask_channel_prob,
246
+ self.config.mask_channel_length,
247
+ self.config.mask_channel_selection,
248
+ self.config.mask_channel_other,
249
+ no_overlap=self.config.no_mask_channel_overlap,
250
+ min_space=self.config.mask_channel_min_space,
251
+ )
252
+ mask_channel_indices = (
253
+ torch.from_numpy(mask_channel_indices)
254
+ .to(x.device)
255
+ .unsqueeze(1)
256
+ .expand(-1, T, -1)
257
+ )
258
+ x[mask_channel_indices] = 0
259
+
260
+ return x, mask_indices
261
+
262
+ def forward_features(
263
+ self,
264
+ source: Dict[str, torch.Tensor],
265
+ modality: str,
266
+ ) -> torch.Tensor:
267
+ extractor = eval(f"self.feature_extractor_{modality}")
268
+ if self.config.feature_grad_mult > 0:
269
+ features = extractor(source)
270
+ if self.config.feature_grad_mult != 1.0:
271
+ features = GradMultiply.apply(features, self.config.feature_grad_mult)
272
+ else:
273
+ with torch.no_grad():
274
+ features = extractor(source)
275
+ return features
276
+
277
+ def forward_targets(
278
+ self,
279
+ features: torch.Tensor,
280
+ mask_indices: torch.Tensor,
281
+ target_list: List[torch.Tensor],
282
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
283
+ # Trim features to ensure labels exist and then get aligned labels
284
+ feat_tsz = features.size(2)
285
+ targ_tsz = min([t.size(1) for t in target_list])
286
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
287
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
288
+ features = features[..., :feat_tsz]
289
+ if mask_indices is not None:
290
+ mask_indices = mask_indices[..., :feat_tsz]
291
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
292
+ target_list = [t[:, target_inds.long()] for t in target_list]
293
+ return features, mask_indices, target_list
294
+
295
+ def forward_padding_mask(
296
+ self,
297
+ features: torch.Tensor,
298
+ padding_mask: torch.Tensor,
299
+ ) -> torch.Tensor:
300
+ extra = padding_mask.size(1) % features.size(1)
301
+ if extra > 0:
302
+ padding_mask = padding_mask[:, :-extra]
303
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
304
+ padding_mask = padding_mask.all(-1)
305
+ return padding_mask
306
+
307
+ def compute_logits(self, feats: torch.Tensor, emb_mat: torch.Tensor) -> torch.Tensor:
308
+ # feats: [B, T, F], emb_mat: [V, F]
309
+ if self.config.sim_type == "dot":
310
+ logits = torch.matmul(feats, emb_mat.transpose(0, 1))
311
+ elif self.config.sim_type == "cosine":
312
+ batch_size, timesteps, emb_dim = feats.size()
313
+ feats_ = feats.view(-1, emb_dim)
314
+ # [B*T, V]
315
+ nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1)
316
+ # [B*T, V]
317
+ denom = (
318
+ (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1)
319
+ * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0)
320
+ )
321
+ logits = (nom / denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
322
+ else:
323
+ raise NotImplementedError
324
+ logits = logits / self.config.logit_temp
325
+ return logits
326
+
327
+ def forward(
328
+ self,
329
+ source: Dict[str, torch.Tensor],
330
+ target_list: Optional[List[torch.Tensor]] = None,
331
+ padding_mask: Optional[torch.Tensor] = None,
332
+ mask: bool = True,
333
+ features_only: bool = False,
334
+ output_layer: Optional[int] = None,
335
+ ) -> Dict[str, torch.Tensor]:
336
+ """output layer is 1-based"""
337
+ src_audio, src_video = source["audio"], source["video"]
338
+ if mask and self.masking_type == "input":
339
+ src_video, mask_indices_video = self.apply_input_mask(
340
+ src_video, padding_mask, target_list
341
+ )
342
+ src_audio, mask_indices_audio = self.apply_input_mask(
343
+ src_audio, padding_mask, target_list
344
+ )
345
+ mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
346
+ else:
347
+ src_audio, src_video, mask_indices = src_audio, src_video, None
348
+
349
+ # [B, F, T]
350
+ features_audio = self.forward_features(src_audio, modality="audio")
351
+ features_video = self.forward_features(src_video, modality="video")
352
+
353
+ if self.training:
354
+ modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
355
+ if modality_drop_prob < self.config.modality_dropout:
356
+ if audio_drop_prob < self.config.audio_dropout:
357
+ features_audio = 0 * features_audio
358
+ else:
359
+ features_video = 0 * features_video
360
+
361
+ if self.config.modality_fuse == "concat":
362
+ features = torch.cat([features_audio, features_video], dim=1)
363
+ elif self.config.modality_fuse == "add":
364
+ features = features_audio + features_video
365
+
366
+ if target_list is not None:
367
+ features, mask_indices, target_list = self.forward_targets(
368
+ features, mask_indices, target_list
369
+ )
370
+
371
+ features_pen = features.float().pow(2).mean()
372
+
373
+ features = features.transpose(1, 2)
374
+ features = self.layer_norm(features)
375
+
376
+ if padding_mask is not None:
377
+ padding_mask = self.forward_padding_mask(features, padding_mask)
378
+
379
+ if self.post_extract_proj is not None:
380
+ features = self.post_extract_proj(features)
381
+
382
+ features = self.dropout_input(features)
383
+ if self.config.masking_type == "feature" and mask:
384
+ x, mask_indices = self.apply_feature_mask(
385
+ features, padding_mask, target_list
386
+ )
387
+ else:
388
+ x = features
389
+
390
+ # feature: (B, T, D), float
391
+ # target: (B, T), long
392
+ # x: (B, T, D), float
393
+ # padding_mask: (B, T), bool
394
+ # mask_indices: (B, T), bool
395
+ x, _ = self.encoder(
396
+ x,
397
+ padding_mask=padding_mask,
398
+ layer=None if output_layer is None else output_layer - 1,
399
+ )
400
+
401
+ if features_only:
402
+ return {"x": x, "padding_mask": padding_mask, "features": features}
403
+
404
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
405
+ proj_x = self.final_proj(x)
406
+ if self.config.untie_final_proj:
407
+ proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
408
+ else:
409
+ proj_x_list = [proj_x for _ in self.num_classes]
410
+
411
+ # [[B*T, V]]
412
+ logit_list = [
413
+ self.compute_logits(proj, emb).view(-1, num_class)
414
+ for proj, emb, num_class in zip(
415
+ proj_x_list, label_embs_list, self.num_classes
416
+ )
417
+ ]
418
+
419
+ mask = torch.logical_and(mask_indices, ~padding_mask).view(-1)
420
+ unmask = torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
421
+ logit_m_list = [logit[mask] for logit in logit_list]
422
+ logit_u_list = [logit[unmask] for logit in logit_list]
423
+ target_m_list = [target.view(-1)[mask].long() for target in target_list]
424
+ target_u_list = [target.view(-1)[unmask].long() for target in target_list]
425
+
426
+ return {
427
+ "logit_m_list": logit_m_list,
428
+ "logit_u_list": logit_u_list,
429
+ "target_m_list": target_m_list,
430
+ "target_u_list": target_u_list,
431
+ "padding_mask": padding_mask,
432
+ "features_pen": features_pen,
433
+ }
434
+
435
+ def extract_features(
436
+ self,
437
+ source: Dict[str, torch.Tensor],
438
+ padding_mask: Optional[torch.Tensor] = None,
439
+ mask: bool = False,
440
+ ret_conv: bool = False,
441
+ output_layer: Optional[int] = None,
442
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
443
+ res = self.forward(
444
+ source,
445
+ padding_mask=padding_mask,
446
+ mask=mask,
447
+ features_only=True,
448
+ output_layer=output_layer,
449
+ )
450
+ feature = res["features"] if ret_conv else res["x"]
451
+ return feature, res["padding_mask"]
452
+
453
+ def extract_units(
454
+ self,
455
+ source: Dict[str, torch.Tensor],
456
+ padding_mask: torch.Tensor = None,
457
+ mask: bool = False,
458
+ ret_conv: bool = False,
459
+ output_layer: Optional[int] = None,
460
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
461
+ res = self.forward(
462
+ source,
463
+ padding_mask=padding_mask,
464
+ mask=mask,
465
+ features_only=True,
466
+ output_layer=None,
467
+ )
468
+
469
+ feature = res["features"] if ret_conv else res["x"]
470
+ proj_x = self.final_proj(feature)
471
+ # B T
472
+ units = (
473
+ torch
474
+ .matmul(proj_x, self.label_embs_concat.transpose(0, 1))
475
+ .argmax(dim=-1)
476
+ )
477
+ return units
478
+
479
+ def extract_finetune(
480
+ self,
481
+ source: Dict[str, torch.Tensor],
482
+ padding_mask: torch.Tensor = None,
483
+ mask: bool = False,
484
+ ret_conv: bool = False,
485
+ output_layer: Optional[int] = None,
486
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
487
+ src_audio, src_video = source["audio"], source["video"]
488
+ if mask and self.config.masking_type == "input":
489
+ src_video, _ = self.apply_input_mask(
490
+ src_video, padding_mask, target_list=None
491
+ )
492
+ src_audio, _ = self.apply_input_mask(
493
+ src_audio, padding_mask, target_list=None
494
+ )
495
+ else:
496
+ src_audio, src_video, _ = src_audio, src_video, None
497
+
498
+ # features: [B, F, T]
499
+ if src_audio is not None and src_video is None:
500
+ features_audio = self.forward_features(
501
+ src_audio, modality="audio"
502
+ )
503
+ features_video = features_audio.new_zeros(
504
+ features_audio.size(0),
505
+ self.encoder_embed_dim,
506
+ features_audio.size(-1)
507
+ )
508
+ elif src_audio is None and src_video is not None:
509
+ features_video = self.forward_features(src_video, modality="video")
510
+ features_audio = features_video.new_zeros(
511
+ features_video.size(0),
512
+ self.encoder_embed_dim,
513
+ features_video.size(-1)
514
+ )
515
+ elif src_audio is not None and src_video is not None:
516
+ features_video = self.forward_features(src_video, modality="video")
517
+ features_audio = self.forward_features(
518
+ src_audio, modality="audio"
519
+ )
520
+
521
+ if self.config.modality_fuse == "concat":
522
+ features = torch.cat([features_audio, features_video], dim=1)
523
+ elif self.config.modality_fuse == "add":
524
+ features = features_audio + features_video
525
+
526
+ features = features.transpose(1, 2)
527
+ features = self.layer_norm(features)
528
+ unmasked_features = features.clone()
529
+
530
+ if padding_mask is not None:
531
+ padding_mask = self.forward_padding_mask(features, padding_mask)
532
+
533
+ if self.post_extract_proj is not None:
534
+ features = self.post_extract_proj(features)
535
+
536
+ features = self.dropout_input(features)
537
+ unmasked_features = self.dropout_features(unmasked_features)
538
+
539
+ # feature: (B, T, D), float
540
+ # target: (B, T), long
541
+ # x: (B, T, D), float
542
+ # padding_mask: (B, T), bool
543
+ # mask_indices: (B, T), bool
544
+ x, _ = self.encoder(
545
+ features,
546
+ padding_mask=padding_mask,
547
+ layer=None if output_layer is None else output_layer - 1,
548
+ )
549
+
550
+ return x, padding_mask
551
+
552
+ def get_extra_losses(
553
+ self,
554
+ net_output: Dict[str, torch.Tensor],
555
+ ) -> Tuple[List[torch.Tensor], List[str]]:
556
+ extra_losses = []
557
+ names = []
558
+ if "features_pen" in net_output:
559
+ extra_losses.append(net_output["features_pen"])
560
+ names.append("features_pen")
561
+
562
+ return extra_losses, names
563
+
564
+ def remove_pretraining_modules(self) -> None:
565
+ self.target_glu = None
566
+ self.final_proj = None
567
+
568
+ def compute_nce(
569
+ self,
570
+ x: torch.Tensor,
571
+ pos: torch.Tensor,
572
+ negs: torch.Tensor,
573
+ ) -> torch.Tensor:
574
+ neg_is_pos = (pos == negs).all(-1)
575
+ pos = pos.unsqueeze(0)
576
+ targets = torch.cat([pos, negs], dim=0)
577
+
578
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
579
+ logits /= self.config.logit_temp
580
+ if neg_is_pos.any():
581
+ logits[1:][neg_is_pos] = float("-inf")
582
+ logits = logits.transpose(0, 1) # (num_x, num_cls+1)
583
+ return logits
584
+
585
+
586
+ class HubertEncoderWrapper(nn.Module):
587
+ def __init__(
588
+ self,
589
+ config: AVHubertConfig,
590
+ dictionaries: List = [None],
591
+ ) -> None:
592
+ super().__init__()
593
+ self.w2v_model = AVHubertModel(config, dictionaries)
594
+
595
+ def forward(
596
+ self,
597
+ source: Dict[str, torch.Tensor],
598
+ padding_mask: torch.Tensor,
599
+ **kwargs,
600
+ ) -> Dict[str, torch.Tensor]:
601
+ w2v_args = {
602
+ "source": source,
603
+ "padding_mask": padding_mask,
604
+ }
605
+ x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
606
+ return {
607
+ "encoder_out": x, # T x B x C
608
+ "encoder_padding_mask": padding_mask, # B x T
609
+ "padding_mask": padding_mask,
610
+ }
611
+
612
+ def reorder_encoder_out(
613
+ self,
614
+ encoder_out: Dict[str, torch.Tensor],
615
+ new_order: torch.Tensor,
616
+ ) -> Dict[str, torch.Tensor]:
617
+ if encoder_out["encoder_out"] is not None:
618
+ encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
619
+ 1, new_order
620
+ )
621
+ if encoder_out["encoder_padding_mask"] is not None:
622
+ encoder_out["encoder_padding_mask"] = encoder_out[
623
+ "encoder_padding_mask"
624
+ ].index_select(0, new_order)
625
+ if encoder_out["padding_mask"] is not None:
626
+ encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
627
+ 0, new_order
628
+ )
629
+ return encoder_out
630
+
631
+
632
+ class AVSPLLMModel(PreTrainedModel):
633
+ config_class = AVSPLLMConfig
634
+
635
+ def __init__(
636
+ self,
637
+ config: AVSPLLMConfig = AVSPLLMConfig(),
638
+ dictionaries: List = [None],
639
+ ) -> None:
640
+ super().__init__(config=config)
641
+ self.encoder = HubertEncoderWrapper(config, dictionaries)
642
+ self.encoder.w2v_model.remove_pretraining_modules()
643
+
644
+ self.avfeat_to_llm = nn.Linear(
645
+ config.encoder_embed_dim, config.decoder_embed_dim
646
+ )
647
+
648
+ bnb_config = BitsAndBytesConfig(
649
+ load_in_4bit=True,
650
+ bnb_4bit_use_double_quant=True,
651
+ bnb_4bit_quant_type="nf4",
652
+ bnb_4bit_compute_dtype=torch.bfloat16,
653
+ )
654
+ decoder_4bit = AutoModelForCausalLM.from_pretrained(
655
+ config.llm_ckpt_path,
656
+ quantization_config=bnb_config,
657
+ )
658
+ lora_config = LoraConfig(
659
+ r=16,
660
+ lora_alpha=32,
661
+ target_modules=["q_proj", "v_proj", "k_proj"],
662
+ lora_dropout=0.05,
663
+ bias="none",
664
+ task_type="CAUSAL_LM",
665
+ )
666
+ self.decoder = get_peft_model(decoder_4bit, lora_config)
667
+ self.decoder.print_trainable_parameters()
668
+
669
+ def forward(
670
+ self,
671
+ source: Dict[str, torch.Tensor],
672
+ target_list: torch.Tensor,
673
+ padding_mask: torch.Tensor,
674
+ **kwargs,
675
+ ) -> CausalLMOutputWithPast:
676
+ ft = self.config.freeze_finetune_updates <= kwargs.get("num_updates", -1)
677
+ with torch.no_grad() if not ft else contextlib.ExitStack():
678
+ output = self.encoder(source, padding_mask, **kwargs)
679
+
680
+ output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"])
681
+ cluster_counts = source["cluster_counts"][0] # tensor list
682
+
683
+ results_tensor = []
684
+ start_idx = 0
685
+ for clutser_num in cluster_counts:
686
+ end_idx = start_idx + clutser_num
687
+ slice = output["encoder_out"][:, start_idx:end_idx, :]
688
+ mean_tensor = torch.mean(slice, dim=1, keepdim=True)
689
+ results_tensor.append(mean_tensor)
690
+ start_idx = end_idx
691
+
692
+ assert cluster_counts.sum().item() == output["encoder_out"].size()[1], \
693
+ f"{cluster_counts.sum().item()} != {output['encoder_out'].size()[1]}"
694
+
695
+ reduced_enc_out = torch.cat(results_tensor, dim=1)
696
+ B, T, D = reduced_enc_out.size()
697
+
698
+ instruction = source["text"]
699
+ instruction_embedding = self.decoder.model.model.embed_tokens(instruction)
700
+
701
+ labels = target_list.clone()
702
+ labels_embedding = self.decoder.model.model.embed_tokens(labels)
703
+
704
+ llm_input = torch.cat(
705
+ (instruction_embedding, reduced_enc_out, labels_embedding), dim=1
706
+ )
707
+ llm_labels = labels.clone()
708
+ llm_labels[llm_labels == 0] = -100
709
+
710
+ _, instruction_embedding_t, _ = instruction_embedding.size()
711
+ target_ids = (
712
+ torch.full((B, T + instruction_embedding_t), -100).long().to(labels.device)
713
+ )
714
+ llm_labels = torch.cat((target_ids, llm_labels), dim=1)
715
+ return self.decoder(
716
+ inputs_embeds=llm_input, labels=llm_labels, return_dict=True
717
+ )
718
+
719
+ @torch.no_grad()
720
+ def generate(
721
+ self,
722
+ inputs: Optional[Dict[str, torch.Tensor]] = None,
723
+ generation_config: Optional[GenerationConfig] = None,
724
+ **kwargs,
725
+ ) -> Any:
726
+ output = self.encoder(**inputs)
727
+ output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"])
728
+ cluster_counts = inputs["source"]["cluster_counts"][0] # tensor list
729
+
730
+ results_tensor = []
731
+ start_idx = 0
732
+
733
+ for clutser_num in cluster_counts:
734
+ end_idx = start_idx + clutser_num
735
+ slice = output["encoder_out"][:, start_idx:end_idx, :]
736
+ mean_tensor = torch.mean(slice, dim=1, keepdim=True)
737
+ results_tensor.append(mean_tensor)
738
+ start_idx = end_idx
739
+
740
+ assert cluster_counts.sum().item() == output["encoder_out"].size()[1]
741
+
742
+ reduced_enc_out = torch.cat(results_tensor, dim=1)
743
+ B, T, D = reduced_enc_out.size()
744
+ instruction = inputs["source"]["text"]
745
+ instruction_embedding = self.decoder.model.model.embed_tokens(instruction)
746
+ llm_input = torch.cat((instruction_embedding, reduced_enc_out), dim=1)
747
+
748
+ self.decoder.config.use_cache = True
749
+ return self.decoder.generate(
750
+ inputs_embeds=llm_input,
751
+ **generation_config,
752
+ **kwargs,
753
+ )
resnet.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from collections import OrderedDict
5
+
6
+
7
+ def conv3x3(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d:
8
+ return nn.Conv2d(
9
+ in_channels=in_channels,
10
+ out_channels=out_channels,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=1,
14
+ bias=False
15
+ )
16
+
17
+
18
+ def downsample_basic_block(
19
+ in_channels: int,
20
+ out_channels: int,
21
+ stride: int,
22
+ ) -> nn.Sequential:
23
+ return nn.Sequential(
24
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
25
+ nn.BatchNorm2d(out_channels),
26
+ )
27
+
28
+
29
+ def downsample_basic_block_v2(
30
+ in_channels: int,
31
+ out_channels: int,
32
+ stride: int,
33
+ ) -> nn.Sequential:
34
+ return nn.Sequential(
35
+ nn.AvgPool2d(
36
+ kernel_size=stride,
37
+ stride=stride,
38
+ ceil_mode=True,
39
+ count_include_pad=False,
40
+ ),
41
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
42
+ nn.BatchNorm2d(out_channels),
43
+ )
44
+
45
+
46
+ class BasicBlock(nn.Module):
47
+ expansion = 1
48
+
49
+ def __init__(
50
+ self,
51
+ in_channels: int,
52
+ channels: int,
53
+ stride: int = 1,
54
+ downsample: nn.Sequential = None,
55
+ relu_type: str = "relu",
56
+ ) -> None:
57
+ super(BasicBlock, self).__init__()
58
+ assert relu_type in ["relu", "prelu"]
59
+
60
+ self.conv1 = conv3x3(in_channels, channels, stride)
61
+ self.bn1 = nn.BatchNorm2d(channels)
62
+
63
+ if relu_type == "relu":
64
+ self.relu1 = nn.ReLU(inplace=True)
65
+ self.relu2 = nn.ReLU(inplace=True)
66
+ elif relu_type == "prelu":
67
+ self.relu1 = nn.PReLU(num_parameters=channels)
68
+ self.relu2 = nn.PReLU(num_parameters=channels)
69
+ else:
70
+ raise Exception("relu type not implemented")
71
+
72
+ self.conv2 = conv3x3(channels, channels)
73
+ self.bn2 = nn.BatchNorm2d(channels)
74
+
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ residual = x
80
+ out = self.conv1(x)
81
+ out = self.bn1(out)
82
+ out = self.relu1(out)
83
+ out = self.conv2(out)
84
+ out = self.bn2(out)
85
+ if self.downsample is not None:
86
+ residual = self.downsample(x)
87
+ out += residual
88
+ out = self.relu2(out)
89
+ return out
90
+
91
+
92
+ class ResNet(nn.Module):
93
+ def __init__(
94
+ self,
95
+ block: nn.Module,
96
+ layers: list,
97
+ relu_type: str = "relu",
98
+ gamma_zero: bool = False,
99
+ avg_pool_downsample: bool = False,
100
+ ) -> None:
101
+ self.in_channels = 64
102
+ self.relu_type = relu_type
103
+ self.gamma_zero = gamma_zero
104
+ self.downsample_block = (
105
+ downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
106
+ )
107
+
108
+ super(ResNet, self).__init__()
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
114
+
115
+ for m in self.modules():
116
+ if isinstance(m, nn.Conv2d):
117
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
119
+ elif isinstance(m, nn.BatchNorm2d):
120
+ m.weight.data.fill_(1)
121
+ m.bias.data.zero_()
122
+
123
+ if self.gamma_zero:
124
+ for m in self.modules():
125
+ if isinstance(m, BasicBlock):
126
+ m.bn2.weight.data.zero_()
127
+
128
+ def _make_layer(
129
+ self,
130
+ block: nn.Module,
131
+ channels: int,
132
+ n_blocks: int,
133
+ stride: int = 1,
134
+ ) -> nn.Sequential:
135
+ downsample = None
136
+ if stride != 1 or self.in_channels != channels * block.expansion:
137
+ downsample = self.downsample_block(
138
+ in_channels=self.in_channels,
139
+ out_channels=channels * block.expansion,
140
+ stride=stride,
141
+ )
142
+
143
+ layers = [
144
+ block(
145
+ self.in_channels, channels, stride, downsample, relu_type=self.relu_type
146
+ )
147
+ ]
148
+ self.in_channels = channels * block.expansion
149
+ for _ in range(1, n_blocks):
150
+ layers.append(block(self.in_channels, channels, relu_type=self.relu_type))
151
+
152
+ return nn.Sequential(*layers)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ x = self.layer1(x)
156
+ x = self.layer2(x)
157
+ x = self.layer3(x)
158
+ x = self.layer4(x)
159
+ x = self.avgpool(x)
160
+ x = x.view(x.size(0), -1)
161
+ return x
162
+
163
+
164
+ class ResNetEncoder(nn.Module):
165
+ def __init__(self, relu_type: str, weight_file: str = None) -> None:
166
+ super(ResNetEncoder, self).__init__()
167
+ self.frontend_out = 64
168
+ self.backend_out = 512
169
+ frontend_relu = (
170
+ nn.PReLU(num_parameters=self.frontend_out)
171
+ if relu_type == "prelu"
172
+ else nn.ReLU()
173
+ )
174
+
175
+ self.frontend3D = nn.Sequential(
176
+ nn.Conv3d(
177
+ 1,
178
+ self.frontend_out,
179
+ kernel_size=(5, 7, 7),
180
+ stride=(1, 2, 2),
181
+ padding=(2, 3, 3),
182
+ bias=False,
183
+ ),
184
+ nn.BatchNorm3d(self.frontend_out),
185
+ frontend_relu,
186
+ nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
187
+ )
188
+ self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
189
+
190
+ if weight_file is not None:
191
+ model_state_dict = torch.load(weight_file, map_location=torch.device("cpu"))
192
+ model_state_dict = model_state_dict["model_state_dict"]
193
+ frontend_state_dict, trunk_state_dict = OrderedDict(), OrderedDict()
194
+ for key, val in model_state_dict.items():
195
+ new_key = ".".join(key.split(".")[1:])
196
+ if "frontend3D" in key:
197
+ frontend_state_dict[new_key] = val
198
+ if "trunk" in key:
199
+ trunk_state_dict[new_key] = val
200
+ self.frontend3D.load_state_dict(frontend_state_dict)
201
+ self.trunk.load_state_dict(trunk_state_dict)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ B, C, T, H, W = x.size()
205
+ x = self.frontend3D(x)
206
+ Tnew = x.shape[2]
207
+ x = self.convert_3D_to_2D(x)
208
+ x = self.trunk(x)
209
+ x = x.view(B, Tnew, x.size(1))
210
+ x = x.transpose(1, 2).contiguous()
211
+ return x
212
+
213
+ def convert_3D_to_2D(self, x: torch.Tensor) -> torch.Tensor:
214
+ n_batches, n_channels, s_time, sx, sy = x.shape
215
+ x = x.transpose(1, 2).contiguous()
216
+ return x.reshape(n_batches * s_time, n_channels, sx, sy)
utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Tuple, Optional
4
+
5
+
6
+ def find_runs(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
7
+ """Find runs of consecutive items in an array."""
8
+
9
+ # ensure array
10
+ x = np.asanyarray(x)
11
+ if x.ndim != 1:
12
+ raise ValueError("only 1D array supported")
13
+ n = x.shape[0]
14
+
15
+ # handle empty array
16
+ if n == 0:
17
+ return np.array([]), np.array([]), np.array([])
18
+ else:
19
+ # find run starts
20
+ loc_run_start = np.empty(n, dtype=bool)
21
+ loc_run_start[0] = True
22
+ np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
23
+ run_starts = np.nonzero(loc_run_start)[0]
24
+
25
+ # find run values
26
+ run_values = x[loc_run_start]
27
+
28
+ # find run lengths
29
+ run_lengths = np.diff(np.append(run_starts, n))
30
+
31
+ return run_values, run_starts, run_lengths
32
+
33
+
34
+ def compute_mask_indices(
35
+ shape: Tuple[int, int],
36
+ padding_mask: Optional[torch.Tensor],
37
+ mask_prob: float,
38
+ mask_length: int,
39
+ mask_type: str = "static",
40
+ mask_other: float = 0.0,
41
+ min_masks: int = 0,
42
+ no_overlap: bool = False,
43
+ min_space: int = 0,
44
+ ) -> np.ndarray:
45
+ """
46
+ Computes random mask spans for a given shape
47
+ Args:
48
+ shape: the the shape for which to compute masks.
49
+ should be of size 2 where first element is batch size and 2nd is timesteps
50
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
51
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
52
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
53
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
54
+ mask_type: how to compute mask lengths
55
+ static = fixed size
56
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
57
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
58
+ poisson = sample from possion distribution with lambda = mask length
59
+ min_masks: minimum number of masked spans
60
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
61
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
62
+ """
63
+ bsz, all_sz = shape
64
+ mask = np.full((bsz, all_sz), False)
65
+
66
+ all_num_mask = int(
67
+ # add a random number for probabilistic rounding
68
+ mask_prob * all_sz / float(mask_length)
69
+ + np.random.rand()
70
+ )
71
+
72
+ all_num_mask = max(min_masks, all_num_mask)
73
+
74
+ mask_idcs = []
75
+ for i in range(bsz):
76
+ if padding_mask is not None:
77
+ sz = all_sz - padding_mask[i].long().sum().item()
78
+ num_mask = int(
79
+ # add a random number for probabilistic rounding
80
+ mask_prob * sz / float(mask_length)
81
+ + np.random.rand()
82
+ )
83
+ num_mask = max(min_masks, num_mask)
84
+ else:
85
+ sz = all_sz
86
+ num_mask = all_num_mask
87
+
88
+ if mask_type == "static":
89
+ lengths = np.full(num_mask, mask_length)
90
+ elif mask_type == "uniform":
91
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
92
+ elif mask_type == "normal":
93
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
94
+ lengths = [max(1, int(round(x))) for x in lengths]
95
+ elif mask_type == "poisson":
96
+ lengths = np.random.poisson(mask_length, size=num_mask)
97
+ lengths = [int(round(x)) for x in lengths]
98
+ else:
99
+ raise Exception("unknown mask selection " + mask_type)
100
+
101
+ if sum(lengths) == 0:
102
+ lengths[0] = min(mask_length, sz - 1)
103
+
104
+ if no_overlap:
105
+ mask_idc = []
106
+
107
+ def arrange(s, e, length, keep_length):
108
+ span_start = np.random.randint(s, e - length)
109
+ mask_idc.extend(span_start + i for i in range(length))
110
+
111
+ new_parts = []
112
+ if span_start - s - min_space >= keep_length:
113
+ new_parts.append((s, span_start - min_space + 1))
114
+ if e - span_start - keep_length - min_space > keep_length:
115
+ new_parts.append((span_start + length + min_space, e))
116
+ return new_parts
117
+
118
+ parts = [(0, sz)]
119
+ min_length = min(lengths)
120
+ for length in sorted(lengths, reverse=True):
121
+ lens = np.fromiter(
122
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
123
+ np.int,
124
+ )
125
+ l_sum = np.sum(lens)
126
+ if l_sum == 0:
127
+ break
128
+ probs = lens / np.sum(lens)
129
+ c = np.random.choice(len(parts), p=probs)
130
+ s, e = parts.pop(c)
131
+ parts.extend(arrange(s, e, length, min_length))
132
+ mask_idc = np.asarray(mask_idc)
133
+ else:
134
+ min_len = min(lengths)
135
+ if sz - min_len <= num_mask:
136
+ min_len = sz - num_mask - 1
137
+
138
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
139
+
140
+ mask_idc = np.asarray(
141
+ [
142
+ mask_idc[j] + offset
143
+ for j in range(len(mask_idc))
144
+ for offset in range(lengths[j])
145
+ ]
146
+ )
147
+
148
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
149
+
150
+ min_len = min([len(m) for m in mask_idcs])
151
+ batch_indexes, starts, ends = [], [], []
152
+ for i, mask_idc in enumerate(mask_idcs):
153
+ if len(mask_idc) > min_len:
154
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
155
+ mask[i, mask_idc] = True
156
+ vals, run_starts, run_lengths = find_runs(mask[i])
157
+ start_indices, lengths = run_starts[vals == True], run_lengths[vals == True]
158
+ starts.append(start_indices)
159
+ ends.append(start_indices + lengths)
160
+ batch_indexes.append(np.zeros([len(start_indices)]) + i)
161
+ return (
162
+ mask,
163
+ np.concatenate(starts).astype(np.int64),
164
+ np.concatenate(ends).astype(np.int64),
165
+ np.concatenate(batch_indexes).astype(np.int64),
166
+ )