WpythonW commited on
Commit
e401779
·
verified ·
1 Parent(s): b11b92d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +139 -51
README.md CHANGED
@@ -86,65 +86,153 @@ Final metrics on validation set:
86
  - Precision: 0.9692 (96.92%)
87
  - Recall: 0.9728 (97.28%)
88
 
89
- ## Usage
90
-
91
- Here's how to use the model:
92
 
 
 
93
  ```python
 
 
 
 
 
 
 
 
 
 
 
94
  import torch
95
  import torchaudio
96
- from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
97
-
98
- # Load model and processor
99
- model = AutoModelForAudioClassification.from_pretrained("WpythonW/ast-fakeaudio-detector")
100
- processor = AutoFeatureExtractor.from_pretrained("WpythonW/ast-fakeaudio-detector")
101
-
102
- # Load audio file
103
- audio_path = "/content/r.mp3"
104
- waveform, sample_rate = torchaudio.load(audio_path)
105
-
106
- # Resample to required sample rate
107
- if sample_rate != 16000:
108
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
109
- waveform = resampler(waveform)
110
-
111
- # Convert stereo to mono if needed
112
- if waveform.shape[0] > 1:
113
- waveform = torch.mean(waveform, dim=0, keepdim=True)
114
-
115
- # Create mel spectrogram
116
- mel_spec = torchaudio.transforms.MelSpectrogram(
117
- sample_rate=16000,
118
- n_mels=128,
119
- n_fft=2048,
120
- hop_length=160
121
- )(waveform)
122
-
123
- # Convert to decibels
124
- mel_spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
125
-
126
- # Normalize
127
- mel_spec_norm = (mel_spec_db + 4.26) / (4.57 * 2)
128
-
129
- # Adjust to required length (1024 frames)
130
- if mel_spec_norm.shape[2] < 1024:
131
- # If shorter - pad with zeros
132
- padding = torch.zeros(1, 128, 1024 - mel_spec_norm.shape[2])
133
- mel_spec_norm = torch.cat([mel_spec_norm, padding], dim=2)
134
- else:
135
- # If longer - truncate
136
- mel_spec_norm = mel_spec_norm[:, :, :1024]
137
-
138
- # Get prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  with torch.no_grad():
140
- outputs = model(mel_spec_norm)
141
- probabilities = torch.sigmoid(outputs.logits)
142
- is_fake = probabilities > 0.5
143
 
144
- print(f"Probability of fake audio: {probabilities[0][0]:.4f}")
145
- print(f"Prediction: {'FAKE' if is_fake[0][0] else 'REAL'} audio")
146
  ```
147
 
 
 
 
 
 
 
 
 
148
  ## Limitations
149
 
150
  Important considerations when using this model:
 
86
  - Precision: 0.9692 (96.92%)
87
  - Recall: 0.9728 (97.28%)
88
 
89
+ # Usage Guide
 
 
90
 
