Awell00's picture
feat!: create app.py to initialize Gradio interface with main function
889d37b verified
raw
history blame
9.37 kB
import yt_dlp
import re
import subprocess
import os
import shutil
from pydub import AudioSegment
import gradio as gr
import traceback
import logging
from inference import proc_folder_direct
from pathlib import Path
OUTPUT_FOLDER = "separation_results/"
INPUT_FOLDER = "input"
download_path = ""
def sanitize_filename(filename):
return re.sub(r'[\\/*?:"<>|]', '_', filename)
def delete_input_files(input_dir):
wav_dir = Path(input_dir) / "wav"
for wav_file in wav_dir.glob("*.wav"):
wav_file.unlink()
print(f"Deleted {wav_file}")
def download_youtube_audio_by_title(query, state=True):
if state:
delete_input_files(INPUT_FOLDER)
ydl_opts = {
'quiet': True,
'default_search': 'ytsearch',
'noplaylist': True,
'format': 'bestaudio/best',
'outtmpl': './input/wav/%(title)s.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
}],
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
search_results = ydl.extract_info(query, download=False)
video_info = search_results['entries'][0]
video_url = video_info['webpage_url']
video_title = video_info['title']
match = re.match(r'^(.*? - .*?)(?: \[.*\]|\(.*\))?$', video_title)
formatted_title = match.group(1) if match else video_title
formatted_title = sanitize_filename(formatted_title.strip())
ydl_opts['outtmpl'] = f'./input/wav/{formatted_title}.%(ext)s'
if state:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([video_url])
return f'./input/wav/{formatted_title}.wav'
return formatted_title
def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
command = [
"python", "inference.py",
"--model_type", model_type,
"--config_path", config_path,
"--start_check_point", start_check_point,
"--INPUT_FOLDER", input_dir,
"--store_dir", output_dir,
"--device_ids", device_ids
]
return subprocess.run(command, check=True, capture_output=True, text=True)
def move_stems_to_parent(input_dir):
for subdir, dirs, files in os.walk(input_dir):
if subdir == input_dir:
continue
parent_dir = os.path.dirname(subdir)
song_name = os.path.basename(parent_dir)
if 'htdemucs' in subdir:
print(f"Processing htdemucs in {subdir}")
bass_path = os.path.join(subdir, f"{song_name}_bass.wav")
if os.path.exists(bass_path):
new_bass_path = os.path.join(parent_dir, "bass.wav")
print(f"Moving {bass_path} to {new_bass_path}")
shutil.move(bass_path, new_bass_path)
else:
print(f"Bass file not found: {bass_path}")
elif 'mel_band_roformer' in subdir:
print(f"Processing mel_band_roformer in {subdir}")
vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav")
if os.path.exists(vocals_path):
new_vocals_path = os.path.join(parent_dir, "vocals.wav")
print(f"Moving {vocals_path} to {new_vocals_path}")
shutil.move(vocals_path, new_vocals_path)
else:
print(f"Vocals file not found: {vocals_path}")
elif 'scnet' in subdir:
print(f"Processing scnet in {subdir}")
other_path = os.path.join(subdir, f"{song_name}_other.wav")
if os.path.exists(other_path):
new_other_path = os.path.join(parent_dir, "other.wav")
print(f"Moving {other_path} to {new_other_path}")
shutil.move(other_path, new_other_path)
else:
print(f"Other file not found: {other_path}")
elif 'bs_roformer' in subdir:
print(f"Processing bs_roformer in {subdir}")
instrumental_path = os.path.join(subdir, f"{song_name}_other.wav")
if os.path.exists(instrumental_path):
new_instrumental_path = os.path.join(parent_dir, "instrumental.wav")
print(f"Moving {instrumental_path} to {new_instrumental_path}")
shutil.move(instrumental_path, new_instrumental_path)
else:
print(f"Instrumental file not found: {instrumental_path}")
def combine_stems_for_all(input_dir):
for subdir, _, _ in os.walk(input_dir):
if subdir == input_dir:
continue
song_name = os.path.basename(subdir)
print(f"Processing {subdir}")
stem_paths = {
"vocals": os.path.join(subdir, "vocals.wav"),
"bass": os.path.join(subdir, "bass.wav"),
"others": os.path.join(subdir, "other.wav"),
"instrumental": os.path.join(subdir, "instrumental.wav")
}
if not all(os.path.exists(path) for path in stem_paths.values()):
print(f"Skipping {subdir}, not all stems are present.")
continue
stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()}
combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"])
output_file = os.path.join(subdir, f"{song_name}.MDS.wav")
combined.export(output_file, format="wav")
print(f"Exported combined stems to {output_file}")
def delete_folders_and_files(input_dir):
folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer']
files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav']
for root, dirs, files in os.walk(input_dir, topdown=False):
if root == input_dir:
continue
for folder in folders_to_delete:
folder_path = os.path.join(root, folder)
if os.path.isdir(folder_path):
print(f"Deleting folder: {folder_path}")
shutil.rmtree(folder_path)
for file in files_to_delete:
file_path = os.path.join(root, file)
if os.path.isfile(file_path):
print(f"Deleting file: {file_path}")
os.remove(file_path)
for root, dirs, files in os.walk(OUTPUT_FOLDER):
for dir_name in dirs:
if dir_name.endswith('_vocals'):
dir_path = os.path.join(root, dir_name)
print(f"Deleting folder: {dir_path}")
shutil.rmtree(dir_path)
print("Cleanup completed.")
def process_audio(song_title):
try:
yield "Finding audio...", None
if title_input == "":
raise ValueError("Please enter a song title.")
formatted_title = download_youtube_audio_by_title(song_title, False)
yield "Starting SCNet inference...", None
proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
yield "Starting Mel Band Roformer inference...", None
proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True)
yield "Starting HTDemucs inference...", None
proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav'
destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav'
os.rename(source_path, destination_path)
yield "Starting BS Roformer inference...", None
proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER)
yield "Moving input files...", None
delete_input_files(INPUT_FOLDER)
yield "Moving stems to parent...", None
move_stems_to_parent(OUTPUT_FOLDER)
yield "Combining stems...", None
combine_stems_for_all(OUTPUT_FOLDER)
yield "Cleaning up...", None
delete_folders_and_files(OUTPUT_FOLDER)
yield f"Audio processing completed successfully.", f'{OUTPUT_FOLDER}{formatted_title}/{formatted_title}.MDS.wav'
except Exception as e:
error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}"
logging.error(error_msg)
yield error_msg, None
with gr.Blocks() as demo:
gr.Markdown("# Music Player and Processor")
with gr.Row():
title_input = gr.Textbox(label="Enter Song Title")
play_button = gr.Button("Play")
audio_output = gr.Audio(label="Audio Player")
process_button = gr.Button("Process Audio")
log_output = gr.Textbox(label="Processing Log", interactive=False)
processed_audio_output = gr.Audio(label="Processed Audio")
play_button.click(
fn=download_youtube_audio_by_title,
inputs=title_input,
outputs=audio_output
)
process_button.click(
fn=process_audio,
inputs=title_input,
outputs=[log_output, processed_audio_output],
show_progress=True
)
demo.launch()