wli3221134 commited on
Commit
34146f0
·
verified ·
1 Parent(s): 851ce73

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +64 -15
  2. dataset.py +130 -0
  3. inference.py +2 -0
  4. llama_nar.py +571 -0
  5. model.py +379 -0
app.py CHANGED
@@ -1,15 +1,56 @@
1
  import gradio as gr
2
  import os
3
- # import inference
 
 
4
 
5
- def audio_deepfake_detection(demonstration_paths, audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """Audio deepfake detection function"""
7
  # Replace with your actual detection logic
8
  print("Demonstration audio paths: {}".format(demonstration_paths))
9
  print("Query audio path: {}".format(audio_path))
10
-
 
 
11
  # Example return value, modify according to your model
12
- result = inference.detect(demonstration_paths, audio_path)
13
 
14
  # Return detection results and confidence scores
15
  return {
@@ -33,14 +74,22 @@ with gr.Blocks() as demo:
33
  """
34
  )
35
 
36
- # Demonstration audio input component
37
- demonstration_audio_input = gr.Audio(
38
- sources=["upload"],
39
- label="Demonstration Audios",
40
- type="filepath",
41
- )
 
 
 
 
 
 
 
 
42
 
43
- # Audio input component
44
  query_audio_input = gr.Audio(
45
  sources=["upload"],
46
  label="Query Audio (Audio for Detection)",
@@ -56,7 +105,7 @@ with gr.Blocks() as demo:
56
  # Set click event
57
  submit_btn.click(
58
  fn=audio_deepfake_detection,
59
- inputs=[demonstration_audio_input, query_audio_input],
60
  outputs=[output_labels]
61
  )
62
 
@@ -64,10 +113,10 @@ with gr.Blocks() as demo:
64
  gr.Markdown("## Test Examples")
65
  gr.Examples(
66
  examples=[
67
- ["examples/real_audio.wav", "examples/query_audio.wav"],
68
- ["examples/fake_audio.wav", "examples/query_audio.wav"],
69
  ],
70
- inputs=[demonstration_audio_input, query_audio_input],
71
  )
72
 
73
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import os
3
+ import dataset
4
+ import torch
5
+ from model import Wav2Vec2BERT_Llama
6
 
7
+ # init
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
9
+
10
+ # init model
11
+ model = Wav2Vec2BERT_Llama().to(device)
12
+ checkpoint_path = "ckpt/model_checkpoint.pth"
13
+ if os.path.exists(checkpoint_path):
14
+ checkpoint = torch.load(checkpoint_path)
15
+ model_state_dict = checkpoint['model_state_dict']
16
+
17
+ # 处理模型状态字典
18
+ if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
19
+ model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
20
+ elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
21
+ model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
22
+
23
+ model.load_state_dict(model_state_dict)
24
+ model.eval()
25
+ else:
26
+ raise FileNotFoundError(f"Not found checkpoint: {checkpoint_path}")
27
+
28
+
29
+ def detect(dataset, model):
30
+ with torch.no_grad():
31
+ for batch in dataset:
32
+ main_features = {
33
+ 'input_features': batch['main_features']['input_features'].to(device),
34
+ 'attention_mask': batch['main_features']['attention_mask'].to(device)
35
+ }
36
+ prompt_features = [{
37
+ 'input_features': pf['input_features'].to(device),
38
+ 'attention_mask': pf['attention_mask'].to(device)
39
+ } for pf in batch['prompt_features']]
40
+
41
+
42
+
43
+
44
+ def audio_deepfake_detection(demonstration_paths, audio_path, model):
45
  """Audio deepfake detection function"""
46
  # Replace with your actual detection logic
47
  print("Demonstration audio paths: {}".format(demonstration_paths))
48
  print("Query audio path: {}".format(audio_path))
49
+
50
+ # dataset
51
+ dataset = dataset.DemoDataset(demonstration_paths, audio_path)
52
  # Example return value, modify according to your model
53
+ result = detect(dataset, model)
54
 
55
  # Return detection results and confidence scores
56
  return {
 
74
  """
75
  )
76
 
77
+ # Create container for demonstration audio
78
+ with gr.Row():
79
+ # Demonstration audio file upload
80
+ demonstration_audio_input = gr.File(
81
+ file_count="multiple",
82
+ file_types=["audio"],
83
+ label="Demonstration Audios",
84
+ )
85
+ # Add demonstration type selection
86
+ demonstration_type = gr.Dropdown(
87
+ choices=["bonafide", "spoof"],
88
+ value="bonafide",
89
+ label="Demonstration Label",
90
+ )
91
 
92
+ # Query audio input component
93
  query_audio_input = gr.Audio(
94
  sources=["upload"],
95
  label="Query Audio (Audio for Detection)",
 
105
  # Set click event
106
  submit_btn.click(
107
  fn=audio_deepfake_detection,
108
+ inputs=[demonstration_audio_input, demonstration_type, query_audio_input],
109
  outputs=[output_labels]
110
  )
111
 
 
113
  gr.Markdown("## Test Examples")
114
  gr.Examples(
115
  examples=[
116
+ ["examples/real_audio.wav", "bonafide", "examples/query_audio.wav"],
117
+ ["examples/fake_audio.wav", "spoof", "examples/query_audio.wav"],
118
  ],
119
+ inputs=[demonstration_audio_input, demonstration_type, query_audio_input],
120
  )
121
 
122
  if __name__ == "__main__":
dataset.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from transformers import AutoFeatureExtractor
4
+ import os
5
+ import librosa
6
+ import numpy as np
7
+
8
+ class DemoDataset(Dataset):
9
+ def __init__(self, demonstration_paths, query_path, sample_rate=16000):
10
+ self.sample_rate = sample_rate
11
+ self.query_path = query_path
12
+
13
+ # Convert to list if single path
14
+ if isinstance(demonstration_paths, str):
15
+ self.demonstration_paths = [demonstration_paths]
16
+ else:
17
+ self.demonstration_paths = demonstration_paths
18
+
19
+ # Load feature extractor
20
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
21
+
22
+ print(f'Number of demonstration audios: {len(self.demonstration_paths)}')
23
+ print(f'Query audio: {self.query_path}')
24
+
25
+ def load_pad(self, path, max_length=64000):
26
+ """Load and pad audio file"""
27
+ X, sr = librosa.load(path, sr=self.sample_rate)
28
+ X = self.pad(X, max_length)
29
+ return X
30
+
31
+ def pad(self, x, max_len=64000):
32
+ """Pad audio to fixed length"""
33
+ x_len = x.shape[0]
34
+ if x_len >= max_len:
35
+ return x[:max_len]
36
+ pad_length = max_len - x_len
37
+ return np.concatenate([x, np.zeros(pad_length)], axis=0)
38
+
39
+ def __len__(self):
40
+ return 1 # Only one query audio
41
+
42
+ def __getitem__(self, idx):
43
+ # Load query audio
44
+ query_waveform = self.load_pad(self.query_path)
45
+ query_waveform = torch.from_numpy(query_waveform).float()
46
+ if len(query_waveform.shape) == 1:
47
+ query_waveform = query_waveform.unsqueeze(0)
48
+
49
+ # Extract features for query audio
50
+ main_features = self.feature_extractor(
51
+ query_waveform,
52
+ sampling_rate=self.sample_rate,
53
+ padding=True,
54
+ return_attention_mask=True,
55
+ return_tensors="pt"
56
+ )
57
+
58
+ # Process demonstration audios
59
+ prompt_features = []
60
+ for demo_path in self.demonstration_paths:
61
+ # Load demonstration audio
62
+ demo_waveform = self.load_pad(demo_path)
63
+ demo_waveform = torch.from_numpy(demo_waveform).float()
64
+ if len(demo_waveform.shape) == 1:
65
+ demo_waveform = demo_waveform.unsqueeze(0)
66
+
67
+ # Extract features
68
+ prompt_feature = self.feature_extractor(
69
+ demo_waveform,
70
+ sampling_rate=self.sample_rate,
71
+ padding=True,
72
+ return_attention_mask=True,
73
+ return_tensors="pt"
74
+ )
75
+ prompt_features.append(prompt_feature)
76
+
77
+ return {
78
+ 'main_features': main_features,
79
+ 'prompt_features': prompt_features,
80
+ 'file_name': os.path.basename(self.query_path),
81
+ 'file_path': self.query_path
82
+ }
83
+
84
+ def collate_fn(batch):
85
+ """
86
+ Collate function for dataloader
87
+ Args:
88
+ batch: List containing dictionaries with:
89
+ - main_features: feature extractor output
90
+ - prompt_features: list of feature extractor outputs
91
+ - file_name: file name
92
+ - file_path: file path
93
+ """
94
+ batch_size = len(batch)
95
+
96
+ # Process main features
97
+ main_features_keys = batch[0]['main_features'].keys()
98
+ main_features = {}
99
+ for key in main_features_keys:
100
+ main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0)
101
+
102
+ # Get number of prompts
103
+ num_prompts = len(batch[0]['prompt_features'])
104
+
105
+ # Process prompt features
106
+ prompt_features = []
107
+ for i in range(num_prompts):
108
+ prompt_feature = {}
109
+ for key in main_features_keys:
110
+ prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0)
111
+ prompt_features.append(prompt_feature)
112
+
113
+ # Collect file names and paths
114
+ file_names = [item['file_name'] for item in batch]
115
+ file_paths = [item['file_path'] for item in batch]
116
+
117
+ return {
118
+ 'main_features': main_features,
119
+ 'prompt_features': prompt_features,
120
+ 'file_names': file_names,
121
+ 'file_paths': file_paths
122
+ }
123
+
124
+ if __name__ == '__main__':
125
+ # Test the dataset
126
+ demo_paths = ["examples/demo1.wav", "examples/demo2.wav"]
127
+ query_path = "examples/query.wav"
128
+
129
+ dataset = DemoDataset(demo_paths, query_path)
130
+ print(dataset[0])
inference.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def detect(dataset):
2
+ pass
llama_nar.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ import torch.nn as nn
7
+ from typing import List, Optional, Tuple, Union
8
+ import math
9
+
10
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
11
+ from torchmetrics.classification import MulticlassAccuracy
12
+ from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
13
+
14
+
15
+ # sinusoidal positional encoding
16
+ class SinusoidalPosEmb(nn.Module):
17
+ def __init__(self, dim):
18
+ super().__init__()
19
+ self.dim = dim
20
+
21
+ def forward(self, x):
22
+ device = x.device
23
+ half_dim = self.dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26
+ emb = x[:, None] * emb[None, :] * 1.0
27
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28
+ return emb
29
+
30
+
31
+
32
+
33
+ class LlamaAdaptiveRMSNorm(nn.Module):
34
+ def __init__(self, dim: int, eps: float = 1e-6):
35
+ super().__init__()
36
+ self.eps = eps
37
+ # The gamma parameter
38
+ self.weight = nn.Parameter(torch.ones(dim))
39
+
40
+ def _norm(self, x: torch.Tensor):
41
+ # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
42
+ # rsqrt: 1 / sqrt(x)
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x: torch.Tensor):
46
+ # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
47
+ return self.weight * self._norm(x.float()).type_as(x)
48
+
49
+
50
+ class MultiEmbedding(nn.Module):
51
+ """Embedding for multiple quantization layers, summing up the embeddings of each layer."""
52
+
53
+ def __init__(
54
+ self,
55
+ num_embeddings=1028,
56
+ embedding_dim=1024,
57
+ num_quantization_layers=8,
58
+ ):
59
+ super().__init__()
60
+ self.embeddings = nn.ModuleList(
61
+ [
62
+ nn.Embedding(num_embeddings, embedding_dim)
63
+ for _ in range(num_quantization_layers)
64
+ ]
65
+ )
66
+
67
+ # initialize embeddings
68
+ for i in range(num_quantization_layers):
69
+ self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02)
70
+ self._is_hf_initialized = True # disable automatic init
71
+
72
+ def forward(self, input_ids):
73
+ """Input: [num_quant, B, T] -> Output: [B, T, H]"""
74
+ num_quant, B, T = input_ids.shape
75
+ summed_embeddings = torch.zeros(
76
+ B, T, self.embeddings[0].embedding_dim, device=input_ids.device
77
+ )
78
+ for i in range(num_quant):
79
+ summed_embeddings += self.embeddings[i](input_ids[i])
80
+ return summed_embeddings
81
+
82
+
83
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
84
+ def __init__(self, config: LlamaConfig, layer_idx: int):
85
+ """Override to adaptive layer norm"""
86
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
87
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
88
+ config.hidden_size, eps=config.rms_norm_eps
89
+ )
90
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
91
+ config.hidden_size, eps=config.rms_norm_eps
92
+ )
93
+
94
+
95
+ def forward(
96
+ self,
97
+ hidden_states: torch.Tensor,
98
+ attention_mask: Optional[torch.Tensor] = None,
99
+ position_ids: Optional[torch.LongTensor] = None,
100
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
101
+ output_attentions: Optional[bool] = False,
102
+ use_cache: Optional[bool] = False,
103
+ ) -> Tuple[
104
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
105
+ ]:
106
+ """
107
+ Args:
108
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
109
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
110
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
111
+ output_attentions (`bool`, *optional*):
112
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
113
+ returned tensors for more detail.
114
+ use_cache (`bool`, *optional*):
115
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
116
+ (see `past_key_values`).
117
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
118
+ """
119
+
120
+ residual = hidden_states
121
+
122
+ hidden_states = self.input_layernorm(hidden_states)
123
+
124
+ # Self Attention
125
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
126
+ hidden_states=hidden_states,
127
+ attention_mask=attention_mask,
128
+ position_ids=position_ids,
129
+ past_key_value=past_key_value,
130
+ output_attentions=output_attentions,
131
+ use_cache=use_cache,
132
+ )
133
+ hidden_states = residual + hidden_states
134
+
135
+ # Fully Connected
136
+ residual = hidden_states
137
+ hidden_states = self.post_attention_layernorm(hidden_states)
138
+ hidden_states = self.mlp(hidden_states)
139
+ hidden_states = residual + hidden_states
140
+
141
+ outputs = (hidden_states,)
142
+
143
+ if output_attentions:
144
+ outputs += (self_attn_weights,)
145
+
146
+ if use_cache:
147
+ outputs += (present_key_value,)
148
+
149
+ return outputs
150
+
151
+
152
+
153
+ class LlamaNAR(LlamaModel):
154
+ def __init__(
155
+ self,
156
+ hidden_size=1024,
157
+ num_heads=16,
158
+ num_layers=16,
159
+ config=LlamaConfig(0, 256, 1024, 1, 1),
160
+ ):
161
+
162
+ super().__init__(config)
163
+ self.layers = nn.ModuleList(
164
+ [
165
+ LlamaNARDecoderLayer(
166
+ config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4),
167
+ layer_idx=i,
168
+ )
169
+ for i in range(num_layers)
170
+ ]
171
+ )
172
+
173
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size)
174
+
175
+ self.multi_embedding = MultiEmbedding(
176
+ num_quantization_layers=8, embedding_dim=hidden_size
177
+ )
178
+
179
+
180
+ self.post_init()
181
+
182
+ def _prepare_decoder_attention_mask(
183
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
184
+ ):
185
+ # create noncausal mask
186
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
187
+ combined_attention_mask = None
188
+
189
+ def _expand_mask(
190
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
191
+ ):
192
+ """
193
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
194
+ """
195
+ bsz, src_len = mask.size()
196
+ tgt_len = tgt_len if tgt_len is not None else src_len
197
+
198
+ expanded_mask = (
199
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
200
+ )
201
+
202
+ inverted_mask = 1.0 - expanded_mask
203
+
204
+ return inverted_mask.masked_fill(
205
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
206
+ )
207
+
208
+ if attention_mask is not None:
209
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
210
+ expanded_attn_mask = _expand_mask(
211
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
212
+ ).to(inputs_embeds.device)
213
+ combined_attention_mask = (
214
+ expanded_attn_mask
215
+ if combined_attention_mask is None
216
+ else expanded_attn_mask + combined_attention_mask
217
+ )
218
+
219
+ return combined_attention_mask
220
+
221
+ def forward(
222
+ self,
223
+ input_ids: torch.LongTensor = None,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ position_ids: Optional[torch.LongTensor] = None,
226
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
227
+ inputs_embeds: Optional[torch.FloatTensor] = None,
228
+ use_cache: Optional[bool] = None,
229
+ output_attentions: Optional[bool] = None,
230
+ output_hidden_states: Optional[bool] = None,
231
+ return_dict: Optional[bool] = None,
232
+ cache_position: Optional[torch.LongTensor] = None,
233
+ length: Optional[torch.LongTensor] = None,
234
+ )-> Union[Tuple, BaseModelOutputWithPast]:
235
+
236
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
237
+ output_hidden_states = (
238
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
239
+ )
240
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
241
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
242
+
243
+
244
+ batch_size, seq_length, num_quant = input_ids.shape
245
+ input_ids = input_ids.permute(2, 0, 1) # [num_quant, B, T]
246
+ inputs_embeds = self.multi_embedding(input_ids)
247
+
248
+ seq_length_with_past = seq_length
249
+ past_key_values_length = 0
250
+
251
+ if past_key_values is not None:
252
+ past_key_values_length = past_key_values[0][0].shape[2]
253
+ seq_length_with_past = seq_length_with_past + past_key_values_length
254
+
255
+ if position_ids is None:
256
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
257
+ position_ids = torch.arange(
258
+ past_key_values_length,
259
+ seq_length + past_key_values_length,
260
+ dtype=torch.long,
261
+ device=device,
262
+ )
263
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
264
+ else:
265
+ position_ids = position_ids.view(-1, seq_length).long()
266
+
267
+ # embed positions
268
+ if attention_mask is None:
269
+ attention_mask = torch.ones(
270
+ (batch_size, seq_length_with_past),
271
+ dtype=torch.bool,
272
+ device=inputs_embeds.device,
273
+ )
274
+ attention_mask = self._prepare_decoder_attention_mask(
275
+ attention_mask,
276
+ (batch_size, seq_length),
277
+ inputs_embeds,
278
+ past_key_values_length,
279
+ )
280
+
281
+ hidden_states = inputs_embeds
282
+
283
+ if self.gradient_checkpointing and self.training:
284
+ if use_cache:
285
+ use_cache = False
286
+
287
+ # decoder layers
288
+ all_hidden_states = () if output_hidden_states else None
289
+ all_self_attns = () if output_attentions else None
290
+ next_decoder_cache = () if use_cache else None
291
+
292
+ for idx, decoder_layer in enumerate(self.layers):
293
+ if output_hidden_states:
294
+ all_hidden_states += (hidden_states,)
295
+
296
+ past_key_value = (
297
+ past_key_values[idx] if past_key_values is not None else None
298
+ )
299
+
300
+ if self.gradient_checkpointing and self.training:
301
+ raise NotImplementedError
302
+
303
+ def create_custom_forward(module):
304
+ def custom_forward(*inputs):
305
+ # None for past_key_value
306
+ return module(*inputs, output_attentions, None)
307
+
308
+ return custom_forward
309
+
310
+ layer_outputs = torch.utils.checkpoint.checkpoint(
311
+ create_custom_forward(decoder_layer),
312
+ hidden_states,
313
+ attention_mask,
314
+ position_ids,
315
+ None,
316
+ )
317
+ else:
318
+ layer_outputs = decoder_layer(
319
+ hidden_states,
320
+ attention_mask=attention_mask,
321
+ position_ids=position_ids,
322
+ past_key_value=past_key_value,
323
+ output_attentions=output_attentions,
324
+ use_cache=use_cache,
325
+ )
326
+
327
+ hidden_states = layer_outputs[0]
328
+
329
+ if use_cache:
330
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
331
+
332
+ if output_attentions:
333
+ all_self_attns += (layer_outputs[1],)
334
+
335
+ hidden_states = self.norm(hidden_states)
336
+
337
+ # add hidden states from the last decoder layer
338
+ if output_hidden_states:
339
+ all_hidden_states += (hidden_states,)
340
+
341
+ next_cache = next_decoder_cache if use_cache else None
342
+
343
+ return hidden_states
344
+
345
+ class LlamaNAREmb(LlamaModel):
346
+ """LlamaNAR model that works directly with embeddings input.
347
+
348
+ This variant of LlamaNAR takes pre-computed embeddings as input
349
+ instead of token IDs that need to be embedded.
350
+ """
351
+ def __init__(
352
+ self,
353
+ hidden_size=1024,
354
+ num_heads=16,
355
+ num_layers=16,
356
+ config=LlamaConfig(0, 256, 1024, 1, 1),
357
+ ):
358
+
359
+ super().__init__(config)
360
+ self.layers = nn.ModuleList(
361
+ [
362
+ LlamaNARDecoderLayer(
363
+ config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4),
364
+ layer_idx=i,
365
+ )
366
+ for i in range(num_layers)
367
+ ]
368
+ )
369
+
370
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size)
371
+
372
+
373
+ self.post_init()
374
+
375
+ def _prepare_decoder_attention_mask(
376
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
377
+ ):
378
+ # create noncausal mask
379
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
380
+ combined_attention_mask = None
381
+
382
+ def _expand_mask(
383
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
384
+ ):
385
+ """
386
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
387
+ """
388
+ bsz, src_len = mask.size()
389
+ tgt_len = tgt_len if tgt_len is not None else src_len
390
+
391
+ expanded_mask = (
392
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
393
+ )
394
+
395
+ inverted_mask = 1.0 - expanded_mask
396
+
397
+ return inverted_mask.masked_fill(
398
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
399
+ )
400
+
401
+ if attention_mask is not None:
402
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
403
+ expanded_attn_mask = _expand_mask(
404
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
405
+ ).to(inputs_embeds.device)
406
+ combined_attention_mask = (
407
+ expanded_attn_mask
408
+ if combined_attention_mask is None
409
+ else expanded_attn_mask + combined_attention_mask
410
+ )
411
+
412
+ return combined_attention_mask
413
+
414
+ def forward(
415
+ self,
416
+ input_ids: torch.LongTensor = None,
417
+ attention_mask: Optional[torch.Tensor] = None,
418
+ position_ids: Optional[torch.LongTensor] = None,
419
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
420
+ inputs_embeds: Optional[torch.FloatTensor] = None,
421
+ use_cache: Optional[bool] = None,
422
+ output_attentions: Optional[bool] = None,
423
+ output_hidden_states: Optional[bool] = None,
424
+ return_dict: Optional[bool] = None,
425
+ cache_position: Optional[torch.LongTensor] = None,
426
+ )-> torch.Tensor:
427
+ """
428
+ Returns:
429
+ hidden_states: Tensor of shape (batch_size, sequence_length, hidden_size)
430
+ """
431
+
432
+ if inputs_embeds is None:
433
+ raise ValueError("inputs_embeds must be provided for LlamaNAREmb")
434
+
435
+ if input_ids is not None:
436
+ warnings.warn("input_ids is ignored in LlamaNAREmb, use inputs_embeds instead")
437
+
438
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
439
+ output_hidden_states = (
440
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
441
+ )
442
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
443
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
444
+
445
+ batch_size, seq_length, hidden_size = inputs_embeds.shape
446
+
447
+ seq_length_with_past = seq_length
448
+ past_key_values_length = 0
449
+
450
+ if past_key_values is not None:
451
+ past_key_values_length = past_key_values[0][0].shape[2]
452
+ seq_length_with_past = seq_length_with_past + past_key_values_length
453
+
454
+ if position_ids is None:
455
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
456
+ position_ids = torch.arange(
457
+ past_key_values_length,
458
+ seq_length + past_key_values_length,
459
+ dtype=torch.long,
460
+ device=device,
461
+ )
462
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
463
+ else:
464
+ position_ids = position_ids.view(-1, seq_length).long()
465
+
466
+ # embed positions
467
+ if attention_mask is None:
468
+ attention_mask = torch.ones(
469
+ (batch_size, seq_length_with_past),
470
+ dtype=torch.bool,
471
+ device=inputs_embeds.device,
472
+ )
473
+ attention_mask = self._prepare_decoder_attention_mask(
474
+ attention_mask,
475
+ (batch_size, seq_length),
476
+ inputs_embeds,
477
+ past_key_values_length,
478
+ )
479
+
480
+ hidden_states = inputs_embeds
481
+
482
+ if self.gradient_checkpointing and self.training:
483
+ if use_cache:
484
+ use_cache = False
485
+
486
+ # decoder layers
487
+ all_hidden_states = () if output_hidden_states else None
488
+ all_self_attns = () if output_attentions else None
489
+ next_decoder_cache = () if use_cache else None
490
+
491
+ for idx, decoder_layer in enumerate(self.layers):
492
+ if output_hidden_states:
493
+ all_hidden_states += (hidden_states,)
494
+
495
+ past_key_value = (
496
+ past_key_values[idx] if past_key_values is not None else None
497
+ )
498
+
499
+ if self.gradient_checkpointing and self.training:
500
+ raise NotImplementedError
501
+
502
+ def create_custom_forward(module):
503
+ def custom_forward(*inputs):
504
+ # None for past_key_value
505
+ return module(*inputs, output_attentions, None)
506
+
507
+ return custom_forward
508
+
509
+ layer_outputs = torch.utils.checkpoint.checkpoint(
510
+ create_custom_forward(decoder_layer),
511
+ hidden_states,
512
+ attention_mask,
513
+ position_ids,
514
+ None,
515
+ )
516
+ else:
517
+ layer_outputs = decoder_layer(
518
+ hidden_states,
519
+ attention_mask=attention_mask,
520
+ position_ids=position_ids,
521
+ past_key_value=past_key_value,
522
+ output_attentions=output_attentions,
523
+ use_cache=use_cache,
524
+ )
525
+
526
+ hidden_states = layer_outputs[0]
527
+
528
+ if use_cache:
529
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
530
+
531
+ if output_attentions:
532
+ all_self_attns += (layer_outputs[1],)
533
+
534
+ hidden_states = self.norm(hidden_states)
535
+
536
+ # add hidden states from the last decoder layer
537
+ if output_hidden_states:
538
+ all_hidden_states += (hidden_states,)
539
+
540
+ next_cache = next_decoder_cache if use_cache else None
541
+
542
+
543
+ return hidden_states
544
+
545
+ if __name__ == '__main__':
546
+ config = LlamaConfig(hidden_size=1024, num_attention_heads=8, num_hidden_layers=8)
547
+
548
+ model = LlamaNAR(config=config)
549
+
550
+ # 模拟输入数据
551
+ batch_size = 2
552
+ seq_length = 10
553
+ n_q = 8
554
+ input_ids = torch.randint(0, 1028, (batch_size, seq_length, n_q)) # 随机生成输入ID
555
+ inputs_embeds = torch.randn(batch_size, seq_length, config.hidden_size) # 随机生成输入嵌入
556
+ attention_mask = torch.ones(batch_size, seq_length) # 所有位置可见
557
+ length = torch.tensor([4,10]) # 输入长度
558
+
559
+ # 前向传播
560
+ hidden_states, class_out = model(
561
+ input_ids=input_ids,
562
+ attention_mask=attention_mask,
563
+ output_attentions=True,
564
+ output_hidden_states=True,
565
+ length=length
566
+ )
567
+
568
+ # 打印输出形状
569
+ print("Hidden States Shape:", hidden_states.shape) # 输出隐藏状态形状
570
+ print('Class output Shape:', class_out.shape)
571
+
model.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Wav2Vec2BertModel
4
+ from llama_nar import LlamaNAREmb
5
+ from transformers import LlamaConfig
6
+ import time
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class Wav2Vec2BERT_Llama(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ # 1. 加载预训练模型
15
+ self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("/mntcephfs/lab_data/wangli/pretrain/w2v-bert-2.0/", output_hidden_states=True)
16
+
17
+ # 2. 选择性冻结参数
18
+ for name, param in self.wav2vec2bert.named_parameters():
19
+ # 冻结所有FFN1 (保留FFN2的适应能力)
20
+ if 'ffn1' in name:
21
+ param.requires_grad = False
22
+
23
+ # 冻结多头注意力中的K,V投影
24
+ if any(proj in name for proj in ['linear_k', 'linear_v']):
25
+ param.requires_grad = False
26
+
27
+ # 冻结distance_embedding
28
+ if 'distance_embedding' in name:
29
+ param.requires_grad = False
30
+
31
+ # 冻结所有卷积相关模块
32
+ if any(conv_name in name for conv_name in [
33
+ 'conv_module', 'pointwise_conv', 'depthwise_conv',
34
+ 'feature_extractor', 'pos_conv_embed', 'conv_layers'
35
+ ]):
36
+ param.requires_grad = False
37
+
38
+ # 3. 减小Llama模型规模
39
+ self.llama_nar = LlamaNAREmb(
40
+ config=LlamaConfig(
41
+ hidden_size=512,
42
+ num_attention_heads=8,
43
+ num_hidden_layers=8,
44
+ ),
45
+ num_heads=8,
46
+ num_layers=8,
47
+ hidden_size=512
48
+ )
49
+
50
+ # 4. 降维投影层
51
+ self.projection = nn.Sequential(
52
+ nn.Linear(1024, 512),
53
+ nn.LayerNorm(512)
54
+ )
55
+
56
+ # 5. 简化分类头
57
+ self.classifier = nn.Sequential(
58
+ nn.Linear(512, 128),
59
+ nn.ReLU(),
60
+ nn.Dropout(0.1),
61
+ nn.Linear(128, 2)
62
+ )
63
+
64
+ # 6. 减小embedding维度
65
+ self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512)
66
+
67
+ # 7. 简化特征处理层
68
+ self.feature_processor = nn.Sequential(
69
+ nn.Linear(512, 512),
70
+ nn.LayerNorm(512),
71
+ nn.ReLU(),
72
+ nn.Dropout(0.1)
73
+ )
74
+
75
+ # 8. 减小特殊token的维度
76
+ self.special_tokens = nn.Parameter(torch.randn(4, 512))
77
+
78
+ def _fuse_layers(self, hidden_states):
79
+ # 修改特征融合方法
80
+ def downsample_sequence(sequence, factor=10):
81
+ """对序列进行下采样"""
82
+ batch_size, seq_len, hidden_size = sequence.shape
83
+ # 确保序列长度可以被因子整除
84
+ new_len = seq_len // factor
85
+ padded_len = new_len * factor
86
+
87
+ if seq_len > padded_len:
88
+ sequence = sequence[:, :padded_len, :]
89
+
90
+ # 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size]
91
+ reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size)
92
+ downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size]
93
+ return downsampled
94
+
95
+ # 1. 获取最后一层特征并进行下采样
96
+ last_layer = hidden_states[-1] # [batch_size, seq_len, 1024]
97
+ downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024]
98
+
99
+ # 2. 投影到512维度
100
+ projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512]
101
+
102
+ return projected_features # 不再需要unsqueeze,因为已经保留了序列维度
103
+
104
+ def forward(self, batch):
105
+ main_output = self.wav2vec2bert(
106
+ **batch['main_features']
107
+ )
108
+
109
+ fused_features = self._fuse_layers(main_output.hidden_states)
110
+ fused_features = self.feature_processor(fused_features)
111
+
112
+ if ('prompt_labels' in batch and
113
+ batch['prompt_labels'] is not None and
114
+ 'prompt_features' in batch and
115
+ batch['prompt_features'] and
116
+ len(batch['prompt_features']) > 0):
117
+
118
+ batch_size, num_prompts = batch['prompt_labels'].shape
119
+
120
+ # 重塑特征以批量处理
121
+ prompt_features = batch['prompt_features']
122
+ all_prompt_outputs = []
123
+
124
+ for i in range(num_prompts):
125
+ prompt_output = self.wav2vec2bert(
126
+ **prompt_features[i]
127
+ )
128
+ all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states))
129
+
130
+ if all_prompt_outputs:
131
+ fused_prompts = torch.stack([
132
+ self.feature_processor(p) for p in all_prompt_outputs
133
+ ], dim=1) # [batch_size, num_prompts, seq_len, hidden_size]
134
+
135
+ # 获取label embeddings并扩展到对应序列长度
136
+ label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512]
137
+
138
+ prompt_embeddings = []
139
+ for i in range(batch_size):
140
+ sequence = []
141
+
142
+ # 添加示例prompts
143
+ for j in range(num_prompts):
144
+ prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度
145
+
146
+ sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
147
+ sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
148
+ sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size]
149
+ sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
150
+
151
+ # 扩展label embedding到与音频特征相同的长度
152
+ expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1)
153
+ sequence.append(expanded_label) # [seq_len, hidden_size]
154
+
155
+ sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP]
156
+
157
+ # 添加待预测的主特征
158
+ main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度
159
+ sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
160
+ sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
161
+ sequence.append(fused_features[i]) # [main_seq_len, hidden_size]
162
+ sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
163
+ # 预测位置使用零向量,长度与主特征相同
164
+ sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device))
165
+
166
+ prompt_embeddings.append(torch.cat(sequence, dim=0))
167
+
168
+ prompt_embeddings = torch.stack(prompt_embeddings, dim=0)
169
+
170
+ else:
171
+ # 简化无prompt情况的处理
172
+ batch_size = fused_features.size(0)
173
+ main_seq_len = fused_features.size(1) # 直接获取主特征序列长度
174
+
175
+ # 构建序列 [batch_size, total_len, hidden_size]
176
+ prompt_embeddings = torch.cat([
177
+ self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT]
178
+ self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO]
179
+ fused_features, # [batch_size, main_seq_len, hidden_size]
180
+ self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL]
181
+ torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置
182
+ ], dim=1)
183
+
184
+ # 输入到llama_nar
185
+ output = self.llama_nar(inputs_embeds=prompt_embeddings)
186
+
187
+ # 获取所有预测位置的输出(即最后main_seq_len个位置)
188
+ pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size]
189
+ # 对每一帧进行分类
190
+ frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2]
191
+
192
+ # 同时返回帧级别的logits和整体的logits(通过平均得到)
193
+ avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size]
194
+ avg_logits = self.classifier(avg_embedding) # [batch_size, 2]
195
+
196
+ return {
197
+ 'frame_logits': frame_logits, # 每一帧的预测分数
198
+ 'avg_logits': avg_logits # 整体的预测分数
199
+ }
200
+
201
+
202
+ if __name__ == '__main__':
203
+ import torch
204
+ from torch.utils.data import DataLoader
205
+ from dataset.train_MultiDataset import train_MultiDataset, collate_fn
206
+ from tqdm import tqdm
207
+ import time
208
+
209
+ # 设置设备
210
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
211
+ print(f"\n=== 使用设备: {device} ===")
212
+
213
+ # 初始化模型
214
+ print("\n=== 初始化模型 ===")
215
+ model = Wav2Vec2BERT_Llama().to(device)
216
+ model.eval() # 设置为评估模式
217
+
218
+ # 打印wav2vec2bert的参数结构
219
+ print("\n=== Wav2Vec2BERT 参数结构 ===")
220
+ w2v_params_by_layer = {}
221
+ total_trainable = 0
222
+ total_frozen = 0
223
+
224
+ for name, param in model.wav2vec2bert.named_parameters():
225
+ # 获取主要层名称
226
+ layer_name = name.split('.')[0]
227
+ if layer_name not in w2v_params_by_layer:
228
+ w2v_params_by_layer[layer_name] = {
229
+ 'trainable_params': 0,
230
+ 'frozen_params': 0,
231
+ 'parameter_names': []
232
+ }
233
+
234
+ # 统计参数
235
+ if param.requires_grad:
236
+ w2v_params_by_layer[layer_name]['trainable_params'] += param.numel()
237
+ total_trainable += param.numel()
238
+ else:
239
+ w2v_params_by_layer[layer_name]['frozen_params'] += param.numel()
240
+ total_frozen += param.numel()
241
+
242
+ w2v_params_by_layer[layer_name]['parameter_names'].append(name)
243
+
244
+ # 打印每层的详细信息
245
+ print("\n各层参数统计:")
246
+ for layer_name, info in w2v_params_by_layer.items():
247
+ trainable_mb = info['trainable_params'] / 1024 / 1024
248
+ frozen_mb = info['frozen_params'] / 1024 / 1024
249
+ total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024
250
+
251
+ print(f"\n{layer_name}:")
252
+ print(f" - 总参数量: {total_mb:.2f}MB")
253
+ print(f" - 可训练参数: {trainable_mb:.2f}MB")
254
+ print(f" - 冻结参数: {frozen_mb:.2f}MB")
255
+ print(f" - 参数名称:")
256
+ for param_name in info['parameter_names']:
257
+ print(f" * {param_name}")
258
+
259
+ # 打印总体统计
260
+ print("\n=== 总体统计 ===")
261
+ print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB")
262
+ print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB")
263
+ print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB")
264
+ print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%")
265
+
266
+ # 分别统计各个模块的参数量
267
+ wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters())
268
+ llama_params = sum(p.numel() for p in model.llama_nar.parameters())
269
+ other_params = sum(p.numel() for name, p in model.named_parameters()
270
+ if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.'))
271
+
272
+ total_params = wav2vec2bert_params + llama_params + other_params
273
+
274
+ print(f"\n=== 参数量统计 ===")
275
+ print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)")
276
+ print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)")
277
+ print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)")
278
+ print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)")
279
+
280
+ # 计算百分比
281
+ print(f"\n=== 参数量占比 ===")
282
+ print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%")
283
+ print(f"LlamaNAR: {llama_params/total_params*100:.2f}%")
284
+ print(f"其他模块: {other_params/total_params*100:.2f}%")
285
+
286
+ # 测试运行时间和内存使用
287
+ print("\n=== 测试运行时间和内存使用 (batch_size=4) ===")
288
+ batch_size = 4
289
+ total_samples = 600000
290
+
291
+ # 清空GPU缓存
292
+ if torch.cuda.is_available():
293
+ torch.cuda.empty_cache()
294
+ initial_memory = torch.cuda.memory_allocated() / 1024 / 1024
295
+ print(f"初始GPU内存使用: {initial_memory:.2f}MB")
296
+
297
+ # 初始化数据集
298
+ print("\n初始化数据集...")
299
+ ds = train_MultiDataset(max_prompts=3)
300
+
301
+ # 创建DataLoader
302
+ dl = DataLoader(ds,
303
+ batch_size=batch_size,
304
+ shuffle=True,
305
+ collate_fn=collate_fn,
306
+ num_workers=4)
307
+
308
+ print(f"\n数据集大小: {len(ds)}")
309
+ print(f"批次数量: {len(dl)}")
310
+
311
+ # 计算一个batch的平均时间
312
+ num_test_batches = 10
313
+ total_time = 0
314
+ max_memory = 0
315
+
316
+ print(f"\n测试{num_test_batches}个batch的平均运行时间...")
317
+ with torch.no_grad():
318
+ for i, batch in enumerate(tqdm(dl, total=num_test_batches)):
319
+ if i >= num_test_batches:
320
+ break
321
+
322
+ # 正确处理字典类型的特征
323
+ main_features = {
324
+ 'input_features': batch['main_features']['input_features'].to(device),
325
+ 'attention_mask': batch['main_features']['attention_mask'].to(device)
326
+ }
327
+
328
+ prompt_features = [{
329
+ 'input_features': pf['input_features'].to(device),
330
+ 'attention_mask': pf['attention_mask'].to(device)
331
+ } for pf in batch['prompt_features']]
332
+
333
+ labels = batch['labels'].to(device)
334
+ prompt_labels = batch['prompt_labels'].to(device)
335
+
336
+ # 记录开始时间
337
+ start_time = time.time()
338
+
339
+ # 前向传播
340
+ outputs = model({
341
+ 'main_features': main_features,
342
+ 'prompt_features': prompt_features,
343
+ 'prompt_labels': prompt_labels
344
+ })
345
+
346
+ # 确保GPU运算完成
347
+ if torch.cuda.is_available():
348
+ torch.cuda.synchronize()
349
+
350
+ # 记录结束时间和内存使用
351
+ end_time = time.time()
352
+ total_time += (end_time - start_time)
353
+
354
+ if torch.cuda.is_available():
355
+ current_memory = torch.cuda.memory_allocated() / 1024 / 1024
356
+ max_memory = max(max_memory, current_memory)
357
+
358
+ # 打印第一���batch的详细信息
359
+ if i == 0:
360
+ print("\n=== 第一个Batch的详细信息 ===")
361
+ print(f"主特征形状: {main_features['input_features'].shape}")
362
+ print(f"主掩码形状: {main_features['attention_mask'].shape}")
363
+ print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}")
364
+ print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}")
365
+ print(f"标签形状: {labels.shape}")
366
+ print(f"Prompt标签形状: {prompt_labels.shape}")
367
+ print(f"模型输出形状: {outputs.shape}")
368
+ print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]")
369
+
370
+ # 计算和打印统计信息
371
+ avg_time = total_time / num_test_batches
372
+ print(f"\n=== 性能统计 ===")
373
+ print(f"平均每个batch处理时间: {avg_time:.4f}秒")
374
+ print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时")
375
+ if torch.cuda.is_available():
376
+ print(f"最大GPU内存使用: {max_memory:.2f}MB")
377
+ print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB")
378
+
379
+ print("\n测试完成!")