balkite commited on
Commit
c5de5c7
·
1 Parent(s): 6dddd69

Upload 7 files

Browse files
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+
4
+ from backend import sample_preparer, load_model
5
+
6
+
7
+ def process_file(file, model):
8
+ prediction = model.predict(np.array(sample_preparer(file)))
9
+
10
+ type_num = np.argmax(prediction, axis=1)
11
+
12
+ drum_types = ['Clap', 'Closed Hat', 'Kick', 'Open Hat', 'Snare']
13
+
14
+ return drum_types[int(type_num)]
15
+
16
+
17
+ def main_page():
18
+ st.set_page_config(page_title="Drum Classifier",
19
+ page_icon="🥁")
20
+
21
+ st.markdown("# Drum Classifier 🥁")
22
+ st.markdown("Classify Drum audio samples through the use of Artificial Intelligence / Machine Learning. The Drum "
23
+ "Audio Classifier, uses a Convolutional Neural Network to predict the most likely drum type of a "
24
+ "audio file. The dataset used to create this model was 2,700+ of my freelance music production audio "
25
+ "samples.")
26
+ st.markdown("Currently supported drums: Clap, Closed Hat, Kick, Open Hat, Snare.")
27
+ st.markdown("Drag and Drop a WAV or Mp3 audio File to classify.")
28
+
29
+ if "model" not in st.session_state:
30
+ with st.spinner("Loading Database..."):
31
+ st.session_state.model = load_model()
32
+
33
+ file = st.file_uploader(
34
+ "Upload an Audio File",
35
+ accept_multiple_files=False,
36
+ type=['wav', 'mp3'],
37
+ label_visibility="hidden"
38
+ )
39
+
40
+ if st.session_state.model and file:
41
+
42
+ st.audio(file)
43
+
44
+ with st.spinner("Processing..."):
45
+ drum_type = process_file(file, st.session_state.model)
46
+
47
+ st.markdown(f"\"{file.name}\" is a {drum_type}.")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ main_page()
backend.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import librosa
4
+ import os
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+
9
+ def load_model():
10
+ abs_path = os.getcwd()
11
+ model = tf.keras.models.load_model(abs_path + "/saved_model/model_20230607_02")
12
+
13
+ model.compile(optimizer='adam',
14
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
15
+ metrics=['accuracy'])
16
+
17
+ return model
18
+
19
+
20
+ def sample_preparer(location):
21
+ sample_data = []
22
+ sample = np.zeros((128, 100, 3))
23
+ y, sr = librosa.load(location, sr=22050)
24
+ y, _ = librosa.effects.trim(y, top_db=50)
25
+ y = librosa.resample(y=y, orig_sr=sr, target_sr=22050)
26
+ melspect = librosa.feature.melspectrogram(y=y)
27
+
28
+ for i, _ in enumerate(melspect): # 128
29
+ for j, _ in enumerate(melspect[i]): # LENGTH
30
+ sample[i][j] = melspect[i][j]
31
+
32
+ sample_data = [sample]
33
+
34
+ return sample_data
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ tensorflow
4
+ librosa
saved_model/model_20230607_02/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3c6d15493594f284ee77a16c6e0fa6e3b4c6be1317529403a15afe0e85404e4
3
+ size 23303
saved_model/model_20230607_02/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fcbde72731bc9e8898b4d9605122bdec5245cd301fa378d5548c28be5e43686
3
+ size 208998
saved_model/model_20230607_02/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cae4da18c25adb2f93e3fb42f9d4068b215cd017d47dfc8f9eb602bd4c5f1522
3
+ size 29618497
saved_model/model_20230607_02/variables/variables.index ADDED
Binary file (3.13 kB). View file