File size: 16,240 Bytes
34146f0
 
 
 
 
 
 
53d1d01
34146f0
 
 
 
 
 
 
46689af
34146f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import torch
import torch.nn as nn
from transformers import Wav2Vec2BertModel
from llama_nar import LlamaNAREmb
from transformers import LlamaConfig
import time
import torch.nn.functional as F
from huggingface_hub import hf_hub_download


class Wav2Vec2BERT_Llama(nn.Module):
    def __init__(self):
        super().__init__()

        # 1. 加载预训练模型
        self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True)
        
        # 2. 选择性冻结参数
        for name, param in self.wav2vec2bert.named_parameters():
            # 冻结所有FFN1 (保留FFN2的适应能力)
            if 'ffn1' in name:
                param.requires_grad = False
            
            # 冻结多头注意力中的K,V投影
            if any(proj in name for proj in ['linear_k', 'linear_v']):
                param.requires_grad = False
            
            # 冻结distance_embedding
            if 'distance_embedding' in name:
                param.requires_grad = False
            
            # 冻结所有卷积相关模块
            if any(conv_name in name for conv_name in [
                'conv_module', 'pointwise_conv', 'depthwise_conv',
                'feature_extractor', 'pos_conv_embed', 'conv_layers'
            ]):
                param.requires_grad = False
        
        # 3. 减小Llama模型规模
        self.llama_nar = LlamaNAREmb(
            config=LlamaConfig(
                hidden_size=512,
                num_attention_heads=8,
                num_hidden_layers=8,
            ), 
            num_heads=8, 
            num_layers=8,
            hidden_size=512
        )
        
        # 4. 降维投影层
        self.projection = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LayerNorm(512)
        )
        
        # 5. 简化分类头
        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 2)
        )
        
        # 6. 减小embedding维度
        self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512)
        
        # 7. 简化特征处理层
        self.feature_processor = nn.Sequential(
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # 8. 减小特殊token的维度
        self.special_tokens = nn.Parameter(torch.randn(4, 512))

    def _fuse_layers(self, hidden_states):
        # 修改特征融合方法
        def downsample_sequence(sequence, factor=10):
            """对序列进行下采样"""
            batch_size, seq_len, hidden_size = sequence.shape
            # 确保序列长度可以被因子整除
            new_len = seq_len // factor
            padded_len = new_len * factor
            
            if seq_len > padded_len:
                sequence = sequence[:, :padded_len, :]
            
            # 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size]
            reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size)
            downsampled = torch.mean(reshaped, dim=2)  # [batch_size, new_len, hidden_size]
            return downsampled

        # 1. 获取最后一层特征并进行下采样
        last_layer = hidden_states[-1]  # [batch_size, seq_len, 1024]
        downsampled_features = downsample_sequence(last_layer)  # [batch_size, seq_len//10, 1024]
        
        # 2. 投影到512维度
        projected_features = self.projection(downsampled_features)  # [batch_size, seq_len//10, 512]
        
        return projected_features  # 不再需要unsqueeze,因为已经保留了序列维度

    def forward(self, batch):
        main_output = self.wav2vec2bert(
            **batch['main_features']
        )
        
        fused_features = self._fuse_layers(main_output.hidden_states)
        fused_features = self.feature_processor(fused_features)
        
        if ('prompt_labels' in batch and 
            batch['prompt_labels'] is not None and 
            'prompt_features' in batch and 
            batch['prompt_features'] and 
            len(batch['prompt_features']) > 0):
            
            batch_size, num_prompts = batch['prompt_labels'].shape
            
            # 重塑特征以批量处理
            prompt_features = batch['prompt_features']
            all_prompt_outputs = []
            
            for i in range(num_prompts):
                prompt_output = self.wav2vec2bert(
                    **prompt_features[i]
                )
                all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states))
            
            if all_prompt_outputs:
                fused_prompts = torch.stack([
                    self.feature_processor(p) for p in all_prompt_outputs
                ], dim=1)  # [batch_size, num_prompts, seq_len, hidden_size]
                
                # 获取label embeddings并扩展到对应序列长度
                label_embs = self.label_embedding(batch['prompt_labels'])  # [batch_size, num_prompts, 512]
                
                prompt_embeddings = []
                for i in range(batch_size):
                    sequence = []
                    
                    # 添加示例prompts
                    for j in range(num_prompts):
                        prompt_seq_len = fused_prompts[i, j].size(0)  # 获取当前prompt的序列长度
                        
                        sequence.append(self.special_tokens[1].expand(1, -1))  # [PROMPT]
                        sequence.append(self.special_tokens[2].expand(1, -1))  # [AUDIO]
                        sequence.append(fused_prompts[i, j])  # [seq_len, hidden_size]
                        sequence.append(self.special_tokens[3].expand(1, -1))  # [LABEL]
                        
                        # 扩展label embedding到与音频特征相同的长度
                        expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1)
                        sequence.append(expanded_label)  # [seq_len, hidden_size]
                        
                        sequence.append(self.special_tokens[0].expand(1, -1))  # [SEP]
                    
                    # 添加待预测的主特征
                    main_seq_len = fused_features[i].size(0)  # 获取主特征的序列长度
                    sequence.append(self.special_tokens[1].expand(1, -1))  # [PROMPT]
                    sequence.append(self.special_tokens[2].expand(1, -1))  # [AUDIO]
                    sequence.append(fused_features[i])  # [main_seq_len, hidden_size]
                    sequence.append(self.special_tokens[3].expand(1, -1))  # [LABEL]
                    # 预测位置使用零向量,长度与主特征相同
                    sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device))
                    
                    prompt_embeddings.append(torch.cat(sequence, dim=0))
                
                prompt_embeddings = torch.stack(prompt_embeddings, dim=0)
            
        else:
            # 简化无prompt情况的处理
            batch_size = fused_features.size(0)
            main_seq_len = fused_features.size(1)  # 直接获取主特征序列长度
            
            # 构建序列 [batch_size, total_len, hidden_size]
            prompt_embeddings = torch.cat([
                self.special_tokens[1].expand(batch_size, 1, -1),  # [PROMPT]
                self.special_tokens[2].expand(batch_size, 1, -1),  # [AUDIO]
                fused_features,  # [batch_size, main_seq_len, hidden_size]
                self.special_tokens[3].expand(batch_size, 1, -1),  # [LABEL]
                torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device)  # 预测位置
            ], dim=1)
        
        # 输入到llama_nar
        output = self.llama_nar(inputs_embeds=prompt_embeddings)
        
        # 获取所有预测位置的输出(即最后main_seq_len个位置)
        pred_pos_embeddings = output[:, -main_seq_len:, :]  # [batch_size, main_seq_len, hidden_size]
        # 对每一帧进行分类
        frame_logits = self.classifier(pred_pos_embeddings)  # [batch_size, main_seq_len, 2]
        
        # 同时返回帧级别的logits和整体的logits(通过平均得到)
        avg_embedding = torch.mean(pred_pos_embeddings, dim=1)  # [batch_size, hidden_size]
        avg_logits = self.classifier(avg_embedding)  # [batch_size, 2]
        
        return {
            'frame_logits': frame_logits,  # 每一帧的预测分数
            'avg_logits': avg_logits       # 整体的预测分数
        }


