cache + dl
Browse files- .gitignore +1 -0
- app.py +32 -6
- model.py +14 -9
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pyc
|
app.py
CHANGED
@@ -1,9 +1,18 @@
|
|
1 |
import streamlit as st
|
2 |
from model import generate
|
|
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
if "result" not in st.session_state:
|
6 |
-
st.session_state["result"] =
|
7 |
|
8 |
st.title("Sound Exploration")
|
9 |
|
@@ -12,7 +21,7 @@ col1, col2 = st.columns(2)
|
|
12 |
with col1:
|
13 |
instrument = st.selectbox(
|
14 |
'Which intrument do you want ?',
|
15 |
-
('πΈ Bass', 'πΊ Brass', 'πͺ Flute', 'πͺ Guitar', 'πΉ Keyboard', 'π¨ Mallet', 'Organ', 'Reed', 'π» String', 'Synth lead', '
|
16 |
)
|
17 |
|
18 |
with col2:
|
@@ -22,11 +31,28 @@ with col2:
|
|
22 |
)
|
23 |
|
24 |
with st.expander("Magical parameters πͺ"):
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
if st.button("Generate β¨", type="primary"):
|
28 |
-
st.session_state["result"] = generate([instrument, instrument_t])
|
29 |
|
30 |
-
if st.session_state["result"]
|
31 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
|
|
1 |
import streamlit as st
|
2 |
from model import generate
|
3 |
+
import io
|
4 |
import numpy as np
|
5 |
+
from scipy.io.wavfile import write
|
6 |
+
|
7 |
+
@st.cache_data
|
8 |
+
def np_to_wav(waveform, sample_rate) -> bytes:
|
9 |
+
bytes_wav = bytes()
|
10 |
+
byte_io = io.BytesIO(bytes_wav)
|
11 |
+
write(byte_io, sample_rate, waveform.T)
|
12 |
+
return byte_io.read()
|
13 |
|
14 |
if "result" not in st.session_state:
|
15 |
+
st.session_state["result"] = None
|
16 |
|
17 |
st.title("Sound Exploration")
|
18 |
|
|
|
21 |
with col1:
|
22 |
instrument = st.selectbox(
|
23 |
'Which intrument do you want ?',
|
24 |
+
('πΈ Bass', 'πΊ Brass', 'πͺ Flute', 'πͺ Guitar', 'πΉ Keyboard', 'π¨ Mallet', 'Organ', 'π· Reed', 'π» String', 'β‘ Synth lead', 'π€ Vocal')
|
25 |
)
|
26 |
|
27 |
with col2:
|
|
|
31 |
)
|
32 |
|
33 |
with st.expander("Magical parameters πͺ"):
|
34 |
+
col1, col2 = st.columns(2)
|
35 |
+
with col1:
|
36 |
+
p1 = st.slider('p1', 0., 1., step=0.001, label_visibility='collapsed')
|
37 |
+
p2 = st.slider('p2', 0., 1., step=0.001, label_visibility='collapsed')
|
38 |
+
p3 = st.slider('p3', 0., 1., step=0.001, label_visibility='collapsed')
|
39 |
+
with col2:
|
40 |
+
p4 = st.slider('p4', 0., 1., step=0.001, label_visibility='collapsed')
|
41 |
+
p5 = st.slider('p5', 0., 1., step=0.001, label_visibility='collapsed')
|
42 |
+
use_params = st.toggle('Use magical parameters ?')
|
43 |
+
params = (p1, p2, p3, p4, p5) if use_params else None
|
44 |
|
45 |
if st.button("Generate β¨", type="primary"):
|
46 |
+
st.session_state["result"] = generate([instrument, instrument_t], params)
|
47 |
|
48 |
+
if st.session_state["result"] is not None:
|
49 |
+
col1, col2 = st.columns(2)
|
50 |
+
with col1:
|
51 |
+
st.audio(st.session_state["result"], sample_rate=16000)
|
52 |
+
with col2:
|
53 |
+
st.download_button(
|
54 |
+
label="Download β¬οΈ",
|
55 |
+
data=np_to_wav(st.session_state["result"], 16000),
|
56 |
+
file_name='result.wav',
|
57 |
+
)
|
58 |
|
model.py
CHANGED
@@ -1,20 +1,25 @@
|
|
1 |
from cvae import CVAE
|
2 |
import torch
|
3 |
from typing import Sequence
|
|
|
4 |
|
5 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
6 |
|
7 |
instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def format(text):
|
20 |
text = text.split(' ')[-1]
|
|
|
1 |
from cvae import CVAE
|
2 |
import torch
|
3 |
from typing import Sequence
|
4 |
+
import streamlit as st
|
5 |
|
6 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
7 |
|
8 |
instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
|
9 |
|
10 |
+
@st.cache_resource
|
11 |
+
def load_model(device):
|
12 |
+
return CVAE.load_from_checkpoint(
|
13 |
+
'epoch=17-step=650718.ckpt',
|
14 |
+
io_channels=1,
|
15 |
+
io_features=16000*4,
|
16 |
+
latent_features=5,
|
17 |
+
channels=[32, 64, 128, 256, 512],
|
18 |
+
num_classes=len(instruments),
|
19 |
+
learning_rate=1e-5
|
20 |
+
).to(device)
|
21 |
+
|
22 |
+
model = load_model(device)
|
23 |
|
24 |
def format(text):
|
25 |
text = text.split(' ')[-1]
|