|
|
|
import spaces |
|
|
|
import sys |
|
import os |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src'))) |
|
|
|
import subprocess |
|
from typing import Tuple, Dict, Literal |
|
from ctypes import ArgumentError |
|
|
|
from html_helper import * |
|
from model_helper import * |
|
|
|
from pytube import YouTube |
|
import torchaudio |
|
import glob |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
model_name = 'YPTF.MoE+Multi (noPS)' |
|
precision = '16' |
|
project = '2024' |
|
|
|
if model_name == "YMT3+": |
|
checkpoint = "[email protected]" |
|
args = [checkpoint, '-p', project, '-pr', precision] |
|
elif model_name == "YPTF+Single (noPS)": |
|
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt" |
|
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec', |
|
'-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF+Multi (PS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', |
|
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf', |
|
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF.MoE+Multi (noPS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', |
|
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', |
|
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', |
|
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF.MoE+Multi (PS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', |
|
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', |
|
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', |
|
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
else: |
|
raise ValueError(model_name) |
|
|
|
model = load_model_checkpoint(args=args, device="cpu") |
|
model.to("cuda") |
|
|
|
|
|
|
|
def prepare_media(source_path_or_url: os.PathLike, |
|
source_type: Literal['audio_filepath', 'youtube_url'], |
|
delete_video: bool = True) -> Dict: |
|
"""prepare media from source path or youtube, and return audio info""" |
|
|
|
if source_type == 'audio_filepath': |
|
audio_file = source_path_or_url |
|
elif source_type == 'youtube_url': |
|
|
|
try: |
|
|
|
yt = YouTube(source_path_or_url) |
|
audio_stream = min(yt.streams.filter(only_audio=True), key=lambda s: s.bitrate) |
|
mp4_file = audio_stream.download(output_path='downloaded') |
|
audio_file = mp4_file[:-3] + 'mp3' |
|
subprocess.run(['ffmpeg', '-i', mp4_file, '-ac', '1', audio_file]) |
|
os.remove(mp4_file) |
|
except Exception as e: |
|
try: |
|
|
|
print(f"Failed with PyTube, error: {e}. Trying yt-dlp...") |
|
audio_file = './downloaded/yt_audio' |
|
subprocess.run(['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio', |
|
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames', |
|
'--force-overwrites']) |
|
audio_file += '.mp3' |
|
except Exception as e: |
|
print(f"Alternative downloader failed, error: {e}. Please try again later!") |
|
return None |
|
else: |
|
raise ValueError(source_type) |
|
|
|
|
|
info = torchaudio.info(audio_file) |
|
return { |
|
"filepath": audio_file, |
|
"track_name": os.path.basename(audio_file).split('.')[0], |
|
"sample_rate": int(info.sample_rate), |
|
"bits_per_sample": int(info.bits_per_sample), |
|
"num_channels": int(info.num_channels), |
|
"num_frames": int(info.num_frames), |
|
"duration": int(info.num_frames / info.sample_rate), |
|
"encoding": str.lower(info.encoding), |
|
} |
|
|
|
@spaces.GPU |
|
def process_audio(audio_filepath): |
|
if audio_filepath is None: |
|
return None |
|
audio_info = prepare_media(audio_filepath, source_type='audio_filepath') |
|
midifile = transcribe(model, audio_info) |
|
midifile = to_data_url(midifile) |
|
return create_html_from_midi(midifile) |
|
|
|
@spaces.GPU |
|
def process_video(youtube_url): |
|
if 'youtu' not in youtube_url: |
|
return None |
|
audio_info = prepare_media(youtube_url, source_type='youtube_url') |
|
midifile = transcribe(model, audio_info) |
|
midifile = to_data_url(midifile) |
|
return create_html_from_midi(midifile) |
|
|
|
def play_video(youtube_url): |
|
if 'youtu' not in youtube_url: |
|
return None |
|
return create_html_youtube_player(youtube_url) |
|
|
|
|
|
|
|
AUDIO_EXAMPLES = glob.glob('examples/*.*', recursive=True) |
|
YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c", |
|
"https://youtu.be/OXXRoa1U6xU?si=nhJ6lzGenCmk4P7R", |
|
"https://youtu.be/EOJ0wH6h3rE?si=a99k6BnSajvNmXcn", |
|
"https://youtu.be/7mjQooXt28o?si=qqmMxCxwqBlLPDI2", |
|
"https://youtu.be/bnS-HK_lTHA?si=PQLVAab3QHMbv0S3https://youtu.be/zJB0nnOc7bM?si=EA1DN8nHWJcpQWp_", |
|
"https://youtu.be/mIWYTg55h10?si=WkbtKfL6NlNquvT8"] |
|
|
|
|
|
|
|
theme = gr.Theme.from_hub("gradio/dracula_revamped") |
|
theme.text_md = '9px' |
|
theme.text_lg = '11px' |
|
css = """ |
|
.gradio-container { |
|
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); |
|
background-size: 400% 400%; |
|
animation: gradient 15s ease infinite; |
|
height: 100vh; |
|
} |
|
@keyframes gradient { |
|
0% { |
|
background-position: 0% 50%; |
|
} |
|
50% { |
|
background-position: 100% 50%; |
|
} |
|
100% { |
|
background-position: 0% 50%; |
|
} |
|
} |
|
""" |
|
|
|
with gr.Blocks(theme=theme, css=css) as demo: |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=10): |
|
gr.Markdown( |
|
""" |
|
## 🎶YourMT3+: Multi-instrument Music Transcription with Enhanced Transformer Architectures and Cross-dataset Stem Augmentation |
|
### Model card: |
|
- Model name: `YPTF.MoE+Multi` |
|
- Encoder backbone: Perceiver-TF + Mixture of Experts (2/8) |
|
- Decoder backbone: Multi-channel T5-small |
|
- Tokenizer: MT3 tokens with Singing extension |
|
- Dataset: YourMT3 dataset |
|
- Augmentation strategy: Intra-/Cross dataset stem augment, No Pitch-shifting |
|
- FP Precision: BF16-mixed for training, FP16 for inference |
|
|
|
#### Caution: |
|
- Currently running on CPU, and it takes longer than 3 minutes for a 30-second input. |
|
- For acadmic reproduction purpose, we strongly recommend to use [Colab Demo](https://colab.research.google.com/drive/1AgOVEBfZknDkjmSRA7leoa81a2vrnhBG?usp=sharing) with multiple checkpoints. |
|
### [arxiv:2407.04822](https://arxiv.org/abs/2407.04822) | [Code](https://github.com/mimbres/YourMT3) |
|
""") |
|
|
|
with gr.Group(): |
|
with gr.Tab("Upload audio"): |
|
|
|
audio_input = gr.Audio(label="Record Audio", type="filepath", |
|
show_share_button=True, show_download_button=True) |
|
|
|
gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input) |
|
|
|
transcribe_audio_button = gr.Button("Transcribe", variant="primary") |
|
|
|
output_tab1 = gr.HTML() |
|
|
|
|
|
transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1) |
|
|
|
with gr.Tab("From YouTube"): |
|
with gr.Row(): |
|
|
|
youtube_url = gr.Textbox(label="YouTube Link URL", |
|
placeholder="https://youtu.be/...") |
|
|
|
youtube_player = gr.HTML(render=True) |
|
with gr.Row(): |
|
|
|
play_video_button = gr.Button("Get Audio from YouTube", variant="primary") |
|
|
|
transcribe_video_button = gr.Button("Transcribe", variant="primary") |
|
|
|
output_tab2 = gr.HTML(render=True) |
|
|
|
transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2) |
|
|
|
play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player) |
|
|
|
|
|
gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url) |
|
|
|
demo.launch(debug=True) |
|
|