DeepfakeDetection / model.py
wli3221134's picture
Update model.py
46689af verified
import torch
import torch.nn as nn
from transformers import Wav2Vec2BertModel
from llama_nar import LlamaNAREmb
from transformers import LlamaConfig
import time
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
class Wav2Vec2BERT_Llama(nn.Module):
def __init__(self):
super().__init__()
# 1. 加载预训练模型
self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True)
# 2. 选择性冻结参数
for name, param in self.wav2vec2bert.named_parameters():
# 冻结所有FFN1 (保留FFN2的适应能力)
if 'ffn1' in name:
param.requires_grad = False
# 冻结多头注意力中的K,V投影
if any(proj in name for proj in ['linear_k', 'linear_v']):
param.requires_grad = False
# 冻结distance_embedding
if 'distance_embedding' in name:
param.requires_grad = False
# 冻结所有卷积相关模块
if any(conv_name in name for conv_name in [
'conv_module', 'pointwise_conv', 'depthwise_conv',
'feature_extractor', 'pos_conv_embed', 'conv_layers'
]):
param.requires_grad = False
# 3. 减小Llama模型规模
self.llama_nar = LlamaNAREmb(
config=LlamaConfig(
hidden_size=512,
num_attention_heads=8,
num_hidden_layers=8,
),
num_heads=8,
num_layers=8,
hidden_size=512
)
# 4. 降维投影层
self.projection = nn.Sequential(
nn.Linear(1024, 512),
nn.LayerNorm(512)
)
# 5. 简化分类头
self.classifier = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 2)
)
# 6. 减小embedding维度
self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512)
# 7. 简化特征处理层
self.feature_processor = nn.Sequential(
nn.Linear(512, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Dropout(0.1)
)
# 8. 减小特殊token的维度
self.special_tokens = nn.Parameter(torch.randn(4, 512))
def _fuse_layers(self, hidden_states):
# 修改特征融合方法
def downsample_sequence(sequence, factor=10):
"""对序列进行下采样"""
batch_size, seq_len, hidden_size = sequence.shape
# 确保序列长度可以被因子整除
new_len = seq_len // factor
padded_len = new_len * factor
if seq_len > padded_len:
sequence = sequence[:, :padded_len, :]
# 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size]
reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size)
downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size]
return downsampled
# 1. 获取最后一层特征并进行下采样
last_layer = hidden_states[-1] # [batch_size, seq_len, 1024]
downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024]
# 2. 投影到512维度
projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512]
return projected_features # 不再需要unsqueeze,因为已经保留了序列维度
def forward(self, batch):
main_output = self.wav2vec2bert(
**batch['main_features']
)
fused_features = self._fuse_layers(main_output.hidden_states)
fused_features = self.feature_processor(fused_features)
if ('prompt_labels' in batch and
batch['prompt_labels'] is not None and
'prompt_features' in batch and
batch['prompt_features'] and
len(batch['prompt_features']) > 0):
batch_size, num_prompts = batch['prompt_labels'].shape
# 重塑特征以批量处理
prompt_features = batch['prompt_features']
all_prompt_outputs = []
for i in range(num_prompts):
prompt_output = self.wav2vec2bert(
**prompt_features[i]
)
all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states))
if all_prompt_outputs:
fused_prompts = torch.stack([
self.feature_processor(p) for p in all_prompt_outputs
], dim=1) # [batch_size, num_prompts, seq_len, hidden_size]
# 获取label embeddings并扩展到对应序列长度
label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512]
prompt_embeddings = []
for i in range(batch_size):
sequence = []
# 添加示例prompts
for j in range(num_prompts):
prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size]
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
# 扩展label embedding到与音频特征相同的长度
expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1)
sequence.append(expanded_label) # [seq_len, hidden_size]
sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP]
# 添加待预测的主特征
main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
sequence.append(fused_features[i]) # [main_seq_len, hidden_size]
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
# 预测位置使用零向量,长度与主特征相同
sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device))
prompt_embeddings.append(torch.cat(sequence, dim=0))
prompt_embeddings = torch.stack(prompt_embeddings, dim=0)
else:
# 简化无prompt情况的处理
batch_size = fused_features.size(0)
main_seq_len = fused_features.size(1) # 直接获取主特征序列长度
# 构建序列 [batch_size, total_len, hidden_size]
prompt_embeddings = torch.cat([
self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT]
self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO]
fused_features, # [batch_size, main_seq_len, hidden_size]
self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL]
torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置
], dim=1)
# 输入到llama_nar
output = self.llama_nar(inputs_embeds=prompt_embeddings)
# 获取所有预测位置的输出(即最后main_seq_len个位置)
pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size]
# 对每一帧进行分类
frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2]
# 同时返回帧级别的logits和整体的logits(通过平均得到)
avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size]
avg_logits = self.classifier(avg_embedding) # [batch_size, 2]
return {
'frame_logits': frame_logits, # 每一帧的预测分数
'avg_logits': avg_logits # 整体的预测分数
}
if __name__ == '__main__':
import torch
from torch.utils.data import DataLoader
from dataset.train_MultiDataset import train_MultiDataset, collate_fn
from tqdm import tqdm
import time
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n=== 使用设备: {device} ===")
# 初始化模型
print("\n=== 初始化模型 ===")
model = Wav2Vec2BERT_Llama().to(device)
model.eval() # 设置为评估模式
# 打印wav2vec2bert的参数结构
print("\n=== Wav2Vec2BERT 参数结构 ===")
w2v_params_by_layer = {}
total_trainable = 0
total_frozen = 0
for name, param in model.wav2vec2bert.named_parameters():
# 获取主要层名称
layer_name = name.split('.')[0]
if layer_name not in w2v_params_by_layer:
w2v_params_by_layer[layer_name] = {
'trainable_params': 0,
'frozen_params': 0,
'parameter_names': []
}
# 统计参数
if param.requires_grad:
w2v_params_by_layer[layer_name]['trainable_params'] += param.numel()
total_trainable += param.numel()
else:
w2v_params_by_layer[layer_name]['frozen_params'] += param.numel()
total_frozen += param.numel()
w2v_params_by_layer[layer_name]['parameter_names'].append(name)
# 打印每层的详细信息
print("\n各层参数统计:")
for layer_name, info in w2v_params_by_layer.items():
trainable_mb = info['trainable_params'] / 1024 / 1024
frozen_mb = info['frozen_params'] / 1024 / 1024
total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024
print(f"\n{layer_name}:")
print(f" - 总参数量: {total_mb:.2f}MB")
print(f" - 可训练参数: {trainable_mb:.2f}MB")
print(f" - 冻结参数: {frozen_mb:.2f}MB")
print(f" - 参数名称:")
for param_name in info['parameter_names']:
print(f" * {param_name}")
# 打印总体统计
print("\n=== 总体统计 ===")
print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB")
print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB")
print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB")
print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%")
# 分别统计各个模块的参数量
wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters())
llama_params = sum(p.numel() for p in model.llama_nar.parameters())
other_params = sum(p.numel() for name, p in model.named_parameters()
if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.'))
total_params = wav2vec2bert_params + llama_params + other_params
print(f"\n=== 参数量统计 ===")
print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)")
print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)")
print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)")
print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)")
# 计算百分比
print(f"\n=== 参数量占比 ===")
print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%")
print(f"LlamaNAR: {llama_params/total_params*100:.2f}%")
print(f"其他模块: {other_params/total_params*100:.2f}%")
# 测试运行时间和内存使用
print("\n=== 测试运行时间和内存使用 (batch_size=4) ===")
batch_size = 4
total_samples = 600000
# 清空GPU缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
initial_memory = torch.cuda.memory_allocated() / 1024 / 1024
print(f"初始GPU内存使用: {initial_memory:.2f}MB")
# 初始化数据集
print("\n初始化数据集...")
ds = train_MultiDataset(max_prompts=3)
# 创建DataLoader
dl = DataLoader(ds,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=4)
print(f"\n数据集大小: {len(ds)}")
print(f"批次数量: {len(dl)}")
# 计算一个batch的平均时间
num_test_batches = 10
total_time = 0
max_memory = 0
print(f"\n测试{num_test_batches}个batch的平均运行时间...")
with torch.no_grad():
for i, batch in enumerate(tqdm(dl, total=num_test_batches)):
if i >= num_test_batches:
break
# 正确处理字典类型的特征
main_features = {
'input_features': batch['main_features']['input_features'].to(device),
'attention_mask': batch['main_features']['attention_mask'].to(device)
}
prompt_features = [{
'input_features': pf['input_features'].to(device),
'attention_mask': pf['attention_mask'].to(device)
} for pf in batch['prompt_features']]
labels = batch['labels'].to(device)
prompt_labels = batch['prompt_labels'].to(device)
# 记录开始时间
start_time = time.time()
# 前向传播
outputs = model({
'main_features': main_features,
'prompt_features': prompt_features,
'prompt_labels': prompt_labels
})
# 确保GPU运算完成
if torch.cuda.is_available():
torch.cuda.synchronize()
# 记录结束时间和内存使用
end_time = time.time()
total_time += (end_time - start_time)
if torch.cuda.is_available():
current_memory = torch.cuda.memory_allocated() / 1024 / 1024
max_memory = max(max_memory, current_memory)
# 打印第一个batch的详细信息
if i == 0:
print("\n=== 第一个Batch的详细信息 ===")
print(f"主特征形状: {main_features['input_features'].shape}")
print(f"主掩码形状: {main_features['attention_mask'].shape}")
print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}")
print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}")
print(f"标签形状: {labels.shape}")
print(f"Prompt标签形状: {prompt_labels.shape}")
print(f"模型输出形状: {outputs.shape}")
print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]")
# 计算和打印统计信息
avg_time = total_time / num_test_batches
print(f"\n=== 性能统计 ===")
print(f"平均每个batch处理时间: {avg_time:.4f}秒")
print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时")
if torch.cuda.is_available():
print(f"最大GPU内存使用: {max_memory:.2f}MB")
print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB")
print("\n测试完成!")