Spaces:
Running
on
Zero
Running
on
Zero
wli3221134
commited on
Upload 5 files
Browse files- app.py +64 -15
- dataset.py +130 -0
- inference.py +2 -0
- llama_nar.py +571 -0
- model.py +379 -0
app.py
CHANGED
@@ -1,15 +1,56 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
-
|
|
|
|
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"""Audio deepfake detection function"""
|
7 |
# Replace with your actual detection logic
|
8 |
print("Demonstration audio paths: {}".format(demonstration_paths))
|
9 |
print("Query audio path: {}".format(audio_path))
|
10 |
-
|
|
|
|
|
11 |
# Example return value, modify according to your model
|
12 |
-
result =
|
13 |
|
14 |
# Return detection results and confidence scores
|
15 |
return {
|
@@ -33,14 +74,22 @@ with gr.Blocks() as demo:
|
|
33 |
"""
|
34 |
)
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
#
|
44 |
query_audio_input = gr.Audio(
|
45 |
sources=["upload"],
|
46 |
label="Query Audio (Audio for Detection)",
|
@@ -56,7 +105,7 @@ with gr.Blocks() as demo:
|
|
56 |
# Set click event
|
57 |
submit_btn.click(
|
58 |
fn=audio_deepfake_detection,
|
59 |
-
inputs=[demonstration_audio_input, query_audio_input],
|
60 |
outputs=[output_labels]
|
61 |
)
|
62 |
|
@@ -64,10 +113,10 @@ with gr.Blocks() as demo:
|
|
64 |
gr.Markdown("## Test Examples")
|
65 |
gr.Examples(
|
66 |
examples=[
|
67 |
-
["examples/real_audio.wav", "examples/query_audio.wav"],
|
68 |
-
["examples/fake_audio.wav", "examples/query_audio.wav"],
|
69 |
],
|
70 |
-
inputs=[demonstration_audio_input, query_audio_input],
|
71 |
)
|
72 |
|
73 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
+
import dataset
|
4 |
+
import torch
|
5 |
+
from model import Wav2Vec2BERT_Llama
|
6 |
|
7 |
+
# init
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
|
9 |
+
|
10 |
+
# init model
|
11 |
+
model = Wav2Vec2BERT_Llama().to(device)
|
12 |
+
checkpoint_path = "ckpt/model_checkpoint.pth"
|
13 |
+
if os.path.exists(checkpoint_path):
|
14 |
+
checkpoint = torch.load(checkpoint_path)
|
15 |
+
model_state_dict = checkpoint['model_state_dict']
|
16 |
+
|
17 |
+
# 处理模型状态字典
|
18 |
+
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
|
19 |
+
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
|
20 |
+
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
|
21 |
+
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
|
22 |
+
|
23 |
+
model.load_state_dict(model_state_dict)
|
24 |
+
model.eval()
|
25 |
+
else:
|
26 |
+
raise FileNotFoundError(f"Not found checkpoint: {checkpoint_path}")
|
27 |
+
|
28 |
+
|
29 |
+
def detect(dataset, model):
|
30 |
+
with torch.no_grad():
|
31 |
+
for batch in dataset:
|
32 |
+
main_features = {
|
33 |
+
'input_features': batch['main_features']['input_features'].to(device),
|
34 |
+
'attention_mask': batch['main_features']['attention_mask'].to(device)
|
35 |
+
}
|
36 |
+
prompt_features = [{
|
37 |
+
'input_features': pf['input_features'].to(device),
|
38 |
+
'attention_mask': pf['attention_mask'].to(device)
|
39 |
+
} for pf in batch['prompt_features']]
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def audio_deepfake_detection(demonstration_paths, audio_path, model):
|
45 |
"""Audio deepfake detection function"""
|
46 |
# Replace with your actual detection logic
|
47 |
print("Demonstration audio paths: {}".format(demonstration_paths))
|
48 |
print("Query audio path: {}".format(audio_path))
|
49 |
+
|
50 |
+
# dataset
|
51 |
+
dataset = dataset.DemoDataset(demonstration_paths, audio_path)
|
52 |
# Example return value, modify according to your model
|
53 |
+
result = detect(dataset, model)
|
54 |
|
55 |
# Return detection results and confidence scores
|
56 |
return {
|
|
|
74 |
"""
|
75 |
)
|
76 |
|
77 |
+
# Create container for demonstration audio
|
78 |
+
with gr.Row():
|
79 |
+
# Demonstration audio file upload
|
80 |
+
demonstration_audio_input = gr.File(
|
81 |
+
file_count="multiple",
|
82 |
+
file_types=["audio"],
|
83 |
+
label="Demonstration Audios",
|
84 |
+
)
|
85 |
+
# Add demonstration type selection
|
86 |
+
demonstration_type = gr.Dropdown(
|
87 |
+
choices=["bonafide", "spoof"],
|
88 |
+
value="bonafide",
|
89 |
+
label="Demonstration Label",
|
90 |
+
)
|
91 |
|
92 |
+
# Query audio input component
|
93 |
query_audio_input = gr.Audio(
|
94 |
sources=["upload"],
|
95 |
label="Query Audio (Audio for Detection)",
|
|
|
105 |
# Set click event
|
106 |
submit_btn.click(
|
107 |
fn=audio_deepfake_detection,
|
108 |
+
inputs=[demonstration_audio_input, demonstration_type, query_audio_input],
|
109 |
outputs=[output_labels]
|
110 |
)
|
111 |
|
|
|
113 |
gr.Markdown("## Test Examples")
|
114 |
gr.Examples(
|
115 |
examples=[
|
116 |
+
["examples/real_audio.wav", "bonafide", "examples/query_audio.wav"],
|
117 |
+
["examples/fake_audio.wav", "spoof", "examples/query_audio.wav"],
|
118 |
],
|
119 |
+
inputs=[demonstration_audio_input, demonstration_type, query_audio_input],
|
120 |
)
|
121 |
|
122 |
if __name__ == "__main__":
|
dataset.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from transformers import AutoFeatureExtractor
|
4 |
+
import os
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class DemoDataset(Dataset):
|
9 |
+
def __init__(self, demonstration_paths, query_path, sample_rate=16000):
|
10 |
+
self.sample_rate = sample_rate
|
11 |
+
self.query_path = query_path
|
12 |
+
|
13 |
+
# Convert to list if single path
|
14 |
+
if isinstance(demonstration_paths, str):
|
15 |
+
self.demonstration_paths = [demonstration_paths]
|
16 |
+
else:
|
17 |
+
self.demonstration_paths = demonstration_paths
|
18 |
+
|
19 |
+
# Load feature extractor
|
20 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
21 |
+
|
22 |
+
print(f'Number of demonstration audios: {len(self.demonstration_paths)}')
|
23 |
+
print(f'Query audio: {self.query_path}')
|
24 |
+
|
25 |
+
def load_pad(self, path, max_length=64000):
|
26 |
+
"""Load and pad audio file"""
|
27 |
+
X, sr = librosa.load(path, sr=self.sample_rate)
|
28 |
+
X = self.pad(X, max_length)
|
29 |
+
return X
|
30 |
+
|
31 |
+
def pad(self, x, max_len=64000):
|
32 |
+
"""Pad audio to fixed length"""
|
33 |
+
x_len = x.shape[0]
|
34 |
+
if x_len >= max_len:
|
35 |
+
return x[:max_len]
|
36 |
+
pad_length = max_len - x_len
|
37 |
+
return np.concatenate([x, np.zeros(pad_length)], axis=0)
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return 1 # Only one query audio
|
41 |
+
|
42 |
+
def __getitem__(self, idx):
|
43 |
+
# Load query audio
|
44 |
+
query_waveform = self.load_pad(self.query_path)
|
45 |
+
query_waveform = torch.from_numpy(query_waveform).float()
|
46 |
+
if len(query_waveform.shape) == 1:
|
47 |
+
query_waveform = query_waveform.unsqueeze(0)
|
48 |
+
|
49 |
+
# Extract features for query audio
|
50 |
+
main_features = self.feature_extractor(
|
51 |
+
query_waveform,
|
52 |
+
sampling_rate=self.sample_rate,
|
53 |
+
padding=True,
|
54 |
+
return_attention_mask=True,
|
55 |
+
return_tensors="pt"
|
56 |
+
)
|
57 |
+
|
58 |
+
# Process demonstration audios
|
59 |
+
prompt_features = []
|
60 |
+
for demo_path in self.demonstration_paths:
|
61 |
+
# Load demonstration audio
|
62 |
+
demo_waveform = self.load_pad(demo_path)
|
63 |
+
demo_waveform = torch.from_numpy(demo_waveform).float()
|
64 |
+
if len(demo_waveform.shape) == 1:
|
65 |
+
demo_waveform = demo_waveform.unsqueeze(0)
|
66 |
+
|
67 |
+
# Extract features
|
68 |
+
prompt_feature = self.feature_extractor(
|
69 |
+
demo_waveform,
|
70 |
+
sampling_rate=self.sample_rate,
|
71 |
+
padding=True,
|
72 |
+
return_attention_mask=True,
|
73 |
+
return_tensors="pt"
|
74 |
+
)
|
75 |
+
prompt_features.append(prompt_feature)
|
76 |
+
|
77 |
+
return {
|
78 |
+
'main_features': main_features,
|
79 |
+
'prompt_features': prompt_features,
|
80 |
+
'file_name': os.path.basename(self.query_path),
|
81 |
+
'file_path': self.query_path
|
82 |
+
}
|
83 |
+
|
84 |
+
def collate_fn(batch):
|
85 |
+
"""
|
86 |
+
Collate function for dataloader
|
87 |
+
Args:
|
88 |
+
batch: List containing dictionaries with:
|
89 |
+
- main_features: feature extractor output
|
90 |
+
- prompt_features: list of feature extractor outputs
|
91 |
+
- file_name: file name
|
92 |
+
- file_path: file path
|
93 |
+
"""
|
94 |
+
batch_size = len(batch)
|
95 |
+
|
96 |
+
# Process main features
|
97 |
+
main_features_keys = batch[0]['main_features'].keys()
|
98 |
+
main_features = {}
|
99 |
+
for key in main_features_keys:
|
100 |
+
main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0)
|
101 |
+
|
102 |
+
# Get number of prompts
|
103 |
+
num_prompts = len(batch[0]['prompt_features'])
|
104 |
+
|
105 |
+
# Process prompt features
|
106 |
+
prompt_features = []
|
107 |
+
for i in range(num_prompts):
|
108 |
+
prompt_feature = {}
|
109 |
+
for key in main_features_keys:
|
110 |
+
prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0)
|
111 |
+
prompt_features.append(prompt_feature)
|
112 |
+
|
113 |
+
# Collect file names and paths
|
114 |
+
file_names = [item['file_name'] for item in batch]
|
115 |
+
file_paths = [item['file_path'] for item in batch]
|
116 |
+
|
117 |
+
return {
|
118 |
+
'main_features': main_features,
|
119 |
+
'prompt_features': prompt_features,
|
120 |
+
'file_names': file_names,
|
121 |
+
'file_paths': file_paths
|
122 |
+
}
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
# Test the dataset
|
126 |
+
demo_paths = ["examples/demo1.wav", "examples/demo2.wav"]
|
127 |
+
query_path = "examples/query.wav"
|
128 |
+
|
129 |
+
dataset = DemoDataset(demo_paths, query_path)
|
130 |
+
print(dataset[0])
|
inference.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
def detect(dataset):
|
2 |
+
pass
|
llama_nar.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import torch.nn as nn
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
import math
|
9 |
+
|
10 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
11 |
+
from torchmetrics.classification import MulticlassAccuracy
|
12 |
+
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
|
13 |
+
|
14 |
+
|
15 |
+
# sinusoidal positional encoding
|
16 |
+
class SinusoidalPosEmb(nn.Module):
|
17 |
+
def __init__(self, dim):
|
18 |
+
super().__init__()
|
19 |
+
self.dim = dim
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
device = x.device
|
23 |
+
half_dim = self.dim // 2
|
24 |
+
emb = math.log(10000) / (half_dim - 1)
|
25 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
26 |
+
emb = x[:, None] * emb[None, :] * 1.0
|
27 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
28 |
+
return emb
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
class LlamaAdaptiveRMSNorm(nn.Module):
|
34 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
35 |
+
super().__init__()
|
36 |
+
self.eps = eps
|
37 |
+
# The gamma parameter
|
38 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
39 |
+
|
40 |
+
def _norm(self, x: torch.Tensor):
|
41 |
+
# (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
|
42 |
+
# rsqrt: 1 / sqrt(x)
|
43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
44 |
+
|
45 |
+
def forward(self, x: torch.Tensor):
|
46 |
+
# (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
|
47 |
+
return self.weight * self._norm(x.float()).type_as(x)
|
48 |
+
|
49 |
+
|
50 |
+
class MultiEmbedding(nn.Module):
|
51 |
+
"""Embedding for multiple quantization layers, summing up the embeddings of each layer."""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
num_embeddings=1028,
|
56 |
+
embedding_dim=1024,
|
57 |
+
num_quantization_layers=8,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.embeddings = nn.ModuleList(
|
61 |
+
[
|
62 |
+
nn.Embedding(num_embeddings, embedding_dim)
|
63 |
+
for _ in range(num_quantization_layers)
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
# initialize embeddings
|
68 |
+
for i in range(num_quantization_layers):
|
69 |
+
self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02)
|
70 |
+
self._is_hf_initialized = True # disable automatic init
|
71 |
+
|
72 |
+
def forward(self, input_ids):
|
73 |
+
"""Input: [num_quant, B, T] -> Output: [B, T, H]"""
|
74 |
+
num_quant, B, T = input_ids.shape
|
75 |
+
summed_embeddings = torch.zeros(
|
76 |
+
B, T, self.embeddings[0].embedding_dim, device=input_ids.device
|
77 |
+
)
|
78 |
+
for i in range(num_quant):
|
79 |
+
summed_embeddings += self.embeddings[i](input_ids[i])
|
80 |
+
return summed_embeddings
|
81 |
+
|
82 |
+
|
83 |
+
class LlamaNARDecoderLayer(LlamaDecoderLayer):
|
84 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
85 |
+
"""Override to adaptive layer norm"""
|
86 |
+
super().__init__(config, layer_idx) # init attention, mlp, etc.
|
87 |
+
self.input_layernorm = LlamaAdaptiveRMSNorm(
|
88 |
+
config.hidden_size, eps=config.rms_norm_eps
|
89 |
+
)
|
90 |
+
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
91 |
+
config.hidden_size, eps=config.rms_norm_eps
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def forward(
|
96 |
+
self,
|
97 |
+
hidden_states: torch.Tensor,
|
98 |
+
attention_mask: Optional[torch.Tensor] = None,
|
99 |
+
position_ids: Optional[torch.LongTensor] = None,
|
100 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
101 |
+
output_attentions: Optional[bool] = False,
|
102 |
+
use_cache: Optional[bool] = False,
|
103 |
+
) -> Tuple[
|
104 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
105 |
+
]:
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
109 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
110 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
111 |
+
output_attentions (`bool`, *optional*):
|
112 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
113 |
+
returned tensors for more detail.
|
114 |
+
use_cache (`bool`, *optional*):
|
115 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
116 |
+
(see `past_key_values`).
|
117 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
118 |
+
"""
|
119 |
+
|
120 |
+
residual = hidden_states
|
121 |
+
|
122 |
+
hidden_states = self.input_layernorm(hidden_states)
|
123 |
+
|
124 |
+
# Self Attention
|
125 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
126 |
+
hidden_states=hidden_states,
|
127 |
+
attention_mask=attention_mask,
|
128 |
+
position_ids=position_ids,
|
129 |
+
past_key_value=past_key_value,
|
130 |
+
output_attentions=output_attentions,
|
131 |
+
use_cache=use_cache,
|
132 |
+
)
|
133 |
+
hidden_states = residual + hidden_states
|
134 |
+
|
135 |
+
# Fully Connected
|
136 |
+
residual = hidden_states
|
137 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
138 |
+
hidden_states = self.mlp(hidden_states)
|
139 |
+
hidden_states = residual + hidden_states
|
140 |
+
|
141 |
+
outputs = (hidden_states,)
|
142 |
+
|
143 |
+
if output_attentions:
|
144 |
+
outputs += (self_attn_weights,)
|
145 |
+
|
146 |
+
if use_cache:
|
147 |
+
outputs += (present_key_value,)
|
148 |
+
|
149 |
+
return outputs
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
class LlamaNAR(LlamaModel):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
hidden_size=1024,
|
157 |
+
num_heads=16,
|
158 |
+
num_layers=16,
|
159 |
+
config=LlamaConfig(0, 256, 1024, 1, 1),
|
160 |
+
):
|
161 |
+
|
162 |
+
super().__init__(config)
|
163 |
+
self.layers = nn.ModuleList(
|
164 |
+
[
|
165 |
+
LlamaNARDecoderLayer(
|
166 |
+
config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4),
|
167 |
+
layer_idx=i,
|
168 |
+
)
|
169 |
+
for i in range(num_layers)
|
170 |
+
]
|
171 |
+
)
|
172 |
+
|
173 |
+
self.norm = LlamaAdaptiveRMSNorm(hidden_size)
|
174 |
+
|
175 |
+
self.multi_embedding = MultiEmbedding(
|
176 |
+
num_quantization_layers=8, embedding_dim=hidden_size
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
self.post_init()
|
181 |
+
|
182 |
+
def _prepare_decoder_attention_mask(
|
183 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
184 |
+
):
|
185 |
+
# create noncausal mask
|
186 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
187 |
+
combined_attention_mask = None
|
188 |
+
|
189 |
+
def _expand_mask(
|
190 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
191 |
+
):
|
192 |
+
"""
|
193 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
194 |
+
"""
|
195 |
+
bsz, src_len = mask.size()
|
196 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
197 |
+
|
198 |
+
expanded_mask = (
|
199 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
200 |
+
)
|
201 |
+
|
202 |
+
inverted_mask = 1.0 - expanded_mask
|
203 |
+
|
204 |
+
return inverted_mask.masked_fill(
|
205 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
206 |
+
)
|
207 |
+
|
208 |
+
if attention_mask is not None:
|
209 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
210 |
+
expanded_attn_mask = _expand_mask(
|
211 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
212 |
+
).to(inputs_embeds.device)
|
213 |
+
combined_attention_mask = (
|
214 |
+
expanded_attn_mask
|
215 |
+
if combined_attention_mask is None
|
216 |
+
else expanded_attn_mask + combined_attention_mask
|
217 |
+
)
|
218 |
+
|
219 |
+
return combined_attention_mask
|
220 |
+
|
221 |
+
def forward(
|
222 |
+
self,
|
223 |
+
input_ids: torch.LongTensor = None,
|
224 |
+
attention_mask: Optional[torch.Tensor] = None,
|
225 |
+
position_ids: Optional[torch.LongTensor] = None,
|
226 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
227 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
228 |
+
use_cache: Optional[bool] = None,
|
229 |
+
output_attentions: Optional[bool] = None,
|
230 |
+
output_hidden_states: Optional[bool] = None,
|
231 |
+
return_dict: Optional[bool] = None,
|
232 |
+
cache_position: Optional[torch.LongTensor] = None,
|
233 |
+
length: Optional[torch.LongTensor] = None,
|
234 |
+
)-> Union[Tuple, BaseModelOutputWithPast]:
|
235 |
+
|
236 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
237 |
+
output_hidden_states = (
|
238 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
239 |
+
)
|
240 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
241 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
242 |
+
|
243 |
+
|
244 |
+
batch_size, seq_length, num_quant = input_ids.shape
|
245 |
+
input_ids = input_ids.permute(2, 0, 1) # [num_quant, B, T]
|
246 |
+
inputs_embeds = self.multi_embedding(input_ids)
|
247 |
+
|
248 |
+
seq_length_with_past = seq_length
|
249 |
+
past_key_values_length = 0
|
250 |
+
|
251 |
+
if past_key_values is not None:
|
252 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
253 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
254 |
+
|
255 |
+
if position_ids is None:
|
256 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
257 |
+
position_ids = torch.arange(
|
258 |
+
past_key_values_length,
|
259 |
+
seq_length + past_key_values_length,
|
260 |
+
dtype=torch.long,
|
261 |
+
device=device,
|
262 |
+
)
|
263 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
264 |
+
else:
|
265 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
266 |
+
|
267 |
+
# embed positions
|
268 |
+
if attention_mask is None:
|
269 |
+
attention_mask = torch.ones(
|
270 |
+
(batch_size, seq_length_with_past),
|
271 |
+
dtype=torch.bool,
|
272 |
+
device=inputs_embeds.device,
|
273 |
+
)
|
274 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
275 |
+
attention_mask,
|
276 |
+
(batch_size, seq_length),
|
277 |
+
inputs_embeds,
|
278 |
+
past_key_values_length,
|
279 |
+
)
|
280 |
+
|
281 |
+
hidden_states = inputs_embeds
|
282 |
+
|
283 |
+
if self.gradient_checkpointing and self.training:
|
284 |
+
if use_cache:
|
285 |
+
use_cache = False
|
286 |
+
|
287 |
+
# decoder layers
|
288 |
+
all_hidden_states = () if output_hidden_states else None
|
289 |
+
all_self_attns = () if output_attentions else None
|
290 |
+
next_decoder_cache = () if use_cache else None
|
291 |
+
|
292 |
+
for idx, decoder_layer in enumerate(self.layers):
|
293 |
+
if output_hidden_states:
|
294 |
+
all_hidden_states += (hidden_states,)
|
295 |
+
|
296 |
+
past_key_value = (
|
297 |
+
past_key_values[idx] if past_key_values is not None else None
|
298 |
+
)
|
299 |
+
|
300 |
+
if self.gradient_checkpointing and self.training:
|
301 |
+
raise NotImplementedError
|
302 |
+
|
303 |
+
def create_custom_forward(module):
|
304 |
+
def custom_forward(*inputs):
|
305 |
+
# None for past_key_value
|
306 |
+
return module(*inputs, output_attentions, None)
|
307 |
+
|
308 |
+
return custom_forward
|
309 |
+
|
310 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
311 |
+
create_custom_forward(decoder_layer),
|
312 |
+
hidden_states,
|
313 |
+
attention_mask,
|
314 |
+
position_ids,
|
315 |
+
None,
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
layer_outputs = decoder_layer(
|
319 |
+
hidden_states,
|
320 |
+
attention_mask=attention_mask,
|
321 |
+
position_ids=position_ids,
|
322 |
+
past_key_value=past_key_value,
|
323 |
+
output_attentions=output_attentions,
|
324 |
+
use_cache=use_cache,
|
325 |
+
)
|
326 |
+
|
327 |
+
hidden_states = layer_outputs[0]
|
328 |
+
|
329 |
+
if use_cache:
|
330 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
331 |
+
|
332 |
+
if output_attentions:
|
333 |
+
all_self_attns += (layer_outputs[1],)
|
334 |
+
|
335 |
+
hidden_states = self.norm(hidden_states)
|
336 |
+
|
337 |
+
# add hidden states from the last decoder layer
|
338 |
+
if output_hidden_states:
|
339 |
+
all_hidden_states += (hidden_states,)
|
340 |
+
|
341 |
+
next_cache = next_decoder_cache if use_cache else None
|
342 |
+
|
343 |
+
return hidden_states
|
344 |
+
|
345 |
+
class LlamaNAREmb(LlamaModel):
|
346 |
+
"""LlamaNAR model that works directly with embeddings input.
|
347 |
+
|
348 |
+
This variant of LlamaNAR takes pre-computed embeddings as input
|
349 |
+
instead of token IDs that need to be embedded.
|
350 |
+
"""
|
351 |
+
def __init__(
|
352 |
+
self,
|
353 |
+
hidden_size=1024,
|
354 |
+
num_heads=16,
|
355 |
+
num_layers=16,
|
356 |
+
config=LlamaConfig(0, 256, 1024, 1, 1),
|
357 |
+
):
|
358 |
+
|
359 |
+
super().__init__(config)
|
360 |
+
self.layers = nn.ModuleList(
|
361 |
+
[
|
362 |
+
LlamaNARDecoderLayer(
|
363 |
+
config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4),
|
364 |
+
layer_idx=i,
|
365 |
+
)
|
366 |
+
for i in range(num_layers)
|
367 |
+
]
|
368 |
+
)
|
369 |
+
|
370 |
+
self.norm = LlamaAdaptiveRMSNorm(hidden_size)
|
371 |
+
|
372 |
+
|
373 |
+
self.post_init()
|
374 |
+
|
375 |
+
def _prepare_decoder_attention_mask(
|
376 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
377 |
+
):
|
378 |
+
# create noncausal mask
|
379 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
380 |
+
combined_attention_mask = None
|
381 |
+
|
382 |
+
def _expand_mask(
|
383 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
384 |
+
):
|
385 |
+
"""
|
386 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
387 |
+
"""
|
388 |
+
bsz, src_len = mask.size()
|
389 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
390 |
+
|
391 |
+
expanded_mask = (
|
392 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
393 |
+
)
|
394 |
+
|
395 |
+
inverted_mask = 1.0 - expanded_mask
|
396 |
+
|
397 |
+
return inverted_mask.masked_fill(
|
398 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
399 |
+
)
|
400 |
+
|
401 |
+
if attention_mask is not None:
|
402 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
403 |
+
expanded_attn_mask = _expand_mask(
|
404 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
405 |
+
).to(inputs_embeds.device)
|
406 |
+
combined_attention_mask = (
|
407 |
+
expanded_attn_mask
|
408 |
+
if combined_attention_mask is None
|
409 |
+
else expanded_attn_mask + combined_attention_mask
|
410 |
+
)
|
411 |
+
|
412 |
+
return combined_attention_mask
|
413 |
+
|
414 |
+
def forward(
|
415 |
+
self,
|
416 |
+
input_ids: torch.LongTensor = None,
|
417 |
+
attention_mask: Optional[torch.Tensor] = None,
|
418 |
+
position_ids: Optional[torch.LongTensor] = None,
|
419 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
420 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
421 |
+
use_cache: Optional[bool] = None,
|
422 |
+
output_attentions: Optional[bool] = None,
|
423 |
+
output_hidden_states: Optional[bool] = None,
|
424 |
+
return_dict: Optional[bool] = None,
|
425 |
+
cache_position: Optional[torch.LongTensor] = None,
|
426 |
+
)-> torch.Tensor:
|
427 |
+
"""
|
428 |
+
Returns:
|
429 |
+
hidden_states: Tensor of shape (batch_size, sequence_length, hidden_size)
|
430 |
+
"""
|
431 |
+
|
432 |
+
if inputs_embeds is None:
|
433 |
+
raise ValueError("inputs_embeds must be provided for LlamaNAREmb")
|
434 |
+
|
435 |
+
if input_ids is not None:
|
436 |
+
warnings.warn("input_ids is ignored in LlamaNAREmb, use inputs_embeds instead")
|
437 |
+
|
438 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
439 |
+
output_hidden_states = (
|
440 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
441 |
+
)
|
442 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
443 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
444 |
+
|
445 |
+
batch_size, seq_length, hidden_size = inputs_embeds.shape
|
446 |
+
|
447 |
+
seq_length_with_past = seq_length
|
448 |
+
past_key_values_length = 0
|
449 |
+
|
450 |
+
if past_key_values is not None:
|
451 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
452 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
453 |
+
|
454 |
+
if position_ids is None:
|
455 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
456 |
+
position_ids = torch.arange(
|
457 |
+
past_key_values_length,
|
458 |
+
seq_length + past_key_values_length,
|
459 |
+
dtype=torch.long,
|
460 |
+
device=device,
|
461 |
+
)
|
462 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
463 |
+
else:
|
464 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
465 |
+
|
466 |
+
# embed positions
|
467 |
+
if attention_mask is None:
|
468 |
+
attention_mask = torch.ones(
|
469 |
+
(batch_size, seq_length_with_past),
|
470 |
+
dtype=torch.bool,
|
471 |
+
device=inputs_embeds.device,
|
472 |
+
)
|
473 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
474 |
+
attention_mask,
|
475 |
+
(batch_size, seq_length),
|
476 |
+
inputs_embeds,
|
477 |
+
past_key_values_length,
|
478 |
+
)
|
479 |
+
|
480 |
+
hidden_states = inputs_embeds
|
481 |
+
|
482 |
+
if self.gradient_checkpointing and self.training:
|
483 |
+
if use_cache:
|
484 |
+
use_cache = False
|
485 |
+
|
486 |
+
# decoder layers
|
487 |
+
all_hidden_states = () if output_hidden_states else None
|
488 |
+
all_self_attns = () if output_attentions else None
|
489 |
+
next_decoder_cache = () if use_cache else None
|
490 |
+
|
491 |
+
for idx, decoder_layer in enumerate(self.layers):
|
492 |
+
if output_hidden_states:
|
493 |
+
all_hidden_states += (hidden_states,)
|
494 |
+
|
495 |
+
past_key_value = (
|
496 |
+
past_key_values[idx] if past_key_values is not None else None
|
497 |
+
)
|
498 |
+
|
499 |
+
if self.gradient_checkpointing and self.training:
|
500 |
+
raise NotImplementedError
|
501 |
+
|
502 |
+
def create_custom_forward(module):
|
503 |
+
def custom_forward(*inputs):
|
504 |
+
# None for past_key_value
|
505 |
+
return module(*inputs, output_attentions, None)
|
506 |
+
|
507 |
+
return custom_forward
|
508 |
+
|
509 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
510 |
+
create_custom_forward(decoder_layer),
|
511 |
+
hidden_states,
|
512 |
+
attention_mask,
|
513 |
+
position_ids,
|
514 |
+
None,
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
layer_outputs = decoder_layer(
|
518 |
+
hidden_states,
|
519 |
+
attention_mask=attention_mask,
|
520 |
+
position_ids=position_ids,
|
521 |
+
past_key_value=past_key_value,
|
522 |
+
output_attentions=output_attentions,
|
523 |
+
use_cache=use_cache,
|
524 |
+
)
|
525 |
+
|
526 |
+
hidden_states = layer_outputs[0]
|
527 |
+
|
528 |
+
if use_cache:
|
529 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
530 |
+
|
531 |
+
if output_attentions:
|
532 |
+
all_self_attns += (layer_outputs[1],)
|
533 |
+
|
534 |
+
hidden_states = self.norm(hidden_states)
|
535 |
+
|
536 |
+
# add hidden states from the last decoder layer
|
537 |
+
if output_hidden_states:
|
538 |
+
all_hidden_states += (hidden_states,)
|
539 |
+
|
540 |
+
next_cache = next_decoder_cache if use_cache else None
|
541 |
+
|
542 |
+
|
543 |
+
return hidden_states
|
544 |
+
|
545 |
+
if __name__ == '__main__':
|
546 |
+
config = LlamaConfig(hidden_size=1024, num_attention_heads=8, num_hidden_layers=8)
|
547 |
+
|
548 |
+
model = LlamaNAR(config=config)
|
549 |
+
|
550 |
+
# 模拟输入数据
|
551 |
+
batch_size = 2
|
552 |
+
seq_length = 10
|
553 |
+
n_q = 8
|
554 |
+
input_ids = torch.randint(0, 1028, (batch_size, seq_length, n_q)) # 随机生成输入ID
|
555 |
+
inputs_embeds = torch.randn(batch_size, seq_length, config.hidden_size) # 随机生成输入嵌入
|
556 |
+
attention_mask = torch.ones(batch_size, seq_length) # 所有位置可见
|
557 |
+
length = torch.tensor([4,10]) # 输入长度
|
558 |
+
|
559 |
+
# 前向传播
|
560 |
+
hidden_states, class_out = model(
|
561 |
+
input_ids=input_ids,
|
562 |
+
attention_mask=attention_mask,
|
563 |
+
output_attentions=True,
|
564 |
+
output_hidden_states=True,
|
565 |
+
length=length
|
566 |
+
)
|
567 |
+
|
568 |
+
# 打印输出形状
|
569 |
+
print("Hidden States Shape:", hidden_states.shape) # 输出隐藏状态形状
|
570 |
+
print('Class output Shape:', class_out.shape)
|
571 |
+
|
model.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import Wav2Vec2BertModel
|
4 |
+
from llama_nar import LlamaNAREmb
|
5 |
+
from transformers import LlamaConfig
|
6 |
+
import time
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class Wav2Vec2BERT_Llama(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
# 1. 加载预训练模型
|
15 |
+
self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("/mntcephfs/lab_data/wangli/pretrain/w2v-bert-2.0/", output_hidden_states=True)
|
16 |
+
|
17 |
+
# 2. 选择性冻结参数
|
18 |
+
for name, param in self.wav2vec2bert.named_parameters():
|
19 |
+
# 冻结所有FFN1 (保留FFN2的适应能力)
|
20 |
+
if 'ffn1' in name:
|
21 |
+
param.requires_grad = False
|
22 |
+
|
23 |
+
# 冻结多头注意力中的K,V投影
|
24 |
+
if any(proj in name for proj in ['linear_k', 'linear_v']):
|
25 |
+
param.requires_grad = False
|
26 |
+
|
27 |
+
# 冻结distance_embedding
|
28 |
+
if 'distance_embedding' in name:
|
29 |
+
param.requires_grad = False
|
30 |
+
|
31 |
+
# 冻结所有卷积相关模块
|
32 |
+
if any(conv_name in name for conv_name in [
|
33 |
+
'conv_module', 'pointwise_conv', 'depthwise_conv',
|
34 |
+
'feature_extractor', 'pos_conv_embed', 'conv_layers'
|
35 |
+
]):
|
36 |
+
param.requires_grad = False
|
37 |
+
|
38 |
+
# 3. 减小Llama模型规模
|
39 |
+
self.llama_nar = LlamaNAREmb(
|
40 |
+
config=LlamaConfig(
|
41 |
+
hidden_size=512,
|
42 |
+
num_attention_heads=8,
|
43 |
+
num_hidden_layers=8,
|
44 |
+
),
|
45 |
+
num_heads=8,
|
46 |
+
num_layers=8,
|
47 |
+
hidden_size=512
|
48 |
+
)
|
49 |
+
|
50 |
+
# 4. 降维投影层
|
51 |
+
self.projection = nn.Sequential(
|
52 |
+
nn.Linear(1024, 512),
|
53 |
+
nn.LayerNorm(512)
|
54 |
+
)
|
55 |
+
|
56 |
+
# 5. 简化分类头
|
57 |
+
self.classifier = nn.Sequential(
|
58 |
+
nn.Linear(512, 128),
|
59 |
+
nn.ReLU(),
|
60 |
+
nn.Dropout(0.1),
|
61 |
+
nn.Linear(128, 2)
|
62 |
+
)
|
63 |
+
|
64 |
+
# 6. 减小embedding维度
|
65 |
+
self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512)
|
66 |
+
|
67 |
+
# 7. 简化特征处理层
|
68 |
+
self.feature_processor = nn.Sequential(
|
69 |
+
nn.Linear(512, 512),
|
70 |
+
nn.LayerNorm(512),
|
71 |
+
nn.ReLU(),
|
72 |
+
nn.Dropout(0.1)
|
73 |
+
)
|
74 |
+
|
75 |
+
# 8. 减小特殊token的维度
|
76 |
+
self.special_tokens = nn.Parameter(torch.randn(4, 512))
|
77 |
+
|
78 |
+
def _fuse_layers(self, hidden_states):
|
79 |
+
# 修改特征融合方法
|
80 |
+
def downsample_sequence(sequence, factor=10):
|
81 |
+
"""对序列进行下采样"""
|
82 |
+
batch_size, seq_len, hidden_size = sequence.shape
|
83 |
+
# 确保序列长度可以被因子整除
|
84 |
+
new_len = seq_len // factor
|
85 |
+
padded_len = new_len * factor
|
86 |
+
|
87 |
+
if seq_len > padded_len:
|
88 |
+
sequence = sequence[:, :padded_len, :]
|
89 |
+
|
90 |
+
# 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size]
|
91 |
+
reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size)
|
92 |
+
downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size]
|
93 |
+
return downsampled
|
94 |
+
|
95 |
+
# 1. 获取最后一层特征并进行下采样
|
96 |
+
last_layer = hidden_states[-1] # [batch_size, seq_len, 1024]
|
97 |
+
downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024]
|
98 |
+
|
99 |
+
# 2. 投影到512维度
|
100 |
+
projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512]
|
101 |
+
|
102 |
+
return projected_features # 不再需要unsqueeze,因为已经保留了序列维度
|
103 |
+
|
104 |
+
def forward(self, batch):
|
105 |
+
main_output = self.wav2vec2bert(
|
106 |
+
**batch['main_features']
|
107 |
+
)
|
108 |
+
|
109 |
+
fused_features = self._fuse_layers(main_output.hidden_states)
|
110 |
+
fused_features = self.feature_processor(fused_features)
|
111 |
+
|
112 |
+
if ('prompt_labels' in batch and
|
113 |
+
batch['prompt_labels'] is not None and
|
114 |
+
'prompt_features' in batch and
|
115 |
+
batch['prompt_features'] and
|
116 |
+
len(batch['prompt_features']) > 0):
|
117 |
+
|
118 |
+
batch_size, num_prompts = batch['prompt_labels'].shape
|
119 |
+
|
120 |
+
# 重塑特征以批量处理
|
121 |
+
prompt_features = batch['prompt_features']
|
122 |
+
all_prompt_outputs = []
|
123 |
+
|
124 |
+
for i in range(num_prompts):
|
125 |
+
prompt_output = self.wav2vec2bert(
|
126 |
+
**prompt_features[i]
|
127 |
+
)
|
128 |
+
all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states))
|
129 |
+
|
130 |
+
if all_prompt_outputs:
|
131 |
+
fused_prompts = torch.stack([
|
132 |
+
self.feature_processor(p) for p in all_prompt_outputs
|
133 |
+
], dim=1) # [batch_size, num_prompts, seq_len, hidden_size]
|
134 |
+
|
135 |
+
# 获取label embeddings并扩展到对应序列长度
|
136 |
+
label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512]
|
137 |
+
|
138 |
+
prompt_embeddings = []
|
139 |
+
for i in range(batch_size):
|
140 |
+
sequence = []
|
141 |
+
|
142 |
+
# 添加示例prompts
|
143 |
+
for j in range(num_prompts):
|
144 |
+
prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度
|
145 |
+
|
146 |
+
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
|
147 |
+
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
|
148 |
+
sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size]
|
149 |
+
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
|
150 |
+
|
151 |
+
# 扩展label embedding到与音频特征相同的长度
|
152 |
+
expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1)
|
153 |
+
sequence.append(expanded_label) # [seq_len, hidden_size]
|
154 |
+
|
155 |
+
sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP]
|
156 |
+
|
157 |
+
# 添加待预测的主特征
|
158 |
+
main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度
|
159 |
+
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
|
160 |
+
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
|
161 |
+
sequence.append(fused_features[i]) # [main_seq_len, hidden_size]
|
162 |
+
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
|
163 |
+
# 预测位置使用零向量,长度与主特征相同
|
164 |
+
sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device))
|
165 |
+
|
166 |
+
prompt_embeddings.append(torch.cat(sequence, dim=0))
|
167 |
+
|
168 |
+
prompt_embeddings = torch.stack(prompt_embeddings, dim=0)
|
169 |
+
|
170 |
+
else:
|
171 |
+
# 简化无prompt情况的处理
|
172 |
+
batch_size = fused_features.size(0)
|
173 |
+
main_seq_len = fused_features.size(1) # 直接获取主特征序列长度
|
174 |
+
|
175 |
+
# 构建序列 [batch_size, total_len, hidden_size]
|
176 |
+
prompt_embeddings = torch.cat([
|
177 |
+
self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT]
|
178 |
+
self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO]
|
179 |
+
fused_features, # [batch_size, main_seq_len, hidden_size]
|
180 |
+
self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL]
|
181 |
+
torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置
|
182 |
+
], dim=1)
|
183 |
+
|
184 |
+
# 输入到llama_nar
|
185 |
+
output = self.llama_nar(inputs_embeds=prompt_embeddings)
|
186 |
+
|
187 |
+
# 获取所有预测位置的输出(即最后main_seq_len个位置)
|
188 |
+
pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size]
|
189 |
+
# 对每一帧进行分类
|
190 |
+
frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2]
|
191 |
+
|
192 |
+
# 同时返回帧级别的logits和整体的logits(通过平均得到)
|
193 |
+
avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size]
|
194 |
+
avg_logits = self.classifier(avg_embedding) # [batch_size, 2]
|
195 |
+
|
196 |
+
return {
|
197 |
+
'frame_logits': frame_logits, # 每一帧的预测分数
|
198 |
+
'avg_logits': avg_logits # 整体的预测分数
|
199 |
+
}
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == '__main__':
|
203 |
+
import torch
|
204 |
+
from torch.utils.data import DataLoader
|
205 |
+
from dataset.train_MultiDataset import train_MultiDataset, collate_fn
|
206 |
+
from tqdm import tqdm
|
207 |
+
import time
|
208 |
+
|
209 |
+
# 设置设备
|
210 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
211 |
+
print(f"\n=== 使用设备: {device} ===")
|
212 |
+
|
213 |
+
# 初始化模型
|
214 |
+
print("\n=== 初始化模型 ===")
|
215 |
+
model = Wav2Vec2BERT_Llama().to(device)
|
216 |
+
model.eval() # 设置为评估模式
|
217 |
+
|
218 |
+
# 打印wav2vec2bert的参数结构
|
219 |
+
print("\n=== Wav2Vec2BERT 参数结构 ===")
|
220 |
+
w2v_params_by_layer = {}
|
221 |
+
total_trainable = 0
|
222 |
+
total_frozen = 0
|
223 |
+
|
224 |
+
for name, param in model.wav2vec2bert.named_parameters():
|
225 |
+
# 获取主要层名称
|
226 |
+
layer_name = name.split('.')[0]
|
227 |
+
if layer_name not in w2v_params_by_layer:
|
228 |
+
w2v_params_by_layer[layer_name] = {
|
229 |
+
'trainable_params': 0,
|
230 |
+
'frozen_params': 0,
|
231 |
+
'parameter_names': []
|
232 |
+
}
|
233 |
+
|
234 |
+
# 统计参数
|
235 |
+
if param.requires_grad:
|
236 |
+
w2v_params_by_layer[layer_name]['trainable_params'] += param.numel()
|
237 |
+
total_trainable += param.numel()
|
238 |
+
else:
|
239 |
+
w2v_params_by_layer[layer_name]['frozen_params'] += param.numel()
|
240 |
+
total_frozen += param.numel()
|
241 |
+
|
242 |
+
w2v_params_by_layer[layer_name]['parameter_names'].append(name)
|
243 |
+
|
244 |
+
# 打印每层的详细信息
|
245 |
+
print("\n各层参数统计:")
|
246 |
+
for layer_name, info in w2v_params_by_layer.items():
|
247 |
+
trainable_mb = info['trainable_params'] / 1024 / 1024
|
248 |
+
frozen_mb = info['frozen_params'] / 1024 / 1024
|
249 |
+
total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024
|
250 |
+
|
251 |
+
print(f"\n{layer_name}:")
|
252 |
+
print(f" - 总参数量: {total_mb:.2f}MB")
|
253 |
+
print(f" - 可训练参数: {trainable_mb:.2f}MB")
|
254 |
+
print(f" - 冻结参数: {frozen_mb:.2f}MB")
|
255 |
+
print(f" - 参数名称:")
|
256 |
+
for param_name in info['parameter_names']:
|
257 |
+
print(f" * {param_name}")
|
258 |
+
|
259 |
+
# 打印总体统计
|
260 |
+
print("\n=== 总体统计 ===")
|
261 |
+
print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB")
|
262 |
+
print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB")
|
263 |
+
print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB")
|
264 |
+
print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%")
|
265 |
+
|
266 |
+
# 分别统计各个模块的参数量
|
267 |
+
wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters())
|
268 |
+
llama_params = sum(p.numel() for p in model.llama_nar.parameters())
|
269 |
+
other_params = sum(p.numel() for name, p in model.named_parameters()
|
270 |
+
if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.'))
|
271 |
+
|
272 |
+
total_params = wav2vec2bert_params + llama_params + other_params
|
273 |
+
|
274 |
+
print(f"\n=== 参数量统计 ===")
|
275 |
+
print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)")
|
276 |
+
print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)")
|
277 |
+
print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)")
|
278 |
+
print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)")
|
279 |
+
|
280 |
+
# 计算百分比
|
281 |
+
print(f"\n=== 参数量占比 ===")
|
282 |
+
print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%")
|
283 |
+
print(f"LlamaNAR: {llama_params/total_params*100:.2f}%")
|
284 |
+
print(f"其他模块: {other_params/total_params*100:.2f}%")
|
285 |
+
|
286 |
+
# 测试运行时间和内存使用
|
287 |
+
print("\n=== 测试运行时间和内存使用 (batch_size=4) ===")
|
288 |
+
batch_size = 4
|
289 |
+
total_samples = 600000
|
290 |
+
|
291 |
+
# 清空GPU缓存
|
292 |
+
if torch.cuda.is_available():
|
293 |
+
torch.cuda.empty_cache()
|
294 |
+
initial_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
295 |
+
print(f"初始GPU内存使用: {initial_memory:.2f}MB")
|
296 |
+
|
297 |
+
# 初始化数据集
|
298 |
+
print("\n初始化数据集...")
|
299 |
+
ds = train_MultiDataset(max_prompts=3)
|
300 |
+
|
301 |
+
# 创建DataLoader
|
302 |
+
dl = DataLoader(ds,
|
303 |
+
batch_size=batch_size,
|
304 |
+
shuffle=True,
|
305 |
+
collate_fn=collate_fn,
|
306 |
+
num_workers=4)
|
307 |
+
|
308 |
+
print(f"\n数据集大小: {len(ds)}")
|
309 |
+
print(f"批次数量: {len(dl)}")
|
310 |
+
|
311 |
+
# 计算一个batch的平均时间
|
312 |
+
num_test_batches = 10
|
313 |
+
total_time = 0
|
314 |
+
max_memory = 0
|
315 |
+
|
316 |
+
print(f"\n测试{num_test_batches}个batch的平均运行时间...")
|
317 |
+
with torch.no_grad():
|
318 |
+
for i, batch in enumerate(tqdm(dl, total=num_test_batches)):
|
319 |
+
if i >= num_test_batches:
|
320 |
+
break
|
321 |
+
|
322 |
+
# 正确处理字典类型的特征
|
323 |
+
main_features = {
|
324 |
+
'input_features': batch['main_features']['input_features'].to(device),
|
325 |
+
'attention_mask': batch['main_features']['attention_mask'].to(device)
|
326 |
+
}
|
327 |
+
|
328 |
+
prompt_features = [{
|
329 |
+
'input_features': pf['input_features'].to(device),
|
330 |
+
'attention_mask': pf['attention_mask'].to(device)
|
331 |
+
} for pf in batch['prompt_features']]
|
332 |
+
|
333 |
+
labels = batch['labels'].to(device)
|
334 |
+
prompt_labels = batch['prompt_labels'].to(device)
|
335 |
+
|
336 |
+
# 记录开始时间
|
337 |
+
start_time = time.time()
|
338 |
+
|
339 |
+
# 前向传播
|
340 |
+
outputs = model({
|
341 |
+
'main_features': main_features,
|
342 |
+
'prompt_features': prompt_features,
|
343 |
+
'prompt_labels': prompt_labels
|
344 |
+
})
|
345 |
+
|
346 |
+
# 确保GPU运算完成
|
347 |
+
if torch.cuda.is_available():
|
348 |
+
torch.cuda.synchronize()
|
349 |
+
|
350 |
+
# 记录结束时间和内存使用
|
351 |
+
end_time = time.time()
|
352 |
+
total_time += (end_time - start_time)
|
353 |
+
|
354 |
+
if torch.cuda.is_available():
|
355 |
+
current_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
356 |
+
max_memory = max(max_memory, current_memory)
|
357 |
+
|
358 |
+
# 打印第一���batch的详细信息
|
359 |
+
if i == 0:
|
360 |
+
print("\n=== 第一个Batch的详细信息 ===")
|
361 |
+
print(f"主特征形状: {main_features['input_features'].shape}")
|
362 |
+
print(f"主掩码形状: {main_features['attention_mask'].shape}")
|
363 |
+
print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}")
|
364 |
+
print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}")
|
365 |
+
print(f"标签形状: {labels.shape}")
|
366 |
+
print(f"Prompt标签形状: {prompt_labels.shape}")
|
367 |
+
print(f"模型输出形状: {outputs.shape}")
|
368 |
+
print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]")
|
369 |
+
|
370 |
+
# 计算和打印统计信息
|
371 |
+
avg_time = total_time / num_test_batches
|
372 |
+
print(f"\n=== 性能统计 ===")
|
373 |
+
print(f"平均每个batch处理时间: {avg_time:.4f}秒")
|
374 |
+
print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时")
|
375 |
+
if torch.cuda.is_available():
|
376 |
+
print(f"最大GPU内存使用: {max_memory:.2f}MB")
|
377 |
+
print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB")
|
378 |
+
|
379 |
+
print("\n测试完成!")
|