shreyasr09 commited on
Commit
574ecd0
·
verified ·
1 Parent(s): 0bdb746

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +171 -0
streamlit_app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_app.py
2
+
3
+ import streamlit as st
4
+ import os
5
+ import librosa
6
+ import librosa.display
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import tensorflow as tf
11
+ from tensorflow.keras.utils import to_categorical
12
+ from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, average_precision_score, calibration_curve, ConfusionMatrixDisplay
13
+ from keras.models import load_model
14
+
15
+ SAMPLE_RATE = 16000
16
+ DURATION = 5
17
+ N_MELS = 128
18
+ MAX_TIME_STEPS = 109
19
+ NUM_CLASSES = 2
20
+
21
+ # Streamlit App
22
+ st.title("Audio Spoofing Detection App")
23
+
24
+ st.sidebar.header("Model Options")
25
+ task = st.sidebar.selectbox("Select Task", ["Train Model", "Evaluate Model", "Visualize Spectrogram"])
26
+
27
+ if task == "Train Model":
28
+ st.header("Train a New Model")
29
+
30
+ uploaded_files = st.file_uploader("Upload FLAC Training Files", accept_multiple_files=True, type='flac')
31
+ label_file = st.file_uploader("Upload Labels File (txt)", type="txt")
32
+
33
+ if uploaded_files and label_file:
34
+ # Parse the label file
35
+ labels = {}
36
+ for line in label_file.getvalue().decode("utf-8").splitlines():
37
+ parts = line.strip().split()
38
+ file_name = parts[1]
39
+ label = 1 if parts[-1] == "bonafide" else 0
40
+ labels[file_name] = label
41
+
42
+ X, y = [], []
43
+ for file in uploaded_files:
44
+ file_name = file.name.split(".")[0]
45
+ label = labels[file_name]
46
+
47
+ # Load audio file
48
+ audio, _ = librosa.load(file, sr=SAMPLE_RATE, duration=DURATION)
49
+
50
+ # Extract Mel spectrogram
51
+ mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=SAMPLE_RATE, n_mels=N_MELS)
52
+ mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
53
+
54
+ # Padding
55
+ if mel_spectrogram.shape[1] < MAX_TIME_STEPS:
56
+ mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, MAX_TIME_STEPS - mel_spectrogram.shape[1])), mode='constant')
57
+ else:
58
+ mel_spectrogram = mel_spectrogram[:, :MAX_TIME_STEPS]
59
+
60
+ X.append(mel_spectrogram)
61
+ y.append(label)
62
+
63
+ X = np.array(X)
64
+ y = np.array(y)
65
+
66
+ y_encoded = to_categorical(y, NUM_CLASSES)
67
+
68
+ # Split into train and validation sets
69
+ split_index = int(0.8 * len(X))
70
+ X_train, X_val = X[:split_index], X[split_index:]
71
+ y_train, y_val = y_encoded[:split_index], y_encoded[split_index:]
72
+
73
+ input_shape = (N_MELS, X_train.shape[2], 1)
74
+
75
+ # Define CNN model
76
+ model_input = tf.keras.Input(shape=input_shape)
77
+ x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu')(model_input)
78
+ x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
79
+ x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
80
+ x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
81
+ x = tf.keras.layers.Flatten()(x)
82
+ x = tf.keras.layers.Dense(128, activation='relu')(x)
83
+ x = tf.keras.layers.Dropout(0.5)(x)
84
+ model_output = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
85
+
86
+ model = tf.keras.Model(inputs=model_input, outputs=model_output)
87
+ model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
88
+
89
+ # Train the model
90
+ if st.button("Start Training"):
91
+ st.write("Training in progress...")
92
+ model.fit(X_train, y_train, batch_size=32, epochs=10, validation_data=(X_val, y_val))
93
+ model.save("audio_classifier.h5")
94
+ st.success("Training Complete. Model Saved!")
95
+
96
+ if task == "Evaluate Model":
97
+ st.header("Evaluate a Trained Model")
98
+
99
+ model_file = st.file_uploader("Upload Model (h5)", type='h5')
100
+ test_files = st.file_uploader("Upload Test FLAC Files", accept_multiple_files=True, type='flac')
101
+ protocol_file = st.file_uploader("Upload Protocol File (txt)", type='txt')
102
+
103
+ if model_file and test_files and protocol_file:
104
+ # Load Model
105
+ model = load_model(model_file)
106
+
107
+ # Prepare test data
108
+ X_test = []
109
+ for file in test_files:
110
+ audio, _ = librosa.load(file, sr=SAMPLE_RATE, duration=DURATION)
111
+ mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=SAMPLE_RATE, n_mels=N_MELS)
112
+ mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
113
+ if mel_spectrogram.shape[1] < MAX_TIME_STEPS:
114
+ mel_spectrogram = np.pad(mel_spectrogram, ((0, 0), (0, MAX_TIME_STEPS - mel_spectrogram.shape[1])), mode='constant')
115
+ else:
116
+ mel_spectrogram = mel_spectrogram[:, :MAX_TIME_STEPS]
117
+ X_test.append(mel_spectrogram)
118
+
119
+ X_test = np.array(X_test)
120
+
121
+ y_pred = model.predict(X_test)
122
+ y_pred_classes = np.argmax(y_pred, axis=1)
123
+
124
+ # Parse the true labels
125
+ true_labels = {}
126
+ for line in protocol_file.getvalue().decode("utf-8").splitlines():
127
+ parts = line.strip().split()
128
+ if len(parts) > 1:
129
+ file_name = parts[0]
130
+ label = parts[-1]
131
+ true_labels[file_name] = 1 if label == "bonafide" else 0
132
+
133
+ y_true = np.array([label for label in true_labels.values()])
134
+
135
+ # Confusion Matrix
136
+ cm = confusion_matrix(y_true, y_pred_classes)
137
+ ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Spoof", "Bonafide"]).plot(cmap=plt.cm.Blues)
138
+ st.pyplot(plt)
139
+
140
+ # ROC Curve
141
+ y_pred_prob = y_pred[:, 1]
142
+ fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
143
+ roc_auc = auc(fpr, tpr)
144
+ plt.figure()
145
+ plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
146
+ plt.legend(loc="lower right")
147
+ st.pyplot(plt)
148
+
149
+ # Precision-Recall Curve
150
+ precision, recall, _ = precision_recall_curve(y_true, y_pred_prob)
151
+ avg_precision = average_precision_score(y_true, y_pred_prob)
152
+ plt.figure()
153
+ plt.plot(recall, precision, color='darkorange', lw=2, label=f'Avg. Precision = {avg_precision:.2f}')
154
+ st.pyplot(plt)
155
+
156
+ if task == "Visualize Spectrogram":
157
+ st.header("Visualize Mel Spectrogram")
158
+
159
+ test_files = st.file_uploader("Upload Test FLAC Files", accept_multiple_files=True, type='flac')
160
+
161
+ if test_files:
162
+ for file in test_files:
163
+ audio, _ = librosa.load(file, sr=SAMPLE_RATE, duration=DURATION)
164
+ mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=SAMPLE_RATE, n_mels=N_MELS)
165
+ mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
166
+
167
+ plt.figure(figsize=(10, 6))
168
+ librosa.display.specshow(mel_spectrogram, x_axis='time', y_axis='mel', sr=SAMPLE_RATE)
169
+ plt.colorbar(format='%+2.0f dB')
170
+ plt.title(f'Mel Spectrogram - {file.name}')
171
+ st.pyplot(plt)