Awell00's picture
feat: add support for processing uploaded WAV files instead of downloading YouTube audio .
e1ba51e verified
raw
history blame
8.33 kB
import yt_dlp
import re
import subprocess
import os
import shutil
from pydub import AudioSegment
import gradio as gr
import traceback
import logging
from inference import proc_folder_direct
from pathlib import Path
OUTPUT_FOLDER = "separation_results/"
INPUT_FOLDER = "input"
download_path = ""
def sanitize_filename(filename):
return re.sub(r'[\\/*?:"<>|]', '_', filename)
def delete_input_files(input_dir):
wav_dir = Path(input_dir) / "wav"
for wav_file in wav_dir.glob("*.wav"):
wav_file.unlink()
print(f"Deleted {wav_file}")
def process_uploaded_audio(file, state=True):
if state:
delete_input_files(INPUT_FOLDER)
sanitized_filename = sanitize_filename(file.name)
input_path = Path(INPUT_FOLDER) / "wav" / sanitized_filename
input_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_path, 'wb') as f:
f.write(file.read())
return str(input_path)
def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
command = [
"python", "inference.py",
"--model_type", model_type,
"--config_path", config_path,
"--start_check_point", start_check_point,
"--INPUT_FOLDER", input_dir,
"--store_dir", output_dir,
"--device_ids", device_ids
]
return subprocess.run(command, check=True, capture_output=True, text=True)
def move_stems_to_parent(input_dir):
for subdir, dirs, files in os.walk(input_dir):
if subdir == input_dir:
continue
parent_dir = os.path.dirname(subdir)
song_name = os.path.basename(parent_dir)
if 'htdemucs' in subdir:
print(f"Processing htdemucs in {subdir}")
bass_path = os.path.join(subdir, f"{song_name}_bass.wav")
if os.path.exists(bass_path):
new_bass_path = os.path.join(parent_dir, "bass.wav")
print(f"Moving {bass_path} to {new_bass_path}")
shutil.move(bass_path, new_bass_path)
else:
print(f"Bass file not found: {bass_path}")
elif 'mel_band_roformer' in subdir:
print(f"Processing mel_band_roformer in {subdir}")
vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav")
if os.path.exists(vocals_path):
new_vocals_path = os.path.join(parent_dir, "vocals.wav")
print(f"Moving {vocals_path} to {new_vocals_path}")
shutil.move(vocals_path, new_vocals_path)
else:
print(f"Vocals file not found: {vocals_path}")
elif 'scnet' in subdir:
print(f"Processing scnet in {subdir}")
other_path = os.path.join(subdir, f"{song_name}_other.wav")
if os.path.exists(other_path):
new_other_path = os.path.join(parent_dir, "other.wav")
print(f"Moving {other_path} to {new_other_path}")
shutil.move(other_path, new_other_path)
else:
print(f"Other file not found: {other_path}")
elif 'bs_roformer' in subdir:
print(f"Processing bs_roformer in {subdir}")
instrumental_path = os.path.join(subdir, f"{song_name}_other.wav")
if os.path.exists(instrumental_path):
new_instrumental_path = os.path.join(parent_dir, "instrumental.wav")
print(f"Moving {instrumental_path} to {new_instrumental_path}")
shutil.move(instrumental_path, new_instrumental_path)
def combine_stems_for_all(input_dir):
for subdir, _, _ in os.walk(input_dir):
if subdir == input_dir:
continue
song_name = os.path.basename(subdir)
print(f"Processing {subdir}")
stem_paths = {
"vocals": os.path.join(subdir, "vocals.wav"),
"bass": os.path.join(subdir, "bass.wav"),
"others": os.path.join(subdir, "other.wav"),
"instrumental": os.path.join(subdir, "instrumental.wav")
}
if not all(os.path.exists(path) for path in stem_paths.values()):
print(f"Skipping {subdir}, not all stems are present.")
continue
stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()}
combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"])
output_file = os.path.join(subdir, f"{song_name}.MDS.wav")
combined.export(output_file, format="wav")
print(f"Exported combined stems to {output_file}")
def delete_folders_and_files(input_dir):
folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer']
files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav']
for root, dirs, files in os.walk(input_dir, topdown=False):
if root == input_dir:
continue
for folder in folders_to_delete:
folder_path = os.path.join(root, folder)
if os.path.isdir(folder_path):
print(f"Deleting folder: {folder_path}")
shutil.rmtree(folder_path)
for file in files_to_delete:
file_path = os.path.join(root, file)
if os.path.isfile(file_path):
print(f"Deleting file: {file_path}")
os.remove(file_path)
for root, dirs, files in os.walk(OUTPUT_FOLDER):
for dir_name in dirs:
if dir_name.endswith('_vocals'):
dir_path = os.path.join(root, dir_name)
print(f"Deleting folder: {dir_path}")
shutil.rmtree(dir_path)
print("Cleanup completed.")
def process_audio(uploaded_file):
try:
yield "Processing audio file...", None
if uploaded_file is None:
raise ValueError("Please upload a WAV file.")
file_path = process_uploaded_audio(uploaded_file, False)
yield "Starting SCNet inference...", None
proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
yield "Starting Mel Band Roformer inference...", None
proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True)
yield "Starting HTDemucs inference...", None
proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
source_path = f'{OUTPUT_FOLDER}/{file_path.stem}/mel_band_roformer/{file_path.stem}_instrumental.wav'
destination_path = f'{OUTPUT_FOLDER}/{file_path.stem}/mel_band_roformer/{file_path.stem}.wav'
os.rename(source_path, destination_path)
yield "Starting BS Roformer inference...", None
proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}/{file_path.stem}/mel_band_roformer', OUTPUT_FOLDER)
yield "Moving input files...", None
delete_input_files(INPUT_FOLDER)
yield "Moving stems to parent...", None
move_stems_to_parent(OUTPUT_FOLDER)
yield "Combining stems...", None
combine_stems_for_all(OUTPUT_FOLDER)
yield "Cleaning up...", None
delete_folders_and_files(OUTPUT_FOLDER)
yield f"Audio processing completed successfully.", f'{OUTPUT_FOLDER}/{file_path.stem}/{file_path.stem}.MDS.wav'
except Exception as e:
error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}"
logging.error(error_msg)
yield error_msg, None
with gr.Blocks() as demo:
gr.Markdown("# Music Player and Processor")
with gr.Row():
file_input = gr.File(label="Upload WAV File", file_types=['wav'])
process_button = gr.Button("Process Audio")
log_output = gr.Textbox(label="Processing Log", interactive=False)
processed_audio_output = gr.Audio(label="Processed Audio")
process_button.click(
fn=process_audio,
inputs=file_input,
outputs=[log_output, processed_audio_output],
show_progress=True
)
demo.launch()