Update README.md
Browse files
README.md
CHANGED
@@ -86,65 +86,153 @@ Final metrics on validation set:
|
|
86 |
- Precision: 0.9692 (96.92%)
|
87 |
- Recall: 0.9728 (97.28%)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
Here's how to use the model:
|
92 |
|
|
|
|
|
93 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
import torch
|
95 |
import torchaudio
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
with torch.no_grad():
|
140 |
-
|
141 |
-
|
142 |
-
is_fake = probabilities > 0.5
|
143 |
|
144 |
-
print(f"Probability of fake audio: {
|
145 |
-
print(
|
146 |
```
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
## Limitations
|
149 |
|
150 |
Important considerations when using this model:
|
|
|
86 |
- Precision: 0.9692 (96.92%)
|
87 |
- Recall: 0.9728 (97.28%)
|
88 |
|
89 |
+
# Usage Guide
|
|
|
|
|
90 |
|
91 |
+
## 1. Environment Setup
|
92 |
+
First, clone the AST repository and install required dependencies:
|
93 |
```python
|
94 |
+
# Clone AST repository and set up path
|
95 |
+
git clone https://github.com/YuanGongND/ast.git
|
96 |
+
import sys
|
97 |
+
sys.path.append('./ast')
|
98 |
+
cd ast
|
99 |
+
|
100 |
+
# Install dependencies
|
101 |
+
pip install timm==0.4.5 wget
|
102 |
+
|
103 |
+
# Required imports
|
104 |
+
import os
|
105 |
import torch
|
106 |
import torchaudio
|
107 |
+
import matplotlib.pyplot as plt
|
108 |
+
import numpy as np
|
109 |
+
from torch import nn
|
110 |
+
from src.models import ASTModel
|
111 |
+
```
|
112 |
+
|
113 |
+
## 2. Model Implementation
|
114 |
+
Implement the BinaryAST model class:
|
115 |
+
```python
|
116 |
+
class BinaryAST(nn.Module):
|
117 |
+
def __init__(self, pretrained_path='pretrained_models/audioset_10_10_0.4593.pth'):
|
118 |
+
super().__init__()
|
119 |
+
# Initialize AST base model
|
120 |
+
self.ast = ASTModel(
|
121 |
+
label_dim=527,
|
122 |
+
input_fdim=128,
|
123 |
+
input_tdim=1024,
|
124 |
+
imagenet_pretrain=True,
|
125 |
+
audioset_pretrain=False,
|
126 |
+
model_size='base384'
|
127 |
+
)
|
128 |
+
|
129 |
+
# Load pretrained weights if available
|
130 |
+
if os.path.exists(pretrained_path):
|
131 |
+
print(f"Loading pretrained weights from {pretrained_path}")
|
132 |
+
state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True)
|
133 |
+
self.ast.load_state_dict(state_dict, strict=False)
|
134 |
+
|
135 |
+
# Binary classification head
|
136 |
+
self.ast.mlp_head = nn.Sequential(
|
137 |
+
nn.LayerNorm(768),
|
138 |
+
nn.Dropout(0.3),
|
139 |
+
nn.Linear(768, 1)
|
140 |
+
)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
return self.ast(x)
|
144 |
+
```
|
145 |
+
|
146 |
+
## 3. Audio Processing Function
|
147 |
+
Function to preprocess audio files for model input:
|
148 |
+
```python
|
149 |
+
def process_audio(file_path, sr=16000):
|
150 |
+
"""
|
151 |
+
Process audio file for model inference.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
file_path (str): Path to audio file
|
155 |
+
sr (int): Target sample rate (default: 16000)
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Processed mel spectrogram (1024 x 128)
|
159 |
+
"""
|
160 |
+
# Load audio
|
161 |
+
audio_tensor, orig_sr = torchaudio.load(file_path)
|
162 |
+
print(f"Initial tensor shape: {audio_tensor.shape}, sample_rate={orig_sr}")
|
163 |
+
|
164 |
+
# Convert to mono if needed
|
165 |
+
if audio_tensor.shape[0] > 1:
|
166 |
+
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
|
167 |
+
|
168 |
+
# Resample to target sample rate
|
169 |
+
if orig_sr != sr:
|
170 |
+
resampler = torchaudio.transforms.Resample(orig_sr, sr)
|
171 |
+
audio_tensor = resampler(audio_tensor)
|
172 |
+
|
173 |
+
# Create mel spectrogram
|
174 |
+
mel_spec = torchaudio.transforms.MelSpectrogram(
|
175 |
+
sample_rate=sr,
|
176 |
+
n_mels=128,
|
177 |
+
n_fft=2048,
|
178 |
+
hop_length=160
|
179 |
+
)(audio_tensor)
|
180 |
+
spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
|
181 |
+
|
182 |
+
# Post-process spectrogram
|
183 |
+
spec_db = spec_db.squeeze(0).transpose(0, 1)
|
184 |
+
spec_db = (spec_db + 4.26) / (4.57 * 2) # Normalize
|
185 |
+
|
186 |
+
# Ensure correct length (pad/trim to 1024 frames)
|
187 |
+
target_len = 1024
|
188 |
+
if spec_db.shape[0] < target_len:
|
189 |
+
pad = torch.zeros(target_len - spec_db.shape[0], 128)
|
190 |
+
spec_db = torch.cat([spec_db, pad], dim=0)
|
191 |
+
else:
|
192 |
+
spec_db = spec_db[:target_len, :]
|
193 |
+
|
194 |
+
return spec_db
|
195 |
+
```
|
196 |
+
|
197 |
+
## 4. Model Loading and Inference
|
198 |
+
Example of loading the model and running inference:
|
199 |
+
```python
|
200 |
+
# Initialize and load model
|
201 |
+
model = BinaryAST()
|
202 |
+
checkpoint = torch.load('/content/final_model.pth', map_location='cpu')
|
203 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
204 |
+
model.eval()
|
205 |
+
|
206 |
+
# Process audio file
|
207 |
+
spec = process_audio('path_to_audio.mp3')
|
208 |
+
|
209 |
+
# Visualize spectrogram (optional)
|
210 |
+
plt.figure(figsize=(10, 3))
|
211 |
+
plt.imshow(spec.numpy().T, aspect='auto', origin='lower')
|
212 |
+
plt.title('Mel Spectrogram')
|
213 |
+
plt.xlabel('Time Frames')
|
214 |
+
plt.ylabel('Mel Bins')
|
215 |
+
plt.colorbar()
|
216 |
+
plt.show()
|
217 |
+
|
218 |
+
# Run inference
|
219 |
+
spec_batch = spec.unsqueeze(0)
|
220 |
with torch.no_grad():
|
221 |
+
output = model(spec_batch)
|
222 |
+
prob_fake = torch.sigmoid(output).item()
|
|
|
223 |
|
224 |
+
print(f"Probability of fake audio: {prob_fake:.4f}")
|
225 |
+
print("Prediction:", "FAKE" if prob_fake > 0.5 else "REAL")
|
226 |
```
|
227 |
|
228 |
+
## Key Notes:
|
229 |
+
- Ensure audio files are accessible and in a supported format
|
230 |
+
- The model expects 16kHz sample rate input
|
231 |
+
- Input audio is converted to mono if stereo
|
232 |
+
- The model outputs probability scores (>0.5 indicates fake audio)
|
233 |
+
- Visualization of spectrograms is optional but useful for debugging
|
234 |
+
|
235 |
+
|
236 |
## Limitations
|
237 |
|
238 |
Important considerations when using this model:
|