import torch import torch.nn as nn from transformers import Wav2Vec2BertModel from llama_nar import LlamaNAREmb from transformers import LlamaConfig import time import torch.nn.functional as F from huggingface_hub import hf_hub_download class Wav2Vec2BERT_Llama(nn.Module): def __init__(self): super().__init__() # 1. 加载预训练模型 self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True) # 2. 选择性冻结参数 for name, param in self.wav2vec2bert.named_parameters(): # 冻结所有FFN1 (保留FFN2的适应能力) if 'ffn1' in name: param.requires_grad = False # 冻结多头注意力中的K,V投影 if any(proj in name for proj in ['linear_k', 'linear_v']): param.requires_grad = False # 冻结distance_embedding if 'distance_embedding' in name: param.requires_grad = False # 冻结所有卷积相关模块 if any(conv_name in name for conv_name in [ 'conv_module', 'pointwise_conv', 'depthwise_conv', 'feature_extractor', 'pos_conv_embed', 'conv_layers' ]): param.requires_grad = False # 3. 减小Llama模型规模 self.llama_nar = LlamaNAREmb( config=LlamaConfig( hidden_size=512, num_attention_heads=8, num_hidden_layers=8, ), num_heads=8, num_layers=8, hidden_size=512 ) # 4. 降维投影层 self.projection = nn.Sequential( nn.Linear(1024, 512), nn.LayerNorm(512) ) # 5. 简化分类头 self.classifier = nn.Sequential( nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.1), nn.Linear(128, 2) ) # 6. 减小embedding维度 self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512) # 7. 简化特征处理层 self.feature_processor = nn.Sequential( nn.Linear(512, 512), nn.LayerNorm(512), nn.ReLU(), nn.Dropout(0.1) ) # 8. 减小特殊token的维度 self.special_tokens = nn.Parameter(torch.randn(4, 512)) def _fuse_layers(self, hidden_states): # 修改特征融合方法 def downsample_sequence(sequence, factor=10): """对序列进行下采样""" batch_size, seq_len, hidden_size = sequence.shape # 确保序列长度可以被因子整除 new_len = seq_len // factor padded_len = new_len * factor if seq_len > padded_len: sequence = sequence[:, :padded_len, :] # 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size] reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size) downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size] return downsampled # 1. 获取最后一层特征并进行下采样 last_layer = hidden_states[-1] # [batch_size, seq_len, 1024] downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024] # 2. 投影到512维度 projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512] return projected_features # 不再需要unsqueeze,因为已经保留了序列维度 def forward(self, batch): main_output = self.wav2vec2bert( **batch['main_features'] ) fused_features = self._fuse_layers(main_output.hidden_states) fused_features = self.feature_processor(fused_features) if ('prompt_labels' in batch and batch['prompt_labels'] is not None and 'prompt_features' in batch and batch['prompt_features'] and len(batch['prompt_features']) > 0): batch_size, num_prompts = batch['prompt_labels'].shape # 重塑特征以批量处理 prompt_features = batch['prompt_features'] all_prompt_outputs = [] for i in range(num_prompts): prompt_output = self.wav2vec2bert( **prompt_features[i] ) all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states)) if all_prompt_outputs: fused_prompts = torch.stack([ self.feature_processor(p) for p in all_prompt_outputs ], dim=1) # [batch_size, num_prompts, seq_len, hidden_size] # 获取label embeddings并扩展到对应序列长度 label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512] prompt_embeddings = [] for i in range(batch_size): sequence = [] # 添加示例prompts for j in range(num_prompts): prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度 sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT] sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO] sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size] sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL] # 扩展label embedding到与音频特征相同的长度 expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1) sequence.append(expanded_label) # [seq_len, hidden_size] sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP] # 添加待预测的主特征 main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度 sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT] sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO] sequence.append(fused_features[i]) # [main_seq_len, hidden_size] sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL] # 预测位置使用零向量,长度与主特征相同 sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device)) prompt_embeddings.append(torch.cat(sequence, dim=0)) prompt_embeddings = torch.stack(prompt_embeddings, dim=0) else: # 简化无prompt情况的处理 batch_size = fused_features.size(0) main_seq_len = fused_features.size(1) # 直接获取主特征序列长度 # 构建序列 [batch_size, total_len, hidden_size] prompt_embeddings = torch.cat([ self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT] self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO] fused_features, # [batch_size, main_seq_len, hidden_size] self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL] torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置 ], dim=1) # 输入到llama_nar output = self.llama_nar(inputs_embeds=prompt_embeddings) # 获取所有预测位置的输出(即最后main_seq_len个位置) pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size] # 对每一帧进行分类 frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2] # 同时返回帧级别的logits和整体的logits(通过平均得到) avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size] avg_logits = self.classifier(avg_embedding) # [batch_size, 2] return { 'frame_logits': frame_logits, # 每一帧的预测分数 'avg_logits': avg_logits # 整体的预测分数 } if __name__ == '__main__': import torch from torch.utils.data import DataLoader from dataset.train_MultiDataset import train_MultiDataset, collate_fn from tqdm import tqdm import time # 设置设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"\n=== 使用设备: {device} ===") # 初始化模型 print("\n=== 初始化模型 ===") model = Wav2Vec2BERT_Llama().to(device) model.eval() # 设置为评估模式 # 打印wav2vec2bert的参数结构 print("\n=== Wav2Vec2BERT 参数结构 ===") w2v_params_by_layer = {} total_trainable = 0 total_frozen = 0 for name, param in model.wav2vec2bert.named_parameters(): # 获取主要层名称 layer_name = name.split('.')[0] if layer_name not in w2v_params_by_layer: w2v_params_by_layer[layer_name] = { 'trainable_params': 0, 'frozen_params': 0, 'parameter_names': [] } # 统计参数 if param.requires_grad: w2v_params_by_layer[layer_name]['trainable_params'] += param.numel() total_trainable += param.numel() else: w2v_params_by_layer[layer_name]['frozen_params'] += param.numel() total_frozen += param.numel() w2v_params_by_layer[layer_name]['parameter_names'].append(name) # 打印每层的详细信息 print("\n各层参数统计:") for layer_name, info in w2v_params_by_layer.items(): trainable_mb = info['trainable_params'] / 1024 / 1024 frozen_mb = info['frozen_params'] / 1024 / 1024 total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024 print(f"\n{layer_name}:") print(f" - 总参数量: {total_mb:.2f}MB") print(f" - 可训练参数: {trainable_mb:.2f}MB") print(f" - 冻结参数: {frozen_mb:.2f}MB") print(f" - 参数名称:") for param_name in info['parameter_names']: print(f" * {param_name}") # 打印总体统计 print("\n=== 总体统计 ===") print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB") print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB") print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB") print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%") # 分别统计各个模块的参数量 wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters()) llama_params = sum(p.numel() for p in model.llama_nar.parameters()) other_params = sum(p.numel() for name, p in model.named_parameters() if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.')) total_params = wav2vec2bert_params + llama_params + other_params print(f"\n=== 参数量统计 ===") print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)") print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)") print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)") print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)") # 计算百分比 print(f"\n=== 参数量占比 ===") print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%") print(f"LlamaNAR: {llama_params/total_params*100:.2f}%") print(f"其他模块: {other_params/total_params*100:.2f}%") # 测试运行时间和内存使用 print("\n=== 测试运行时间和内存使用 (batch_size=4) ===") batch_size = 4 total_samples = 600000 # 清空GPU缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() initial_memory = torch.cuda.memory_allocated() / 1024 / 1024 print(f"初始GPU内存使用: {initial_memory:.2f}MB") # 初始化数据集 print("\n初始化数据集...") ds = train_MultiDataset(max_prompts=3) # 创建DataLoader dl = DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4) print(f"\n数据集大小: {len(ds)}") print(f"批次数量: {len(dl)}") # 计算一个batch的平均时间 num_test_batches = 10 total_time = 0 max_memory = 0 print(f"\n测试{num_test_batches}个batch的平均运行时间...") with torch.no_grad(): for i, batch in enumerate(tqdm(dl, total=num_test_batches)): if i >= num_test_batches: break # 正确处理字典类型的特征 main_features = { 'input_features': batch['main_features']['input_features'].to(device), 'attention_mask': batch['main_features']['attention_mask'].to(device) } prompt_features = [{ 'input_features': pf['input_features'].to(device), 'attention_mask': pf['attention_mask'].to(device) } for pf in batch['prompt_features']] labels = batch['labels'].to(device) prompt_labels = batch['prompt_labels'].to(device) # 记录开始时间 start_time = time.time() # 前向传播 outputs = model({ 'main_features': main_features, 'prompt_features': prompt_features, 'prompt_labels': prompt_labels }) # 确保GPU运算完成 if torch.cuda.is_available(): torch.cuda.synchronize() # 记录结束时间和内存使用 end_time = time.time() total_time += (end_time - start_time) if torch.cuda.is_available(): current_memory = torch.cuda.memory_allocated() / 1024 / 1024 max_memory = max(max_memory, current_memory) # 打印第一个batch的详细信息 if i == 0: print("\n=== 第一个Batch的详细信息 ===") print(f"主特征形状: {main_features['input_features'].shape}") print(f"主掩码形状: {main_features['attention_mask'].shape}") print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}") print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}") print(f"标签形状: {labels.shape}") print(f"Prompt标签形状: {prompt_labels.shape}") print(f"模型输出形状: {outputs.shape}") print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]") # 计算和打印统计信息 avg_time = total_time / num_test_batches print(f"\n=== 性能统计 ===") print(f"平均每个batch处理时间: {avg_time:.4f}秒") print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时") if torch.cuda.is_available(): print(f"最大GPU内存使用: {max_memory:.2f}MB") print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB") print("\n测试完成!")