Spaces:
Running
on
Zero
Running
on
Zero
import yt_dlp | |
import re | |
import subprocess | |
import os | |
import shutil | |
from pydub import AudioSegment, silence | |
import gradio as gr | |
import traceback | |
import logging | |
from inference import proc_folder_direct | |
from pathlib import Path | |
import spaces | |
from pydub.exceptions import CouldntEncodeError | |
from transformers import pipeline | |
# Initialize text generation model | |
model = pipeline('text-generation', model='EleutherAI/gpt-neo-125M') | |
# Define constants | |
OUTPUT_FOLDER = "separation_results/" | |
INPUT_FOLDER = "input" | |
download_path = "" | |
def sanitize_filename(filename): | |
""" | |
Remove special characters from filename to ensure it's valid across different file systems. | |
Args: | |
filename (str): The original filename | |
Returns: | |
str: Sanitized filename | |
""" | |
return re.sub(r'[\\/*?:"<>|]', '_', filename) | |
def delete_input_files(input_dir): | |
""" | |
Delete all WAV files in the input directory. | |
Args: | |
input_dir (str): Path to the input directory | |
""" | |
wav_dir = Path(input_dir) / "wav" | |
for wav_file in wav_dir.glob("*.wav"): | |
wav_file.unlink() | |
print(f"Deleted {wav_file}") | |
def standardize_title(input_title): | |
""" | |
Standardize the title format by removing unnecessary words and rearranging artist and title. | |
Args: | |
input_title (str): The original title | |
Returns: | |
str: Standardized title in "Artist - Title" format | |
""" | |
# Remove content within parentheses or brackets | |
title_cleaned = re.sub(r"[\(\[].*?[\)\]]", "", input_title) | |
# Remove unnecessary words | |
unnecessary_words = ["official", "video", "hd", "4k", "lyrics", "music", "audio", "visualizer", "remix"] | |
title_cleaned = re.sub(r"\b(?:{})\b".format("|".join(unnecessary_words)), "", title_cleaned, flags=re.IGNORECASE) | |
# Split title into parts | |
parts = re.split(r"\s*-\s*|\s*,\s*", title_cleaned) | |
# Determine artist and title parts | |
if len(parts) >= 2: | |
title_part = parts[-1].strip() | |
artist_part = ', '.join(parts[:-1]).strip() | |
else: | |
artist_part = "Unknown Artist" | |
title_part = title_cleaned.strip() | |
# Handle "with" or "feat" in the title | |
if "with" in input_title.lower() or "feat" in input_title.lower(): | |
match = re.search(r"\((with|feat\.?) (.*?)\)", input_title, re.IGNORECASE) | |
if match: | |
additional_artist = match.group(2).strip() | |
artist_part = f"{artist_part}, {additional_artist}" if artist_part != "Unknown Artist" else additional_artist | |
# Clean up and capitalize | |
artist_part = re.sub(r'\s+', ' ', artist_part).title() | |
title_part = re.sub(r'\s+', ' ', title_part).title() | |
# Combine artist and title | |
standardized_output = f"{artist_part} - {title_part}" | |
return standardized_output.strip() | |
def handle_file_upload(file): | |
""" | |
Handle file upload, standardize the filename, and copy it to the input folder. | |
Args: | |
file: Uploaded file object | |
Returns: | |
tuple: (input_path, formatted_title) or (None, error_message) | |
""" | |
if file is None: | |
return None, "No file uploaded" | |
filename = os.path.basename(file.name) | |
formatted_title = standardize_title(filename) | |
formatted_title = sanitize_filename(formatted_title.strip()) | |
input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav") | |
os.makedirs(os.path.dirname(input_path), exist_ok=True) | |
shutil.copy(file.name, input_path) | |
return input_path, formatted_title | |
def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"): | |
""" | |
Run inference using the specified model and parameters. | |
Args: | |
model_type (str): Type of the model | |
config_path (str): Path to the model configuration | |
start_check_point (str): Path to the model checkpoint | |
input_dir (str): Input directory | |
output_dir (str): Output directory | |
device_ids (str): GPU device IDs to use | |
Returns: | |
subprocess.CompletedProcess: Result of the subprocess run | |
""" | |
command = [ | |
"python", "inference.py", | |
"--model_type", model_type, | |
"--config_path", config_path, | |
"--start_check_point", start_check_point, | |
"--INPUT_FOLDER", input_dir, | |
"--store_dir", output_dir, | |
"--device_ids", device_ids | |
] | |
return subprocess.run(command, check=True, capture_output=True, text=True) | |
def move_stems_to_parent(input_dir): | |
""" | |
Move generated stem files to their parent directories. | |
Args: | |
input_dir (str): Input directory containing stem folders | |
""" | |
for subdir, dirs, files in os.walk(input_dir): | |
if subdir == input_dir: | |
continue | |
parent_dir = os.path.dirname(subdir) | |
song_name = os.path.basename(parent_dir) | |
# Move bass stem | |
if 'htdemucs' in subdir: | |
bass_path = os.path.join(subdir, f"{song_name}_bass.wav") | |
if os.path.exists(bass_path): | |
new_bass_path = os.path.join(parent_dir, "bass.wav") | |
shutil.move(bass_path, new_bass_path) | |
else: | |
print(f"Bass file not found: {bass_path}") | |
# Move vocals stem | |
elif 'mel_band_roformer' in subdir: | |
vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav") | |
if os.path.exists(vocals_path): | |
new_vocals_path = os.path.join(parent_dir, "vocals.wav") | |
shutil.move(vocals_path, new_vocals_path) | |
else: | |
print(f"Vocals file not found: {vocals_path}") | |
# Move other stem | |
elif 'scnet' in subdir: | |
other_path = os.path.join(subdir, f"{song_name}_other.wav") | |
if os.path.exists(other_path): | |
new_other_path = os.path.join(parent_dir, "other.wav") | |
shutil.move(other_path, new_other_path) | |
else: | |
print(f"Other file not found: {other_path}") | |
# Move instrumental stem | |
elif 'bs_roformer' in subdir: | |
instrumental_path = os.path.join(subdir, f"{song_name}_other.wav") | |
if os.path.exists(instrumental_path): | |
new_instrumental_path = os.path.join(parent_dir, "instrumental.wav") | |
shutil.move(instrumental_path, new_instrumental_path) | |
def combine_stems_for_all(input_dir, output_format): | |
""" | |
Combine all stems for each song in the input directory. | |
Args: | |
input_dir (str): Input directory containing song folders | |
output_format (str): Output audio format (e.g., 'm4a') | |
Returns: | |
str: Path to the combined audio file | |
""" | |
for subdir, _, _ in os.walk(input_dir): | |
if subdir == input_dir: | |
continue | |
song_name = os.path.basename(subdir).strip() # Remove any trailing spaces | |
print(f"Processing {subdir}") | |
stem_paths = { | |
"vocals": os.path.join(subdir, "vocals.wav"), | |
"bass": os.path.join(subdir, "bass.wav"), | |
"others": os.path.join(subdir, "other.wav"), | |
"instrumental": os.path.join(subdir, "instrumental.wav") | |
} | |
# Skip if not all stems are present | |
if not all(os.path.exists(path) for path in stem_paths.values()): | |
print(f"Skipping {subdir}, not all stems are present.") | |
continue | |
# Load and combine stems | |
stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()} | |
combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"]) | |
# Trim silence at the end | |
trimmed_combined = trim_silence_at_end(combined) | |
output_file = os.path.join(subdir, f"{song_name}.{output_format.lower()}") | |
try: | |
# Export combined audio | |
trimmed_combined.export(output_file, format=output_format.lower(), codec="aac") | |
print(f"Exported combined stems to {output_file}") | |
except CouldntEncodeError as e: | |
print(f"Encoding failed: {e}") | |
raise | |
return output_file | |
def trim_silence_at_end(audio_segment, silence_thresh=-50, chunk_size=10): | |
""" | |
Trim silence at the end of an audio segment. | |
Args: | |
audio_segment (AudioSegment): Input audio segment | |
silence_thresh (int): Silence threshold in dB | |
chunk_size (int): Size of chunks to analyze in ms | |
Returns: | |
AudioSegment: Trimmed audio segment | |
""" | |
silence_end = silence.detect_silence(audio_segment, min_silence_len=chunk_size, silence_thresh=silence_thresh) | |
if silence_end: | |
last_silence_start = silence_end[-1][0] | |
return audio_segment[:last_silence_start] | |
else: | |
return audio_segment | |
def delete_folders_and_files(input_dir): | |
""" | |
Delete temporary folders and files after processing. | |
Args: | |
input_dir (str): Input directory to clean up | |
""" | |
folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer'] | |
files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav'] | |
for root, dirs, files in os.walk(input_dir, topdown=False): | |
if root == input_dir: | |
continue | |
# Delete specified folders | |
for folder in folders_to_delete: | |
folder_path = os.path.join(root, folder) | |
if os.path.isdir(folder_path): | |
print(f"Deleting folder: {folder_path}") | |
shutil.rmtree(folder_path) | |
# Delete specified files | |
for file in files_to_delete: | |
file_path = os.path.join(root, file) | |
if os.path.isfile(file_path): | |
print(f"Deleting file: {file_path}") | |
os.remove(file_path) | |
# Delete vocals folders | |
for root, dirs, files in os.walk(OUTPUT_FOLDER): | |
for dir_name in dirs: | |
if dir_name.endswith('_vocals'): | |
dir_path = os.path.join(root, dir_name) | |
print(f"Deleting folder: {dir_path}") | |
shutil.rmtree(dir_path) | |
print("Cleanup completed.") | |
def process_audio(uploaded_file): | |
""" | |
Main function to process the uploaded audio file. | |
Args: | |
uploaded_file: Uploaded file object | |
Yields: | |
tuple: (status_message, output_file_path) | |
""" | |
try: | |
yield "Processing audio...", None | |
if uploaded_file: | |
input_path, formatted_title = handle_file_upload(uploaded_file) | |
if input_path is None: | |
raise ValueError("File upload failed.") | |
else: | |
raise ValueError("Please upload a WAV file.") | |
# Run inference for different models | |
yield "Starting SCNet inference...", None | |
proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
yield "Starting Mel Band Roformer inference...", None | |
proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True) | |
yield "Starting HTDemucs inference...", None | |
proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
# Rename instrumental file | |
source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav' | |
destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav' | |
os.rename(source_path, destination_path) | |
yield "Starting BS Roformer inference...", None | |
proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER) | |
# Clean up and organize files | |
yield "Moving input files...", None | |
delete_input_files(INPUT_FOLDER) | |
yield "Moving stems to parent...", None | |
move_stems_to_parent(OUTPUT_FOLDER) | |
yield "Combining stems...", None | |
output_file = combine_stems_for_all(OUTPUT_FOLDER, "m4a") | |
yield "Cleaning up...", None | |
delete_folders_and_files(OUTPUT_FOLDER) | |
yield f"Audio processing completed successfully.", output_file | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}" | |
logging.error(error_msg) | |
yield error_msg, None | |
# Set up Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Music Player and Processor") | |
file_upload = gr.File(label="Upload WAV file", file_types=[".m4a"]) | |
process_button = gr.Button("Process Audio") | |
log_output = gr.Textbox(label="Processing Log", interactive=False) | |
processed_audio_output = gr.File(label="Processed Audio") | |
process_button.click( | |
fn=process_audio, | |
inputs=file_upload, | |
outputs=[log_output, processed_audio_output], | |
show_progress=True | |
) | |
# Launch the Gradio app | |
demo.launch() |