91
+ ## 1. Environment Setup
92
+ First, clone the AST repository and install required dependencies:
93
  ```python
94
+ # Clone AST repository and set up path
95
+ git clone https://github.com/YuanGongND/ast.git
96
+ import sys
97
+ sys.path.append('./ast')
98
+ cd ast
99
+
100
+ # Install dependencies
101
+ pip install timm==0.4.5 wget
102
+
103
+ # Required imports
104
+ import os
105
  import torch
106
  import torchaudio
107
+ import matplotlib.pyplot as plt
108
+ import numpy as np
109
+ from torch import nn
110
+ from src.models import ASTModel
111
+ ```
112
+
113
+ ## 2. Model Implementation
114
+ Implement the BinaryAST model class:
115
+ ```python
116
+ class BinaryAST(nn.Module):
117
+ def __init__(self, pretrained_path='pretrained_models/audioset_10_10_0.4593.pth'):
118
+ super().__init__()
119
+ # Initialize AST base model
120
+ self.ast = ASTModel(
121
+ label_dim=527,
122
+ input_fdim=128,
123
+ input_tdim=1024,
124
+ imagenet_pretrain=True,
125
+ audioset_pretrain=False,
126
+ model_size='base384'
127
+ )
128
+
129
+ # Load pretrained weights if available
130
+ if os.path.exists(pretrained_path):
131
+ print(f"Loading pretrained weights from {pretrained_path}")
132
+ state_dict = torch.load(pretrained_path, map_location='cpu', weights_only=True)
133
+ self.ast.load_state_dict(state_dict, strict=False)
134
+
135
+ # Binary classification head
136
+ self.ast.mlp_head = nn.Sequential(
137
+ nn.LayerNorm(768),
138
+ nn.Dropout(0.3),
139
+ nn.Linear(768, 1)
140
+ )
141
+
142
+ def forward(self, x):
143
+ return self.ast(x)
144
+ ```
145
+
146
+ ## 3. Audio Processing Function
147
+ Function to preprocess audio files for model input:
148
+ ```python
149
+ def process_audio(file_path, sr=16000):
150
+ """
151
+ Process audio file for model inference.
152
+
153
+ Args:
154
+ file_path (str): Path to audio file
155
+ sr (int): Target sample rate (default: 16000)
156
+
157
+ Returns:
158
+ torch.Tensor: Processed mel spectrogram (1024 x 128)
159
+ """
160
+ # Load audio
161
+ audio_tensor, orig_sr = torchaudio.load(file_path)
162
+ print(f"Initial tensor shape: {audio_tensor.shape}, sample_rate={orig_sr}")
163
+
164
+ # Convert to mono if needed
165
+ if audio_tensor.shape[0] > 1:
166
+ audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
167
+
168
+ # Resample to target sample rate
169
+ if orig_sr != sr:
170
+ resampler = torchaudio.transforms.Resample(orig_sr, sr)
171
+ audio_tensor = resampler(audio_tensor)
172
+
173
+ # Create mel spectrogram
174
+ mel_spec = torchaudio.transforms.MelSpectrogram(
175
+ sample_rate=sr,
176
+ n_mels=128,
177
+ n_fft=2048,
178
+ hop_length=160
179
+ )(audio_tensor)
180
+ spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
181
+
182
+ # Post-process spectrogram
183
+ spec_db = spec_db.squeeze(0).transpose(0, 1)
184
+ spec_db = (spec_db + 4.26) / (4.57 * 2) # Normalize
185
+
186
+ # Ensure correct length (pad/trim to 1024 frames)
187
+ target_len = 1024
188
+ if spec_db.shape[0] < target_len:
189
+ pad = torch.zeros(target_len - spec_db.shape[0], 128)
190
+ spec_db = torch.cat([spec_db, pad], dim=0)
191
+ else:
192
+ spec_db = spec_db[:target_len, :]
193
+
194
+ return spec_db
195
+ ```
196
+
197
+ ## 4. Model Loading and Inference
198
+ Example of loading the model and running inference:
199
+ ```python
200
+ # Initialize and load model
201
+ model = BinaryAST()
202
+ checkpoint = torch.load('/content/final_model.pth', map_location='cpu')
203
+ model.load_state_dict(checkpoint['model_state_dict'])
204
+ model.eval()
205
+
206
+ # Process audio file
207
+ spec = process_audio('path_to_audio.mp3')
208
+
209
+ # Visualize spectrogram (optional)
210
+ plt.figure(figsize=(10, 3))
211
+ plt.imshow(spec.numpy().T, aspect='auto', origin='lower')
212
+ plt.title('Mel Spectrogram')
213
+ plt.xlabel('Time Frames')
214
+ plt.ylabel('Mel Bins')
215
+ plt.colorbar()
216
+ plt.show()
217
+
218
+ # Run inference
219
+ spec_batch = spec.unsqueeze(0)
220
  with torch.no_grad():
221
+ output = model(spec_batch)
222
+ prob_fake = torch.sigmoid(output).item()
 
223
 
224
+ print(f"Probability of fake audio: {prob_fake:.4f}")
225
+ print("Prediction:", "FAKE" if prob_fake > 0.5 else "REAL")
226
  ```
227
 
228
+ ## Key Notes:
229
+ - Ensure audio files are accessible and in a supported format
230
+ - The model expects 16kHz sample rate input
231
+ - Input audio is converted to mono if stereo
232
+ - The model outputs probability scores (>0.5 indicates fake audio)
233
+ - Visualization of spectrograms is optional but useful for debugging
234
+
235
+
236
  ## Limitations
237
 
238
  Important considerations when using this model: