Hmjz100 commited on
Commit
94ad715
·
1 Parent(s): 8559b66

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +54 -37
  3. requirements.txt +7 -5
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: YouTube To MT3
3
  emoji: 🎼
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.4.1
8
  app_file: app.py
9
- pinned: false
10
  duplicated_from: mdnestor/YouTube-to-MT3
11
  ---
12
 
 
1
  ---
2
  title: YouTube To MT3
3
  emoji: 🎼
4
+ colorFrom: purple
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 3.4.1
8
  app_file: app.py
9
+ pinned: true
10
  duplicated_from: mdnestor/YouTube-to-MT3
11
  ---
12
 
app.py CHANGED
@@ -1,18 +1,31 @@
1
  import os
 
 
2
  import gradio as gr
 
 
3
  import glob
4
 
5
- os.system("apt-get update -qq && apt-get install -qq libfluidsynth2 build-essential libasound2-dev libjack-dev")
6
 
7
- # install mt3
 
 
 
 
 
 
 
8
  os.system("git clone --branch=main https://github.com/magenta/mt3")
9
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
10
- os.system("python3 -m pip install nest-asyncio pyfluidsynth==1.3.0 -e .")
11
-
 
12
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
13
- os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
14
 
 
 
15
 
 
16
  import functools
17
  import os
18
 
@@ -24,6 +37,7 @@ import gin
24
  import jax
25
  import librosa
26
  import note_seq
 
27
  import seqio
28
  import t5
29
  import t5x
@@ -42,12 +56,16 @@ nest_asyncio.apply()
42
  SAMPLE_RATE = 16000
43
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
44
 
 
 
 
 
45
  class InferenceModel(object):
46
- """Wrapper of T5X model for music transcription."""
47
 
48
  def __init__(self, checkpoint_path, model_type='mt3'):
49
 
50
- # Model Constants.
51
  if model_type == 'ismir2021':
52
  num_velocity_bins = 127
53
  self.encoding_spec = note_sequences.NoteEncodingSpec
@@ -68,9 +86,9 @@ class InferenceModel(object):
68
  'targets': self.outputs_length}
69
 
70
  self.partitioner = t5x.partitioning.PjitPartitioner(
71
- model_parallel_submesh=(1, 1, 1, 1), num_partitions=1)
72
 
73
- # Build Codecs and Vocabularies.
74
  self.spectrogram_config = spectrograms.SpectrogramConfig()
75
  self.codec = vocabularies.build_codec(
76
  vocab_config=vocabularies.VocabularyConfig(
@@ -81,11 +99,11 @@ class InferenceModel(object):
81
  'targets': seqio.Feature(vocabulary=self.vocabulary),
82
  }
83
 
84
- # Create a T5X model.
85
  self._parse_gin(gin_files)
86
  self.model = self._load_model()
87
 
88
- # Restore from checkpoint.
89
  self.restore_from_checkpoint(checkpoint_path)
90
 
91
  @property
@@ -96,7 +114,7 @@ class InferenceModel(object):
96
  }
97
 
98
  def _parse_gin(self, gin_files):
99
- """Parse gin files used to train the model."""
100
  gin_bindings = [
101
  'from __gin__ import dynamic_registration',
102
  'from mt3 import vocabularies',
@@ -108,7 +126,7 @@ class InferenceModel(object):
108
  gin_files, gin_bindings, finalize_config=False)
109
 
110
  def _load_model(self):
111
- """Load up a T5X `Model` after parsing training gin config."""
112
  model_config = gin.get_configurable(network.T5Config)()
113
  module = network.Transformer(config=model_config)
