import torch | |
import torch.nn as nn | |
from transformers import Wav2Vec2BertModel | |
from llama_nar import LlamaNAREmb | |
from transformers import LlamaConfig | |
import time | |
import torch.nn.functional as F | |
class Wav2Vec2BERT_Llama(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# 1. 加载预训练模型 | |
self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("/mntcephfs/lab_data/wangli/pretrain/w2v-bert-2.0/", output_hidden_states=True) | |
# 2. 选择性冻结参数 | |
for name, param in self.wav2vec2bert.named_parameters(): | |
# 冻结所有FFN1 (保留FFN2的适应能力) | |
if 'ffn1' in name: | |
param.requires_grad = False | |
# 冻结多头注意力中的K,V投影 | |
if any(proj in name for proj in ['linear_k', 'linear_v']): | |
param.requires_grad = False | |
# 冻结distance_embedding | |
if 'distance_embedding' in name: | |
param.requires_grad = False | |
# 冻结所有卷积相关模块 | |
if any(conv_name in name for conv_name in [ | |
'conv_module', 'pointwise_conv', 'depthwise_conv', | |
'feature_extractor', 'pos_conv_embed', 'conv_layers' | |
]): | |
param.requires_grad = False | |
# 3. 减小Llama模型规模 | |
self.llama_nar = LlamaNAREmb( | |
config=LlamaConfig( | |
hidden_size=512, | |
num_attention_heads=8, | |
num_hidden_layers=8, | |
), | |
num_heads=8, | |
num_layers=8, | |
hidden_size=512 | |
) | |
# 4. 降维投影层 | |
self.projection = nn.Sequential( | |
nn.Linear(1024, 512), | |
nn.LayerNorm(512) | |
) | |
# 5. 简化分类头 | |
self.classifier = nn.Sequential( | |
nn.Linear(512, 128), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(128, 2) | |
) | |
# 6. 减小embedding维度 | |
self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512) | |
# 7. 简化特征处理层 | |
self.feature_processor = nn.Sequential( | |
nn.Linear(512, 512), | |
nn.LayerNorm(512), | |
nn.ReLU(), | |
nn.Dropout(0.1) | |
) | |
# 8. 减小特殊token的维度 | |
self.special_tokens = nn.Parameter(torch.randn(4, 512)) | |
def _fuse_layers(self, hidden_states): | |
# 修改特征融合方法 | |
def downsample_sequence(sequence, factor=10): | |
"""对序列进行下采样""" | |
batch_size, seq_len, hidden_size = sequence.shape | |
# 确保序列长度可以被因子整除 | |
new_len = seq_len // factor | |
padded_len = new_len * factor | |
if seq_len > padded_len: | |
sequence = sequence[:, :padded_len, :] | |
# 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size] | |
reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size) | |
downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size] | |
return downsampled | |
# 1. 获取最后一层特征并进行下采样 | |
last_layer = hidden_states[-1] # [batch_size, seq_len, 1024] | |
downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024] | |
# 2. 投影到512维度 | |
projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512] | |
return projected_features # 不再需要unsqueeze,因为已经保留了序列维度 | |
def forward(self, batch): | |
main_output = self.wav2vec2bert( | |
**batch['main_features'] | |
) | |
fused_features = self._fuse_layers(main_output.hidden_states) | |
fused_features = self.feature_processor(fused_features) | |
if ('prompt_labels' in batch and | |
batch['prompt_labels'] is not None and | |
'prompt_features' in batch and | |
batch['prompt_features'] and | |
len(batch['prompt_features']) > 0): | |
batch_size, num_prompts = batch['prompt_labels'].shape | |
# 重塑特征以批量处理 | |
prompt_features = batch['prompt_features'] | |
all_prompt_outputs = [] | |
for i in range(num_prompts): | |
prompt_output = self.wav2vec2bert( | |
**prompt_features[i] | |
) | |
all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states)) | |
if all_prompt_outputs: | |
fused_prompts = torch.stack([ | |
self.feature_processor(p) for p in all_prompt_outputs | |
], dim=1) # [batch_size, num_prompts, seq_len, hidden_size] | |
# 获取label embeddings并扩展到对应序列长度 | |
label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512] | |
prompt_embeddings = [] | |
for i in range(batch_size): | |
sequence = [] | |
# 添加示例prompts | |
for j in range(num_prompts): | |
prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度 | |
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT] | |
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO] | |
sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size] | |
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL] | |
# 扩展label embedding到与音频特征相同的长度 | |
expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1) | |
sequence.append(expanded_label) # [seq_len, hidden_size] | |
sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP] | |
# 添加待预测的主特征 | |
main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度 | |
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT] | |
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO] | |
sequence.append(fused_features[i]) # [main_seq_len, hidden_size] | |
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL] | |
# 预测位置使用零向量,长度与主特征相同 | |
sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device)) | |
prompt_embeddings.append(, dim=0)) | |
prompt_embeddings = torch.stack(prompt_embeddings, dim=0) | |
else: | |
# 简化无prompt情况的处理 | |
batch_size = fused_features.size(0) | |
main_seq_len = fused_features.size(1) # 直接获取主特征序列长度 | |
# 构建序列 [batch_size, total_len, hidden_size] | |
prompt_embeddings =[ | |
self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT] | |
self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO] | |
fused_features, # [batch_size, main_seq_len, hidden_size] | |
self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL] | |
torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置 | |
], dim=1) | |
# 输入到llama_nar | |
output = self.llama_nar(inputs_embeds=prompt_embeddings) | |
# 获取所有预测位置的输出(即最后main_seq_len个位置) | |
pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size] | |
# 对每一帧进行分类 | |
frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2] | |
# 同时返回帧级别的logits和整体的logits(通过平均得到) | |
avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size] | |
avg_logits = self.classifier(avg_embedding) # [batch_size, 2] | |
return { | |
'frame_logits': frame_logits, # 每一帧的预测分数 | |
'avg_logits': avg_logits # 整体的预测分数 | |
} | |
if __name__ == '__main__': | |
import torch | |
from import DataLoader | |
from dataset.train_MultiDataset import train_MultiDataset, collate_fn | |
from tqdm import tqdm | |
import time | |
# 设置设备 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"\n=== 使用设备: {device} ===") | |
# 初始化模型 | |
print("\n=== 初始化模型 ===") | |
model = Wav2Vec2BERT_Llama().to(device) | |
model.eval() # 设置为评估模式 | |
# 打印wav2vec2bert的参数结构 | |
print("\n=== Wav2Vec2BERT 参数结构 ===") | |
w2v_params_by_layer = {} | |
total_trainable = 0 | |
total_frozen = 0 | |
for name, param in model.wav2vec2bert.named_parameters(): | |
# 获取主要层名称 | |
layer_name = name.split('.')[0] | |
if layer_name not in w2v_params_by_layer: | |
w2v_params_by_layer[layer_name] = { | |
'trainable_params': 0, | |
'frozen_params': 0, | |
'parameter_names': [] | |
} | |
# 统计参数 | |
if param.requires_grad: | |
w2v_params_by_layer[layer_name]['trainable_params'] += param.numel() | |
total_trainable += param.numel() | |
else: | |
w2v_params_by_layer[layer_name]['frozen_params'] += param.numel() | |
total_frozen += param.numel() | |
w2v_params_by_layer[layer_name]['parameter_names'].append(name) | |
# 打印每层的详细信息 | |
print("\n各层参数统计:") | |
for layer_name, info in w2v_params_by_layer.items(): | |
trainable_mb = info['trainable_params'] / 1024 / 1024 | |
frozen_mb = info['frozen_params'] / 1024 / 1024 | |
total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024 | |
print(f"\n{layer_name}:") | |
print(f" - 总参数量: {total_mb:.2f}MB") | |
print(f" - 可训练参数: {trainable_mb:.2f}MB") | |
print(f" - 冻结参数: {frozen_mb:.2f}MB") | |
print(f" - 参数名称:") | |
for param_name in info['parameter_names']: | |
print(f" * {param_name}") | |
# 打印总体统计 | |
print("\n=== 总体统计 ===") | |
print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB") | |
print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB") | |
print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB") | |
print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%") | |
# 分别统计各个模块的参数量 | |
wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters()) | |
llama_params = sum(p.numel() for p in model.llama_nar.parameters()) | |
other_params = sum(p.numel() for name, p in model.named_parameters() | |
if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.')) | |
total_params = wav2vec2bert_params + llama_params + other_params | |
print(f"\n=== 参数量统计 ===") | |
print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)") | |
print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)") | |
print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)") | |
print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)") | |
# 计算百分比 | |
print(f"\n=== 参数量占比 ===") | |
print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%") | |
print(f"LlamaNAR: {llama_params/total_params*100:.2f}%") | |
print(f"其他模块: {other_params/total_params*100:.2f}%") | |
# 测试运行时间和内存使用 | |
print("\n=== 测试运行时间和内存使用 (batch_size=4) ===") | |
batch_size = 4 | |
total_samples = 600000 | |
# 清空GPU缓存 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
initial_memory = torch.cuda.memory_allocated() / 1024 / 1024 | |
print(f"初始GPU内存使用: {initial_memory:.2f}MB") | |
# 初始化数据集 | |
print("\n初始化数据集...") | |
ds = train_MultiDataset(max_prompts=3) | |
# 创建DataLoader | |
dl = DataLoader(ds, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=collate_fn, | |
num_workers=4) | |
print(f"\n数据集大小: {len(ds)}") | |
print(f"批次数量: {len(dl)}") | |
# 计算一个batch的平均时间 | |
num_test_batches = 10 | |
total_time = 0 | |
max_memory = 0 | |
print(f"\n测试{num_test_batches}个batch的平均运行时间...") | |
with torch.no_grad(): | |
for i, batch in enumerate(tqdm(dl, total=num_test_batches)): | |
if i >= num_test_batches: | |
break | |
# 正确处理字典类型的特征 | |
main_features = { | |
'input_features': batch['main_features']['input_features'].to(device), | |
'attention_mask': batch['main_features']['attention_mask'].to(device) | |
} | |
prompt_features = [{ | |
'input_features': pf['input_features'].to(device), | |
'attention_mask': pf['attention_mask'].to(device) | |
} for pf in batch['prompt_features']] | |
labels = batch['labels'].to(device) | |
prompt_labels = batch['prompt_labels'].to(device) | |
# 记录开始时间 | |
start_time = time.time() | |
# 前向传播 | |
outputs = model({ | |
'main_features': main_features, | |
'prompt_features': prompt_features, | |
'prompt_labels': prompt_labels | |
}) | |
# 确保GPU运算完成 | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
# 记录结束时间和内存使用 | |
end_time = time.time() | |
total_time += (end_time - start_time) | |
if torch.cuda.is_available(): | |
current_memory = torch.cuda.memory_allocated() / 1024 / 1024 | |
max_memory = max(max_memory, current_memory) | |
# 打印第一个batch的详细信息 | |
if i == 0: | |
print("\n=== 第一个Batch的详细信息 ===") | |
print(f"主特征形状: {main_features['input_features'].shape}") | |
print(f"主掩码形状: {main_features['attention_mask'].shape}") | |
print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}") | |
print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}") | |
print(f"标签形状: {labels.shape}") | |
print(f"Prompt标签形状: {prompt_labels.shape}") | |
print(f"模型输出形状: {outputs.shape}") | |
print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]") | |
# 计算和打印统计信息 | |
avg_time = total_time / num_test_batches | |
print(f"\n=== 性能统计 ===") | |
print(f"平均每个batch处理时间: {avg_time:.4f}秒") | |
print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时") | |
if torch.cuda.is_available(): | |
print(f"最大GPU内存使用: {max_memory:.2f}MB") | |
print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB") | |
print("\n测试完成!") |