import yt_dlp import re import subprocess import os import shutil from pydub import AudioSegment import gradio as gr import traceback import logging from inference import proc_folder_direct from pathlib import Path OUTPUT_FOLDER = "separation_results/" INPUT_FOLDER = "input" download_path = "" def sanitize_filename(filename): return re.sub(r'[\\/*?:"<>|]', '_', filename) def delete_input_files(input_dir): wav_dir = Path(input_dir) / "wav" for wav_file in wav_dir.glob("*.wav"): wav_file.unlink() print(f"Deleted {wav_file}") def download_youtube_audio_by_title(query, state=True): if state: delete_input_files(INPUT_FOLDER) ydl_opts = { 'quiet': True, 'default_search': 'ytsearch', 'noplaylist': True, 'format': 'bestaudio/best', 'outtmpl': './input/wav/%(title)s.%(ext)s', 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', }], } with yt_dlp.YoutubeDL(ydl_opts) as ydl: search_results = ydl.extract_info(query, download=False) video_info = search_results['entries'][0] video_url = video_info['webpage_url'] video_title = video_info['title'] match = re.match(r'^(.*? - .*?)(?: \[.*\]|\(.*\))?$', video_title) formatted_title = match.group(1) if match else video_title formatted_title = sanitize_filename(formatted_title.strip()) ydl_opts['outtmpl'] = f'./input/wav/{formatted_title}.%(ext)s' if state: with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([video_url]) return f'./input/wav/{formatted_title}.wav' return formatted_title def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"): command = [ "python", "inference.py", "--model_type", model_type, "--config_path", config_path, "--start_check_point", start_check_point, "--INPUT_FOLDER", input_dir, "--store_dir", output_dir, "--device_ids", device_ids ] return subprocess.run(command, check=True, capture_output=True, text=True) def move_stems_to_parent(input_dir): for subdir, dirs, files in os.walk(input_dir): if subdir == input_dir: continue parent_dir = os.path.dirname(subdir) song_name = os.path.basename(parent_dir) if 'htdemucs' in subdir: print(f"Processing htdemucs in {subdir}") bass_path = os.path.join(subdir, f"{song_name}_bass.wav") if os.path.exists(bass_path): new_bass_path = os.path.join(parent_dir, "bass.wav") print(f"Moving {bass_path} to {new_bass_path}") shutil.move(bass_path, new_bass_path) else: print(f"Bass file not found: {bass_path}") elif 'mel_band_roformer' in subdir: print(f"Processing mel_band_roformer in {subdir}") vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav") if os.path.exists(vocals_path): new_vocals_path = os.path.join(parent_dir, "vocals.wav") print(f"Moving {vocals_path} to {new_vocals_path}") shutil.move(vocals_path, new_vocals_path) else: print(f"Vocals file not found: {vocals_path}") elif 'scnet' in subdir: print(f"Processing scnet in {subdir}") other_path = os.path.join(subdir, f"{song_name}_other.wav") if os.path.exists(other_path): new_other_path = os.path.join(parent_dir, "other.wav") print(f"Moving {other_path} to {new_other_path}") shutil.move(other_path, new_other_path) else: print(f"Other file not found: {other_path}") elif 'bs_roformer' in subdir: print(f"Processing bs_roformer in {subdir}") instrumental_path = os.path.join(subdir, f"{song_name}_other.wav") if os.path.exists(instrumental_path): new_instrumental_path = os.path.join(parent_dir, "instrumental.wav") print(f"Moving {instrumental_path} to {new_instrumental_path}") shutil.move(instrumental_path, new_instrumental_path) else: print(f"Instrumental file not found: {instrumental_path}") def combine_stems_for_all(input_dir): for subdir, _, _ in os.walk(input_dir): if subdir == input_dir: continue song_name = os.path.basename(subdir) print(f"Processing {subdir}") stem_paths = { "vocals": os.path.join(subdir, "vocals.wav"), "bass": os.path.join(subdir, "bass.wav"), "others": os.path.join(subdir, "other.wav"), "instrumental": os.path.join(subdir, "instrumental.wav") } if not all(os.path.exists(path) for path in stem_paths.values()): print(f"Skipping {subdir}, not all stems are present.") continue stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()} combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"]) output_file = os.path.join(subdir, f"{song_name}.MDS.wav") combined.export(output_file, format="wav") print(f"Exported combined stems to {output_file}") def delete_folders_and_files(input_dir): folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer'] files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav'] for root, dirs, files in os.walk(input_dir, topdown=False): if root == input_dir: continue for folder in folders_to_delete: folder_path = os.path.join(root, folder) if os.path.isdir(folder_path): print(f"Deleting folder: {folder_path}") shutil.rmtree(folder_path) for file in files_to_delete: file_path = os.path.join(root, file) if os.path.isfile(file_path): print(f"Deleting file: {file_path}") os.remove(file_path) for root, dirs, files in os.walk(OUTPUT_FOLDER): for dir_name in dirs: if dir_name.endswith('_vocals'): dir_path = os.path.join(root, dir_name) print(f"Deleting folder: {dir_path}") shutil.rmtree(dir_path) print("Cleanup completed.") def process_audio(song_title): try: yield "Finding audio...", None if title_input == "": raise ValueError("Please enter a song title.") formatted_title = download_youtube_audio_by_title(song_title, False) yield "Starting SCNet inference...", None proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) yield "Starting Mel Band Roformer inference...", None proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True) yield "Starting HTDemucs inference...", None proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav' destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav' os.rename(source_path, destination_path) yield "Starting BS Roformer inference...", None proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER) yield "Moving input files...", None delete_input_files(INPUT_FOLDER) yield "Moving stems to parent...", None move_stems_to_parent(OUTPUT_FOLDER) yield "Combining stems...", None combine_stems_for_all(OUTPUT_FOLDER) yield "Cleaning up...", None delete_folders_and_files(OUTPUT_FOLDER) yield f"Audio processing completed successfully.", f'{OUTPUT_FOLDER}{formatted_title}/{formatted_title}.MDS.wav' except Exception as e: error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}" logging.error(error_msg) yield error_msg, None with gr.Blocks() as demo: gr.Markdown("# Music Player and Processor") with gr.Row(): title_input = gr.Textbox(label="Enter Song Title") play_button = gr.Button("Play") audio_output = gr.Audio(label="Audio Player") process_button = gr.Button("Process Audio") log_output = gr.Textbox(label="Processing Log", interactive=False) processed_audio_output = gr.Audio(label="Processed Audio") play_button.click( fn=download_youtube_audio_by_title, inputs=title_input, outputs=audio_output ) process_button.click( fn=process_audio, inputs=title_input, outputs=[log_output, processed_audio_output], show_progress=True ) demo.launch()