suric commited on
Commit
97a428f
·
1 Parent(s): a114736

update prediction method

Browse files
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
 
3
  import gradio as gr
4
- from gradio_components.prediction import load_model, predict
 
5
 
6
  theme = gr.themes.Glass(
7
  primary_hue="fuchsia",
@@ -60,7 +61,6 @@ theme = gr.themes.Glass(
60
  button_cancel_background_fill_hover_dark='*button_cancel_background_fill'
61
  )
62
 
63
-
64
  _AUDIOCRAFT_MODELS = ["facebook/musicgen-melody",
65
  "facebook/musicgen-medium",
66
  "facebook/musicgen-small",
@@ -79,6 +79,7 @@ def generate_prompt(difficulty, style):
79
  )
80
  return prompt
81
 
 
82
  def UI():
83
  with gr.Blocks() as demo:
84
  with gr.Tab("Generate Music by melody"):
@@ -104,9 +105,19 @@ def UI():
104
  temperature = gr.Number(
105
  label="Temperature", value=1.0, interactive=True
106
  )
107
- sample_rate = gr.Number(label="output music sample rate", value=32000, interactive=True)
108
- difficulty = gr.Radio(["Easy", "Medium", "Hard"], label="Difficulty", value="Easy", interactive=True)
109
- style = gr.Radio(["Jazz", "Classical Music", "Hip Hop", "Others"], value="Classical Music", label="music genre", interactive=True)
 
 
 
 
 
 
 
 
 
 
110
  if style == "Others":
111
  style = gr.Textbox(label="Type your music genre")
112
  prompt = generate_prompt(difficulty.value, style.value)
@@ -119,29 +130,27 @@ def UI():
119
  with gr.Row():
120
  melody = gr.Audio(
121
  sources=["microphone", "upload"],
122
- streaming=True,
123
  label="Record or upload your audio",
124
- interactive=True,
125
- type="numpy",
126
  show_label=True,
127
- )
128
  with gr.Row():
129
  submit = gr.Button("Generate Music")
130
  output = gr.Audio("listen to the generated music")
131
- submit.click(fn=predict, inputs=melody, outputs=output)
132
 
133
- generate_music = gr.Button("Generate Music")
134
- output = gr.Audio("listen to the music")
135
- generate_music.click(fn=predict,
136
- inputs=[model_path, prompt, melody, duration, topk, topp, temperature, sample_rate],
137
- outputs=output)
 
138
 
139
  gr.Examples(
140
  examples=[
141
  [
142
  os.path.join(
143
  os.path.dirname(__file__),
144
- "./data/audio/twinkle_twinkle_little_stars_mozart.mp3"
145
  ),
146
  "Easy",
147
  32000,
@@ -150,7 +159,7 @@ def UI():
150
  [
151
  os.path.join(
152
  os.path.dirname(__file__),
153
- "./data/audio/golden_hour.mp3"
154
  ),
155
  "Easy",
156
  32000,
@@ -159,7 +168,7 @@ def UI():
159
  [
160
  os.path.join(
161
  os.path.dirname(__file__),
162
- "./data/audio/turkish_march_mozart.mp3"
163
  ),
164
  "Easy",
165
  32000,
@@ -168,7 +177,7 @@ def UI():
168
  [
169
  os.path.join(
170
  os.path.dirname(__file__),
171
- "./data/audio/golden_hour.mp3"
172
  ),
173
  "Hard",
174
  32000,
@@ -177,7 +186,7 @@ def UI():
177
  [
178
  os.path.join(
179
  os.path.dirname(__file__),
180
- "./data/audio/golden_hour.mp3"
181
  ),
182
  "Hard",
183
  32000,
@@ -186,7 +195,7 @@ def UI():
186
  [
187
  os.path.join(
188
  os.path.dirname(__file__),
189
- "./data/audio/golden_hour.mp3"
190
  ),
191
  "Hard",
192
  16000,
@@ -200,6 +209,6 @@ def UI():
200
  )
201
  demo.queue().launch()
202
 
 
203
  if __name__ == "__main__":
204
  UI()
205
-
 
1
  import os
2
 
3
  import gradio as gr
4
+
5
+ from gradio_components.prediction import predict
6
 
7
  theme = gr.themes.Glass(
8
  primary_hue="fuchsia",
 
61
  button_cancel_background_fill_hover_dark='*button_cancel_background_fill'
62
  )
63
 
 
64
  _AUDIOCRAFT_MODELS = ["facebook/musicgen-melody",
65
  "facebook/musicgen-medium",
66
  "facebook/musicgen-small",
 
79
  )
80
  return prompt
81
 
82
+
83
  def UI():
84
  with gr.Blocks() as demo:
85
  with gr.Tab("Generate Music by melody"):
 
105
  temperature = gr.Number(
106
  label="Temperature", value=1.0, interactive=True
107
  )
108
+ sample_rate = gr.Number(
109
+ label="output music sample rate", value=32000,
110
+ interactive=True
111
+ )
112
+ difficulty = gr.Radio(
113
+ ["Easy", "Medium", "Hard"], label="Difficulty",
114
+ value="Easy", interactive=True
115
+ )
116
+ style = gr.Radio(
117
+ ["Jazz", "Classical Music", "Hip Hop", "Others"],
118
+ value="Classical Music", label="music genre",
119
+ interactive=True
120
+ )
121
  if style == "Others":
122
  style = gr.Textbox(label="Type your music genre")
123
  prompt = generate_prompt(difficulty.value, style.value)
 
130
  with gr.Row():
131
  melody = gr.Audio(
132
  sources=["microphone", "upload"],
 
133
  label="Record or upload your audio",
134
+ #interactive=True,
 
135
  show_label=True,
136
+ )
137
  with gr.Row():
138
  submit = gr.Button("Generate Music")
139
  output = gr.Audio("listen to the generated music")
 
140
 
141
+ submit.click(
142
+ fn=predict,
143
+ inputs=[model_path, prompt, melody, duration, topk, topp, temperature,
144
+ sample_rate],
145
+ outputs=output
146
+ )
147
 
148
  gr.Examples(
149
  examples=[
150
  [
151
  os.path.join(
152
  os.path.dirname(__file__),
153
+ "./data/audio/twinkle_twinkle_little_stars_mozart_20sec.mp3"
154
  ),
155
  "Easy",
156
  32000,
 
159
  [
160
  os.path.join(
161
  os.path.dirname(__file__),
162
+ "./data/audio/golden_hour_20sec.mp3"
163
  ),
164
  "Easy",
165
  32000,
 
168
  [
169
  os.path.join(
170
  os.path.dirname(__file__),
171
+ "./data/audio/turkish_march_mozart_20sec.mp3"
172
  ),
173
  "Easy",
174
  32000,
 
177
  [
178
  os.path.join(
179
  os.path.dirname(__file__),
180
+ "./data/audio/golden_hour_20sec.mp3"
181
  ),
182
  "Hard",
183
  32000,
 
186
  [
187
  os.path.join(
188
  os.path.dirname(__file__),
189
+ "./data/audio/golden_hour_20sec.mp3"
190
  ),
191
  "Hard",
192
  32000,
 
195
  [
196
  os.path.join(
197
  os.path.dirname(__file__),
198
+ "./data/audio/golden_hour_20sec.mp3"
199
  ),
200
  "Hard",
201
  16000,
 
209
  )
210
  demo.queue().launch()
211
 
212
+
213
  if __name__ == "__main__":
214
  UI()
 
data/audio/golden_hour_20sec.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b4288206a5e10ba85e38f9f734ab21a01ae2e038028fc4d58b2e0404eaa6b38
3
+ size 150626
data/audio/turkish_march_mozart_20sec.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b503d887b6f567819f74bdfb491540e7e49708858341d02c83a045a53d1f7dd
3
+ size 146599
data/audio/twinkle_twinkle_little_stars_mozart.mp3 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ed46f25fb0031b270dafc14981e67121ecf094e15c6c6c138f7998672de8ce7a
3
  size 20276397
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f0082e0f5f042b80d91385c8f68c9f9109a0abc164972c0535705ebeb6708c2
3
  size 20276397
data/audio/twinkle_twinkle_little_stars_mozart_20sec.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d959a1125db5075112d4be8086f81722d216154fe721f4d8463cd8a2e06b3f3
3
+ size 154344
gradio_components/prediction.py CHANGED
@@ -8,6 +8,7 @@ from audiocraft.models import MusicGen
8
 
9
  from tempfile import NamedTemporaryFile
10
  from pathlib import Path
 
11
 
12
 
13
  def load_model(version='facebook/musicgen-melody'):
@@ -23,6 +24,7 @@ def _do_predictions(model, texts, melodies, duration, progress=False, gradio_pro
23
  processed_melodies.append(None)
24
  else:
25
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
 
26
  if melody.dim() == 1:
27
  melody = melody[None]
28
  melody = melody[..., :int(sr * duration)]
@@ -63,12 +65,12 @@ def predict(model_path, text, melody, duration, topk, topp, temperature, target_
63
  INTERRUPTING = False
64
  progress(0, desc="Loading model...")
65
  model_path = model_path.strip()
66
- if model_path:
67
- if not Path(model_path).exists():
68
- raise gr.Error(f"Model path {model_path} doesn't exist.")
69
- if not Path(model_path).is_dir():
70
- raise gr.Error(f"Model path {model_path} must be a folder containing "
71
- "state_dict.bin and compression_state_dict_.bin.")
72
  if temperature < 0:
73
  raise gr.Error("Temperature must be >= 0.")
74
  if topk < 0:
 
8
 
9
  from tempfile import NamedTemporaryFile
10
  from pathlib import Path
11
+ from transformers import AutoModelForSeq2SeqLM
12
 
13
 
14
  def load_model(version='facebook/musicgen-melody'):
 
24
  processed_melodies.append(None)
25
  else:
26
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
27
+ print(f"Input audio sample rate is {sr}")
28
  if melody.dim() == 1:
29
  melody = melody[None]
30
  melody = melody[..., :int(sr * duration)]
 
65
  INTERRUPTING = False
66
  progress(0, desc="Loading model...")
67
  model_path = model_path.strip()
68
+ # if model_path:
69
+ # if not Path(model_path).exists():
70
+ # raise gr.Error(f"Model path {model_path} doesn't exist.")
71
+ # if not Path(model_path).is_dir():
72
+ # raise gr.Error(f"Model path {model_path} must be a folder containing "
73
+ # "state_dict.bin and compression_state_dict_.bin.")
74
  if temperature < 0:
75
  raise gr.Error("Temperature must be >= 0.")
76
  if topk < 0: