File size: 16,224 Bytes
34146f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
import torch
import torch.nn as nn
from transformers import Wav2Vec2BertModel
from llama_nar import LlamaNAREmb
from transformers import LlamaConfig
import time
import torch.nn.functional as F
class Wav2Vec2BERT_Llama(nn.Module):
def __init__(self):
# 1. 加载预训练模型
self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("/mntcephfs/lab_data/wangli/pretrain/w2v-bert-2.0/", output_hidden_states=True)
# 2. 选择性冻结参数
for name, param in self.wav2vec2bert.named_parameters():
# 冻结所有FFN1 (保留FFN2的适应能力)
if 'ffn1' in name:
param.requires_grad = False
# 冻结多头注意力中的K,V投影
if any(proj in name for proj in ['linear_k', 'linear_v']):
param.requires_grad = False
# 冻结distance_embedding
if 'distance_embedding' in name:
param.requires_grad = False
# 冻结所有卷积相关模块
if any(conv_name in name for conv_name in [
'conv_module', 'pointwise_conv', 'depthwise_conv',
'feature_extractor', 'pos_conv_embed', 'conv_layers'
param.requires_grad = False
# 3. 减小Llama模型规模
self.llama_nar = LlamaNAREmb(
# 4. 降维投影层
self.projection = nn.Sequential(
nn.Linear(1024, 512),
# 5. 简化分类头
self.classifier = nn.Sequential(
nn.Linear(512, 128),
nn.Linear(128, 2)
# 6. 减小embedding维度
self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512)
# 7. 简化特征处理层
self.feature_processor = nn.Sequential(
nn.Linear(512, 512),
# 8. 减小特殊token的维度
self.special_tokens = nn.Parameter(torch.randn(4, 512))
def _fuse_layers(self, hidden_states):
# 修改特征融合方法
def downsample_sequence(sequence, factor=10):
batch_size, seq_len, hidden_size = sequence.shape
# 确保序列长度可以被因子整除
new_len = seq_len // factor
padded_len = new_len * factor
if seq_len > padded_len:
sequence = sequence[:, :padded_len, :]
# 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size]
reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size)
downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size]
return downsampled
# 1. 获取最后一层特征并进行下采样
last_layer = hidden_states[-1] # [batch_size, seq_len, 1024]
downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024]
# 2. 投影到512维度
projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512]
return projected_features # 不再需要unsqueeze,因为已经保留了序列维度
def forward(self, batch):
main_output = self.wav2vec2bert(
fused_features = self._fuse_layers(main_output.hidden_states)
fused_features = self.feature_processor(fused_features)
if ('prompt_labels' in batch and
batch['prompt_labels'] is not None and
'prompt_features' in batch and
batch['prompt_features'] and
len(batch['prompt_features']) > 0):
batch_size, num_prompts = batch['prompt_labels'].shape
# 重塑特征以批量处理
prompt_features = batch['prompt_features']
all_prompt_outputs = []
for i in range(num_prompts):
prompt_output = self.wav2vec2bert(
if all_prompt_outputs:
fused_prompts = torch.stack([
self.feature_processor(p) for p in all_prompt_outputs
], dim=1) # [batch_size, num_prompts, seq_len, hidden_size]
# 获取label embeddings并扩展到对应序列长度
label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512]
prompt_embeddings = []
for i in range(batch_size):
sequence = []
# 添加示例prompts
for j in range(num_prompts):
prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size]
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
# 扩展label embedding到与音频特征相同的长度
expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1)
sequence.append(expanded_label) # [seq_len, hidden_size]
sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP]
# 添加待预测的主特征
main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
sequence.append(fused_features[i]) # [main_seq_len, hidden_size]
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
# 预测位置使用零向量,长度与主特征相同
sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device))
prompt_embeddings.append(, dim=0))
prompt_embeddings = torch.stack(prompt_embeddings, dim=0)
# 简化无prompt情况的处理
batch_size = fused_features.size(0)
main_seq_len = fused_features.size(1) # 直接获取主特征序列长度
# 构建序列 [batch_size, total_len, hidden_size]
prompt_embeddings =[
self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT]
self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO]
fused_features, # [batch_size, main_seq_len, hidden_size]
self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL]
torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置
], dim=1)
# 输入到llama_nar
output = self.llama_nar(inputs_embeds=prompt_embeddings)
# 获取所有预测位置的输出(即最后main_seq_len个位置)
pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size]
# 对每一帧进行分类
frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2]
# 同时返回帧级别的logits和整体的logits(通过平均得到)
avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size]
avg_logits = self.classifier(avg_embedding) # [batch_size, 2]
return {
'frame_logits': frame_logits, # 每一帧的预测分数
'avg_logits': avg_logits # 整体的预测分数
if __name__ == '__main__':
import torch
from import DataLoader
from dataset.train_MultiDataset import train_MultiDataset, collate_fn
from tqdm import tqdm
import time
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n=== 使用设备: {device} ===")
# 初始化模型
print("\n=== 初始化模型 ===")
model = Wav2Vec2BERT_Llama().to(device)
model.eval() # 设置为评估模式
# 打印wav2vec2bert的参数结构
print("\n=== Wav2Vec2BERT 参数结构 ===")
w2v_params_by_layer = {}
total_trainable = 0
total_frozen = 0
for name, param in model.wav2vec2bert.named_parameters():
# 获取主要层名称
layer_name = name.split('.')[0]
if layer_name not in w2v_params_by_layer:
w2v_params_by_layer[layer_name] = {
'trainable_params': 0,
'frozen_params': 0,
'parameter_names': []
# 统计参数
if param.requires_grad:
w2v_params_by_layer[layer_name]['trainable_params'] += param.numel()
total_trainable += param.numel()
w2v_params_by_layer[layer_name]['frozen_params'] += param.numel()
total_frozen += param.numel()
# 打印每层的详细信息
for layer_name, info in w2v_params_by_layer.items():
trainable_mb = info['trainable_params'] / 1024 / 1024
frozen_mb = info['frozen_params'] / 1024 / 1024
total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024
print(f" - 总参数量: {total_mb:.2f}MB")
print(f" - 可训练参数: {trainable_mb:.2f}MB")
print(f" - 冻结参数: {frozen_mb:.2f}MB")
print(f" - 参数名称:")
for param_name in info['parameter_names']:
print(f" * {param_name}")
# 打印总体统计
print("\n=== 总体统计 ===")
print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB")
print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB")
print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB")
print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%")
# 分别统计各个模块的参数量
wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters())
llama_params = sum(p.numel() for p in model.llama_nar.parameters())
other_params = sum(p.numel() for name, p in model.named_parameters()
if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.'))
total_params = wav2vec2bert_params + llama_params + other_params
print(f"\n=== 参数量统计 ===")
print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)")
print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)")
print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)")
print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)")
# 计算百分比
print(f"\n=== 参数量占比 ===")
print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%")
print(f"LlamaNAR: {llama_params/total_params*100:.2f}%")
print(f"其他模块: {other_params/total_params*100:.2f}%")
# 测试运行时间和内存使用
print("\n=== 测试运行时间和内存使用 (batch_size=4) ===")
batch_size = 4
total_samples = 600000
# 清空GPU缓存
if torch.cuda.is_available():
initial_memory = torch.cuda.memory_allocated() / 1024 / 1024
print(f"初始GPU内存使用: {initial_memory:.2f}MB")
# 初始化数据集
ds = train_MultiDataset(max_prompts=3)
# 创建DataLoader
dl = DataLoader(ds,
print(f"\n数据集大小: {len(ds)}")
print(f"批次数量: {len(dl)}")
# 计算一个batch的平均时间
num_test_batches = 10
total_time = 0
max_memory = 0
with torch.no_grad():
for i, batch in enumerate(tqdm(dl, total=num_test_batches)):
if i >= num_test_batches:
# 正确处理字典类型的特征
main_features = {
'input_features': batch['main_features']['input_features'].to(device),
'attention_mask': batch['main_features']['attention_mask'].to(device)
prompt_features = [{
'input_features': pf['input_features'].to(device),
'attention_mask': pf['attention_mask'].to(device)
} for pf in batch['prompt_features']]
labels = batch['labels'].to(device)
prompt_labels = batch['prompt_labels'].to(device)
# 记录开始时间
start_time = time.time()
# 前向传播
outputs = model({
'main_features': main_features,
'prompt_features': prompt_features,
'prompt_labels': prompt_labels
# 确保GPU运算完成
if torch.cuda.is_available():
# 记录结束时间和内存使用
end_time = time.time()
total_time += (end_time - start_time)
if torch.cuda.is_available():
current_memory = torch.cuda.memory_allocated() / 1024 / 1024
max_memory = max(max_memory, current_memory)
# 打印第一个batch的详细信息
if i == 0:
print("\n=== 第一个Batch的详细信息 ===")
print(f"主特征形状: {main_features['input_features'].shape}")
print(f"主掩码形状: {main_features['attention_mask'].shape}")
print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}")
print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}")
print(f"标签形状: {labels.shape}")
print(f"Prompt标签形状: {prompt_labels.shape}")
print(f"模型输出形状: {outputs.shape}")
print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]")
# 计算和打印统计信息
avg_time = total_time / num_test_batches
print(f"\n=== 性能统计 ===")
print(f"平均每个batch处理时间: {avg_time:.4f}秒")
print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时")
if torch.cuda.is_available():
print(f"最大GPU内存使用: {max_memory:.2f}MB")
print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB")
print("\n测试完成!") |