if __name__ == '__main__':
    import torch
    from torch.utils.data import DataLoader
    from dataset.train_MultiDataset import train_MultiDataset, collate_fn
    from tqdm import tqdm
    import time
    
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\n=== 使用设备: {device} ===")
    
    # 初始化模型
    print("\n=== 初始化模型 ===")
    model = Wav2Vec2BERT_Llama().to(device)
    model.eval()  # 设置为评估模式
    
    # 打印wav2vec2bert的参数结构
    print("\n=== Wav2Vec2BERT 参数结构 ===")
    w2v_params_by_layer = {}
    total_trainable = 0
    total_frozen = 0
    
    for name, param in model.wav2vec2bert.named_parameters():
        # 获取主要层名称
        layer_name = name.split('.')[0]
        if layer_name not in w2v_params_by_layer:
            w2v_params_by_layer[layer_name] = {
                'trainable_params': 0,
                'frozen_params': 0,
                'parameter_names': []
            }
        
        # 统计参数
        if param.requires_grad:
            w2v_params_by_layer[layer_name]['trainable_params'] += param.numel()
            total_trainable += param.numel()
        else:
            w2v_params_by_layer[layer_name]['frozen_params'] += param.numel()
            total_frozen += param.numel()
        
        w2v_params_by_layer[layer_name]['parameter_names'].append(name)
    
    # 打印每层的详细信息
    print("\n各层参数统计:")
    for layer_name, info in w2v_params_by_layer.items():
        trainable_mb = info['trainable_params'] / 1024 / 1024
        frozen_mb = info['frozen_params'] / 1024 / 1024
        total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024
        
        print(f"\n{layer_name}:")
        print(f"  - 总参数量: {total_mb:.2f}MB")
        print(f"  - 可训练参数: {trainable_mb:.2f}MB")
        print(f"  - 冻结参数: {frozen_mb:.2f}MB")
        print(f"  - 参数名称:")
        for param_name in info['parameter_names']:
            print(f"    * {param_name}")
    
    # 打印总体统计
    print("\n=== 总体统计 ===")
    print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB")
    print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB")
    print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB")
    print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%")
    
    # 分别统计各个模块的参数量
    wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters())
    llama_params = sum(p.numel() for p in model.llama_nar.parameters())
    other_params = sum(p.numel() for name, p in model.named_parameters() 
                      if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.'))
    
    total_params = wav2vec2bert_params + llama_params + other_params
    
    print(f"\n=== 参数量统计 ===")
    print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)")
    print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)")
    print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)")
    print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)")
    
    # 计算百分比
    print(f"\n=== 参数量占比 ===")
    print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%")
    print(f"LlamaNAR: {llama_params/total_params*100:.2f}%")
    print(f"其他模块: {other_params/total_params*100:.2f}%")
    
    # 测试运行时间和内存使用
    print("\n=== 测试运行时间和内存使用 (batch_size=4) ===")
    batch_size = 4
    total_samples = 600000
    
    # 清空GPU缓存
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        initial_memory = torch.cuda.memory_allocated() / 1024 / 1024
        print(f"初始GPU内存使用: {initial_memory:.2f}MB")
    
    # 初始化数据集
    print("\n初始化数据集...")
    ds = train_MultiDataset(max_prompts=3)
    
    # 创建DataLoader
    dl = DataLoader(ds, 
                   batch_size=batch_size,
                   shuffle=True,
                   collate_fn=collate_fn,
                   num_workers=4)
    
    print(f"\n数据集大小: {len(ds)}")
    print(f"批次数量: {len(dl)}")
    
    # 计算一个batch的平均时间
    num_test_batches = 10
    total_time = 0
    max_memory = 0
    
    print(f"\n测试{num_test_batches}个batch的平均运行时间...")
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dl, total=num_test_batches)):
            if i >= num_test_batches:
                break
            
            # 正确处理字典类型的特征
            main_features = {
                'input_features': batch['main_features']['input_features'].to(device),
                'attention_mask': batch['main_features']['attention_mask'].to(device)
            }
            
            prompt_features = [{
                'input_features': pf['input_features'].to(device),
                'attention_mask': pf['attention_mask'].to(device)
            } for pf in batch['prompt_features']]
            
            labels = batch['labels'].to(device)
            prompt_labels = batch['prompt_labels'].to(device)
            
            # 记录开始时间
            start_time = time.time()
            
            # 前向传播
            outputs = model({
                'main_features': main_features,
                'prompt_features': prompt_features,
                'prompt_labels': prompt_labels
            })
            
            # 确保GPU运算完成
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # 记录结束时间和内存使用
            end_time = time.time()
            total_time += (end_time - start_time)
            
            if torch.cuda.is_available():
                current_memory = torch.cuda.memory_allocated() / 1024 / 1024
                max_memory = max(max_memory, current_memory)
            
            # 打印第一个batch的详细信息
            if i == 0:
                print("\n=== 第一个Batch的详细信息 ===")
                print(f"主特征形状: {main_features['input_features'].shape}")
                print(f"主掩码形状: {main_features['attention_mask'].shape}")
                print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}")
                print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}")
                print(f"标签形状: {labels.shape}")
                print(f"Prompt标签形状: {prompt_labels.shape}")
                print(f"模型输出形状: {outputs.shape}")
                print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]")
    
    # 计算和打印统计信息
    avg_time = total_time / num_test_batches
    print(f"\n=== 性能统计 ===")
    print(f"平均每个batch处理时间: {avg_time:.4f}秒")
    print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时")
    if torch.cuda.is_available():
        print(f"最大GPU内存使用: {max_memory:.2f}MB")
        print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB")
    
    print("\n测试完成!")