114
  return models.ContinuousInputsEncoderDecoderModel(
@@ -120,7 +138,7 @@ class InferenceModel(object):
120
 
121
 
122
  def restore_from_checkpoint(self, checkpoint_path):
123
- """Restore training state from checkpoint, resets self._predict_fn()."""
124
  train_state_initializer = t5x.utils.TrainStateInitializer(
125
  optimizer_def=self.model.optimizer_def,
126
  init_fn=self.model.get_initial_variables,
@@ -137,7 +155,7 @@ class InferenceModel(object):
137
 
138
  @functools.lru_cache()
139
  def _get_predict_fn(self, train_state_axes):
140
- """Generate a partitioned prediction function for decoding."""
141
  def partial_predict_fn(params, batch, decode_rng):
142
  return self.model.predict_batch_with_aux(
143
  params, batch, decoder_params={'decode_rng': None})
@@ -150,18 +168,18 @@ class InferenceModel(object):
150
  )
151
 
152
  def predict_tokens(self, batch, seed=0):
153
- """Predict tokens from preprocessed dataset batch."""
154
  prediction, _ = self._predict_fn(
155
  self._train_state.params, batch, jax.random.PRNGKey(seed))
156
  return self.vocabulary.decode_tf(prediction).numpy()
157
 
158
  def __call__(self, audio):
159
- """Infer note sequence from audio samples.
160
 
161
- Args:
162
- audio: 1-d numpy array of audio samples (16kHz) for a single example.
163
- Returns:
164
- A note_sequence of the transcribed audio.
165
  """
166
  ds = self.audio_to_dataset(audio)
167
  ds = self.preprocess(ds)
@@ -182,7 +200,7 @@ class InferenceModel(object):
182
  return result['est_ns']
183
 
184
  def audio_to_dataset(self, audio):
185
- """Create a TF Dataset of spectrograms from input audio."""
186
  frames, frame_times = self._audio_to_frames(audio)
187
  return tf.data.Dataset.from_tensors({
188
  'inputs': frames,
@@ -190,7 +208,7 @@ class InferenceModel(object):
190
  })
191
 
192
  def _audio_to_frames(self, audio):
193
- """Compute spectrogram frames from audio."""
194
  frame_size = self.spectrogram_config.hop_width
195
  padding = [0, frame_size - len(audio) % frame_size]
196
  audio = np.pad(audio, padding, mode='constant')
@@ -207,7 +225,7 @@ class InferenceModel(object):
207
  output_features=self.output_features,
208
  feature_key='inputs',
209
  additional_feature_keys=['input_times']),
210
- # Cache occurs here during training.
211
  preprocessors.add_dummy_targets,
