Spaces:
Running
on
Zero
Running
on
Zero
import yt_dlp | |
import re | |
import subprocess | |
import os | |
import shutil | |
from pydub import AudioSegment | |
import gradio as gr | |
import traceback | |
import logging | |
from inference import proc_folder_direct | |
from pathlib import Path | |
OUTPUT_FOLDER = "separation_results/" | |
INPUT_FOLDER = "input" | |
download_path = "" | |
def sanitize_filename(filename): | |
return re.sub(r'[\\/*?:"<>|]', '_', filename) | |
def delete_input_files(input_dir): | |
wav_dir = Path(input_dir) / "wav" | |
for wav_file in wav_dir.glob("*.wav"): | |
wav_file.unlink() | |
print(f"Deleted {wav_file}") | |
def download_youtube_audio_by_title(query, state=True): | |
if state: | |
delete_input_files(INPUT_FOLDER) | |
ydl_opts = { | |
'quiet': True, | |
'default_search': 'ytsearch', | |
'noplaylist': True, | |
'format': 'bestaudio/best', | |
'outtmpl': './input/wav/%(title)s.%(ext)s', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'wav', | |
}], | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
search_results = ydl.extract_info(query, download=False) | |
video_info = search_results['entries'][0] | |
video_url = video_info['webpage_url'] | |
video_title = video_info['title'] | |
match = re.match(r'^(.*? - .*?)(?: \[.*\]|\(.*\))?$', video_title) | |
formatted_title = match.group(1) if match else video_title | |
formatted_title = sanitize_filename(formatted_title.strip()) | |
ydl_opts['outtmpl'] = f'./input/wav/{formatted_title}.%(ext)s' | |
if state: | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([video_url]) | |
return f'./input/wav/{formatted_title}.wav' | |
return formatted_title | |
def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"): | |
command = [ | |
"python", "inference.py", | |
"--model_type", model_type, | |
"--config_path", config_path, | |
"--start_check_point", start_check_point, | |
"--INPUT_FOLDER", input_dir, | |
"--store_dir", output_dir, | |
"--device_ids", device_ids | |
] | |
return subprocess.run(command, check=True, capture_output=True, text=True) | |
def move_stems_to_parent(input_dir): | |
for subdir, dirs, files in os.walk(input_dir): | |
if subdir == input_dir: | |
continue | |
parent_dir = os.path.dirname(subdir) | |
song_name = os.path.basename(parent_dir) | |
if 'htdemucs' in subdir: | |
print(f"Processing htdemucs in {subdir}") | |
bass_path = os.path.join(subdir, f"{song_name}_bass.wav") | |
if os.path.exists(bass_path): | |
new_bass_path = os.path.join(parent_dir, "bass.wav") | |
print(f"Moving {bass_path} to {new_bass_path}") | |
shutil.move(bass_path, new_bass_path) | |
else: | |
print(f"Bass file not found: {bass_path}") | |
elif 'mel_band_roformer' in subdir: | |
print(f"Processing mel_band_roformer in {subdir}") | |
vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav") | |
if os.path.exists(vocals_path): | |
new_vocals_path = os.path.join(parent_dir, "vocals.wav") | |
print(f"Moving {vocals_path} to {new_vocals_path}") | |
shutil.move(vocals_path, new_vocals_path) | |
else: | |
print(f"Vocals file not found: {vocals_path}") | |
elif 'scnet' in subdir: | |
print(f"Processing scnet in {subdir}") | |
other_path = os.path.join(subdir, f"{song_name}_other.wav") | |
if os.path.exists(other_path): | |
new_other_path = os.path.join(parent_dir, "other.wav") | |
print(f"Moving {other_path} to {new_other_path}") | |
shutil.move(other_path, new_other_path) | |
else: | |
print(f"Other file not found: {other_path}") | |
elif 'bs_roformer' in subdir: | |
print(f"Processing bs_roformer in {subdir}") | |
instrumental_path = os.path.join(subdir, f"{song_name}_other.wav") | |
if os.path.exists(instrumental_path): | |
new_instrumental_path = os.path.join(parent_dir, "instrumental.wav") | |
print(f"Moving {instrumental_path} to {new_instrumental_path}") | |
shutil.move(instrumental_path, new_instrumental_path) | |
else: | |
print(f"Instrumental file not found: {instrumental_path}") | |
def combine_stems_for_all(input_dir): | |
for subdir, _, _ in os.walk(input_dir): | |
if subdir == input_dir: | |
continue | |
song_name = os.path.basename(subdir) | |
print(f"Processing {subdir}") | |
stem_paths = { | |
"vocals": os.path.join(subdir, "vocals.wav"), | |
"bass": os.path.join(subdir, "bass.wav"), | |
"others": os.path.join(subdir, "other.wav"), | |
"instrumental": os.path.join(subdir, "instrumental.wav") | |
} | |
if not all(os.path.exists(path) for path in stem_paths.values()): | |
print(f"Skipping {subdir}, not all stems are present.") | |
continue | |
stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()} | |
combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"]) | |
output_file = os.path.join(subdir, f"{song_name}.MDS.wav") | |
combined.export(output_file, format="wav") | |
print(f"Exported combined stems to {output_file}") | |
def delete_folders_and_files(input_dir): | |
folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer'] | |
files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav'] | |
for root, dirs, files in os.walk(input_dir, topdown=False): | |
if root == input_dir: | |
continue | |
for folder in folders_to_delete: | |
folder_path = os.path.join(root, folder) | |
if os.path.isdir(folder_path): | |
print(f"Deleting folder: {folder_path}") | |
shutil.rmtree(folder_path) | |
for file in files_to_delete: | |
file_path = os.path.join(root, file) | |
if os.path.isfile(file_path): | |
print(f"Deleting file: {file_path}") | |
os.remove(file_path) | |
for root, dirs, files in os.walk(OUTPUT_FOLDER): | |
for dir_name in dirs: | |
if dir_name.endswith('_vocals'): | |
dir_path = os.path.join(root, dir_name) | |
print(f"Deleting folder: {dir_path}") | |
shutil.rmtree(dir_path) | |
print("Cleanup completed.") | |
def process_audio(song_title): | |
try: | |
yield "Finding audio...", None | |
if title_input == "": | |
raise ValueError("Please enter a song title.") | |
formatted_title = download_youtube_audio_by_title(song_title, False) | |
yield "Starting SCNet inference...", None | |
proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
yield "Starting Mel Band Roformer inference...", None | |
proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True) | |
yield "Starting HTDemucs inference...", None | |
proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav' | |
destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav' | |
os.rename(source_path, destination_path) | |
yield "Starting BS Roformer inference...", None | |
proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER) | |
yield "Moving input files...", None | |
delete_input_files(INPUT_FOLDER) | |
yield "Moving stems to parent...", None | |
move_stems_to_parent(OUTPUT_FOLDER) | |
yield "Combining stems...", None | |
combine_stems_for_all(OUTPUT_FOLDER) | |
yield "Cleaning up...", None | |
delete_folders_and_files(OUTPUT_FOLDER) | |
yield f"Audio processing completed successfully.", f'{OUTPUT_FOLDER}{formatted_title}/{formatted_title}.MDS.wav' | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}" | |
logging.error(error_msg) | |
yield error_msg, None | |
with gr.Blocks() as demo: | |
gr.Markdown("# Music Player and Processor") | |
with gr.Row(): | |
title_input = gr.Textbox(label="Enter Song Title") | |
play_button = gr.Button("Play") | |
audio_output = gr.Audio(label="Audio Player") | |
process_button = gr.Button("Process Audio") | |
log_output = gr.Textbox(label="Processing Log", interactive=False) | |
processed_audio_output = gr.Audio(label="Processed Audio") | |
play_button.click( | |
fn=download_youtube_audio_by_title, | |
inputs=title_input, | |
outputs=audio_output | |
) | |
process_button.click( | |
fn=process_audio, | |
inputs=title_input, | |
outputs=[log_output, processed_audio_output], | |
show_progress=True | |
) | |
demo.launch() |