Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from backend import sample_preparer, load_model
|
5 |
+
|
6 |
+
|
7 |
+
def process_file(file, model):
|
8 |
+
prediction = model.predict(np.array(sample_preparer(file)))
|
9 |
+
|
10 |
+
type_num = np.argmax(prediction, axis=1)
|
11 |
+
|
12 |
+
drum_types = ['Clap', 'Closed Hat', 'Kick', 'Open Hat', 'Snare']
|
13 |
+
|
14 |
+
return drum_types[int(type_num)]
|
15 |
+
|
16 |
+
|
17 |
+
def main_page():
|
18 |
+
st.set_page_config(page_title="Drum Classifier",
|
19 |
+
page_icon="🥁")
|
20 |
+
|
21 |
+
st.markdown("# Drum Classifier 🥁")
|
22 |
+
st.markdown("Classify Drum audio samples through the use of Artificial Intelligence / Machine Learning. The Drum "
|
23 |
+
"Audio Classifier, uses a Convolutional Neural Network to predict the most likely drum type of a "
|
24 |
+
"audio file. The dataset used to create this model was 2,700+ of my freelance music production audio "
|
25 |
+
"samples.")
|
26 |
+
st.markdown("Currently supported drums: Clap, Closed Hat, Kick, Open Hat, Snare.")
|
27 |
+
st.markdown("Drag and Drop a WAV or Mp3 audio File to classify.")
|
28 |
+
|
29 |
+
if "model" not in st.session_state:
|
30 |
+
with st.spinner("Loading Database..."):
|
31 |
+
st.session_state.model = load_model()
|
32 |
+
|
33 |
+
file = st.file_uploader(
|
34 |
+
"Upload an Audio File",
|
35 |
+
accept_multiple_files=False,
|
36 |
+
type=['wav', 'mp3'],
|
37 |
+
label_visibility="hidden"
|
38 |
+
)
|
39 |
+
|
40 |
+
if st.session_state.model and file:
|
41 |
+
|
42 |
+
st.audio(file)
|
43 |
+
|
44 |
+
with st.spinner("Processing..."):
|
45 |
+
drum_type = process_file(file, st.session_state.model)
|
46 |
+
|
47 |
+
st.markdown(f"\"{file.name}\" is a {drum_type}.")
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
main_page()
|
backend.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import librosa
|
4 |
+
import os
|
5 |
+
import warnings
|
6 |
+
warnings.filterwarnings("ignore")
|
7 |
+
|
8 |
+
|
9 |
+
def load_model():
|
10 |
+
abs_path = os.getcwd()
|
11 |
+
model = tf.keras.models.load_model(abs_path + "/saved_model/model_20230607_02")
|
12 |
+
|
13 |
+
model.compile(optimizer='adam',
|
14 |
+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
15 |
+
metrics=['accuracy'])
|
16 |
+
|
17 |
+
return model
|
18 |
+
|
19 |
+
|
20 |
+
def sample_preparer(location):
|
21 |
+
sample_data = []
|
22 |
+
sample = np.zeros((128, 100, 3))
|
23 |
+
y, sr = librosa.load(location, sr=22050)
|
24 |
+
y, _ = librosa.effects.trim(y, top_db=50)
|
25 |
+
y = librosa.resample(y=y, orig_sr=sr, target_sr=22050)
|
26 |
+
melspect = librosa.feature.melspectrogram(y=y)
|
27 |
+
|
28 |
+
for i, _ in enumerate(melspect): # 128
|
29 |
+
for j, _ in enumerate(melspect[i]): # LENGTH
|
30 |
+
sample[i][j] = melspect[i][j]
|
31 |
+
|
32 |
+
sample_data = [sample]
|
33 |
+
|
34 |
+
return sample_data
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
numpy
|
3 |
+
tensorflow
|
4 |
+
librosa
|
saved_model/model_20230607_02/keras_metadata.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3c6d15493594f284ee77a16c6e0fa6e3b4c6be1317529403a15afe0e85404e4
|
3 |
+
size 23303
|
saved_model/model_20230607_02/saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fcbde72731bc9e8898b4d9605122bdec5245cd301fa378d5548c28be5e43686
|
3 |
+
size 208998
|
saved_model/model_20230607_02/variables/variables.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cae4da18c25adb2f93e3fb42f9d4068b215cd017d47dfc8f9eb602bd4c5f1522
|
3 |
+
size 29618497
|
saved_model/model_20230607_02/variables/variables.index
ADDED
Binary file (3.13 kB). View file
|
|