212
  functools.partial(
213
  preprocessors.compute_spectrograms,
@@ -220,12 +238,12 @@ class InferenceModel(object):
220
  def postprocess(self, tokens, example):
221
  tokens = self._trim_eos(tokens)
222
  start_time = example['input_times'][0]
223
- # Round down to nearest symbolic token step.
224
  start_time -= start_time % (1 / self.codec.steps_per_second)
225
  return {
226
  'est_tokens': tokens,
227
  'start_time': start_time,
228
- # Internal MT3 code expects raw inputs, not used here.
229
  'raw_inputs': []
230
  }
231
 
@@ -241,24 +259,23 @@ inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
241
 
242
  def inference(url):
243
  os.system(f"yt-dlp -x {url} -o 'audio.%(ext)s'")
244
- audio_file = glob.glob('audio.*')[0]
245
- with open(audio_file, 'rb') as f:
246
- data = f.read()
247
- audio = note_seq.audio_io.wav_data_to_samples_librosa(data, sample_rate=SAMPLE_RATE)
248
  est_ns = inference_model(audio)
249
- midi_file = f"./transcribed.mid"
250
- note_seq.sequence_proto_to_midi_file(est_ns, midi_file)
251
- return midi_file
252
 
253
  title = "YouTube-to-MT3"
254
- description = "Upload YouTube audio to MT3: Multi-Task Multitrack Music Transcription. Thanks to <a href=\"https://huggingface.co/spaces/akhaliq/MT3\">akhaliq</a> for the original <i>Spaces</i> implementation."
255
 
256
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: Multi-Task Multitrack Music Transcription</a> | <a href='https://github.com/magenta/mt3' target='_blank'>Github Repo</a></p>"
257
 
258
  gr.Interface(
259
  inference,
260
- gr.Textbox(label="Audio URL"),
261
- gr.outputs.File(label="Transcribed MIDI"),
262
  title=title,
263
  description=description,
264
  article=article,
 
1
  import os
2
+ os.system("pip install gradio")
3
+
4
  import gradio as gr
5
+ from pathlib import Path
6
+ os.system("pip install gsutil")
7
  import glob
8
 
 
9
 
10
+ os.system("git clone --branch=main https://github.com/google-research/t5x")
11
+ os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
12
+ os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
13
+ os.system("python3 -m pip install -e .")
14
+ os.system("python3 -m pip install --upgrade pip")
15
+
16
+
17
+ # 安装 mt3
18
  os.system("git clone --branch=main https://github.com/magenta/mt3")
19
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
20
+ os.system("python3 -m pip install -e .")
21
+ os.system("pip install tensorflow_cpu")
22
+ # 复制检查点
23
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
 
24
 
25
+ # 复制 soundfont 文件(原始文件来自 https://sites.google.com/site/soundfonts4u)
26
+ os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
27
 
28
+ #@title 导入和定义
29
  import functools
30
  import os
31
 
 
37
  import jax
38
  import librosa
39
  import note_seq
40
+
41
  import seqio
42
  import t5
43
  import t5x
 
56
  SAMPLE_RATE = 16000
57
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
58
 
59
+ def callbak_audio(audio, sample_rate):
60
+ return note_seq.audio_io.wav_data_to_samples_librosa(
61
+ audio, sample_rate=sample_rate)
62
+
63
  class InferenceModel(object):
64
+ """音乐转录的 T5X 模型包装器。"""
65
 
66
  def __init__(self, checkpoint_path, model_type='mt3'):
67
 
68
+ # 模型常量。
69
  if model_type == 'ismir2021':
70
  num_velocity_bins = 127
71
  self.encoding_spec = note_sequences.NoteEncodingSpec
 
86
  'targets': self.outputs_length}
87
 
88
  self.partitioner = t5x.partitioning.PjitPartitioner(
89
+ model_parallel_submesh=(1, 1, 1, 1))
90
 
91
+ # 构建编解码器和词汇表。
92
  self.spectrogram_config = spectrograms.SpectrogramConfig()
93
  self.codec = vocabularies.build_codec(
94
  vocab_config=vocabularies.VocabularyConfig(
 
99
  'targets': seqio.Feature(vocabulary=self.vocabulary),
100
  }
101
 
102
+ # 创建 T5X 模型。
103
  self._parse_gin(gin_files)
104
  self.model = self._load_model()
105
 
106
+ # 从检查点中恢复。
107
  self.restore_from_checkpoint(checkpoint_path)
108
 
109
  @property
 
114
  }
115
 
116
  def _parse_gin(self, gin_files):
117
+ """解析用于训练模型的 gin 文件。"""
118
  gin_bindings = [
119
  'from __gin__ import dynamic_registration',
120
  'from mt3 import vocabularies',
 
126
  gin_files, gin_bindings, finalize_config=False)
127
 
128
  def _load_model(self):
129
+ """在解析训练 gin 配置后加载 T5X `Model`。"""
130
  model_config = gin.get_configurable(network.T5Config)()
131
  module = network.Transformer(config=model_config)
132
  return models.ContinuousInputsEncoderDecoderModel(
 
138
 
139
 
140
  def restore_from_checkpoint(self, checkpoint_path):
141
+ """从检查点中恢复训练状态,重置 self._predict_fn()"""
142
  train_state_initializer = t5x.utils.TrainStateInitializer(
143
  optimizer_def=self.model.optimizer_def,
144
  init_fn=self.model.get_initial_variables,
 
155
 
156
  @functools.lru_cache()
157
  def _get_predict_fn(self, train_state_axes):
158
+ """生成一个分区的预测函数用于解码。"""
159
  def partial_predict_fn(params, batch, decode_rng):
160
  return self.model.predict_batch_with_aux(
161
  params, batch, decoder_params={'decode_rng': None})
 
168
  )
169
 
170
  def predict_tokens(self, batch, seed=0):
171
+ """从预处理的数据集批次中预测 tokens"""
172
  prediction, _ = self._predict_fn(
173
  self._train_state.params, batch, jax.random.PRNGKey(seed))
174
  return self.vocabulary.decode_tf(prediction).numpy()
175
 
176
  def __call__(self, audio):
177
+ """从音频样本推断出音符序列。
178
 
179
+ 参数:
180
+ audio:16kHz 的单个音频样本的 1 numpy 数组。
181
+ 返回:
182
+ 转录音频的音符序列。
183
  """
184
  ds = self.audio_to_dataset(audio)
185
  ds = self.preprocess(ds)
 
200
  return result['est_ns']
201
 
202
  def audio_to_dataset(self, audio):
203
+ """从输入音频创建一个包含频谱图的 TF Dataset"""
204
  frames, frame_times = self._audio_to_frames(audio)
205
  return tf.data.Dataset.from_tensors({
206
  'inputs': frames,
 
208
  })
209
 
210
  def _audio_to_frames(self, audio):
211
+ """从音频计算频谱图帧。"""
212
  frame_size = self.spectrogram_config.hop_width
213
  padding = [0, frame_size - len(audio) % frame_size]
214
  audio = np.pad(audio, padding, mode='constant')
 
225
  output_features=self.output_features,
226
  feature_key='inputs',
227
  additional_feature_keys=['input_times']),
228
+ # 在训练期间进行缓存。
229
  preprocessors.add_dummy_targets,
230
  functools.partial(
231
  preprocessors.compute_spectrograms,
 
238
  def postprocess(self, tokens, example):
239
  tokens = self._trim_eos(tokens)
240
  start_time = example['input_times'][0]
241
+ # 向下取整到最接近的符号化时间步。
242
  start_time -= start_time % (1 / self.codec.steps_per_second)
243
  return {
244
  'est_tokens': tokens,
245
  'start_time': start_time,
246
+ # 内部 MT3 代码期望原始输入,这里不使用。
247
  'raw_inputs': []
248
  }
249
 
 
259
 
260
  def inference(url):
261
  os.system(f"yt-dlp -x {url} -o 'audio.%(ext)s'")
262
+ audio = glob.glob('audio.*')[0]
263
+ with open(audio, 'rb') as fd:
264
+ contents = fd.read()
265
+ audio = callbak_audio(contents,sample_rate=16000)
266
  est_ns = inference_model(audio)
267
+ note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
268
+ return './transcribed.mid'
 
269
 
270
  title = "YouTube-to-MT3"
271
+ description = "YouTube音频上传到MT3:多任务多轨音乐转录。感谢 <a href=\"https://huggingface.co/spaces/akhaliq/MT3\">akhaliq</a> 的原始 <i>Spaces</i> 实现。"
272
 
273
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: 多任务多轨音乐转录</a> | <a href='https://github.com/magenta/mt3' target='_blank'>Github 仓库</a></p>"
274
 
275
  gr.Interface(
276
  inference,
277
+ gr.inputs.Textbox(label="URL"),
278
+ gr.outputs.File(label="输出"),
279
  title=title,
280
  description=description,
281
  article=article,
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- absl-py
 
 
2
  ddsp
3
- flax@git+https://github.com/google/flax#egg=flax
 
4
  gin-config
5
  immutabledict
6
  librosa
@@ -10,9 +13,8 @@ numpy
10
  pretty_midi
11
  scikit-learn
12
  scipy
13
- seqio @ git+https://github.com/google/seqio#egg=seqio
14
  t5
15
- t5x@git+https://github.com/google-research/t5x#egg=t5x
16
- tensorflow
17
  tensorflow-datasets
18
  yt-dlp
 
1
+ nest-asyncio
2
+ pyfluidsynth
3
+ absl-py
4
  ddsp
5
+ flax
6
+ glob
7
  gin-config
8
  immutabledict
9
  librosa
 
13
  pretty_midi
14
  scikit-learn
15
  scipy
16
+ seqio
17
  t5
18
+ tensorflow_cpu
 
19
  tensorflow-datasets
20
  yt-dlp