import yt_dlp
import re
import subprocess
import os
import shutil
from pydub import AudioSegment
import gradio as gr
import traceback
import logging
from inference import proc_folder_direct
from pathlib import Path
from transformers import pipeline

OUTPUT_FOLDER = "separation_results/"
INPUT_FOLDER = "input"
download_path = ""

def sanitize_filename(filename):
    return re.sub(r'[\\/*?:"<>|]', '_', filename)

def delete_input_files(input_dir):
    wav_dir = Path(input_dir) / "wav"
    for wav_file in wav_dir.glob("*.wav"):
        wav_file.unlink()
        print(f"Deleted {wav_file}")

def handle_file_upload(file):
    if file is None:
        return None, "No file uploaded"

    filename = os.path.basename(file.name)
    formatted_title = sanitize_filename(Path(filename).stem)
    
    input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav")
    os.makedirs(os.path.dirname(input_path), exist_ok=True)
    
    # Copy the uploaded file to the input folder
    shutil.copy(file.name, input_path)
    
    return input_path, formatted_title
    
def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
    command = [
        "python", "inference.py",
        "--model_type", model_type,
        "--config_path", config_path,
        "--start_check_point", start_check_point,
        "--INPUT_FOLDER", input_dir,
        "--store_dir", output_dir,
        "--device_ids", device_ids
    ]
    return subprocess.run(command, check=True, capture_output=True, text=True)

def move_stems_to_parent(input_dir):
    for subdir, dirs, files in os.walk(input_dir):
        if subdir == input_dir:
            continue

        parent_dir = os.path.dirname(subdir)
        song_name = os.path.basename(parent_dir)

        if 'htdemucs' in subdir:
            print(f"Processing htdemucs in {subdir}")
            bass_path = os.path.join(subdir, f"{song_name}_bass.wav")
            if os.path.exists(bass_path):
                new_bass_path = os.path.join(parent_dir, "bass.wav")
                print(f"Moving {bass_path} to {new_bass_path}")
                shutil.move(bass_path, new_bass_path)
            else:
                print(f"Bass file not found: {bass_path}")
        elif 'mel_band_roformer' in subdir:
            print(f"Processing mel_band_roformer in {subdir}")
            vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav")
            if os.path.exists(vocals_path):
                new_vocals_path = os.path.join(parent_dir, "vocals.wav")
                print(f"Moving {vocals_path} to {new_vocals_path}")
                shutil.move(vocals_path, new_vocals_path)
            else:
                print(f"Vocals file not found: {vocals_path}")
        elif 'scnet' in subdir:
            print(f"Processing scnet in {subdir}")
            other_path = os.path.join(subdir, f"{song_name}_other.wav")
            if os.path.exists(other_path):
                new_other_path = os.path.join(parent_dir, "other.wav")
                print(f"Moving {other_path} to {new_other_path}")
                shutil.move(other_path, new_other_path)
            else:
                print(f"Other file not found: {other_path}")
        elif 'bs_roformer' in subdir:
            print(f"Processing bs_roformer in {subdir}")
            instrumental_path = os.path.join(subdir, f"{song_name}_other.wav")
            if os.path.exists(instrumental_path):
                new_instrumental_path = os.path.join(parent_dir, "instrumental.wav")
                print(f"Moving {instrumental_path} to {new_instrumental_path}")
                shutil.move(instrumental_path, new_instrumental_path)
            else:
                print(f"Instrumental file not found: {instrumental_path}")

def combine_stems_for_all(input_dir):
    for subdir, _, _ in os.walk(input_dir):
        if subdir == input_dir:
            continue

        song_name = os.path.basename(subdir)
        print(f"Processing {subdir}")

        stem_paths = {
            "vocals": os.path.join(subdir, "vocals.wav"),
            "bass": os.path.join(subdir, "bass.wav"),
            "others": os.path.join(subdir, "other.wav"),
            "instrumental": os.path.join(subdir, "instrumental.wav")
        }

        if not all(os.path.exists(path) for path in stem_paths.values()):
            print(f"Skipping {subdir}, not all stems are present.")
            continue

        stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()}
        combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"])

        output_file = os.path.join(subdir, f"{song_name}.MDS.wav")
        combined.export(output_file, format="wav")
        print(f"Exported combined stems to {output_file}")

def delete_folders_and_files(input_dir):
    folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer']
    files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav']

    for root, dirs, files in os.walk(input_dir, topdown=False):
        if root == input_dir:
            continue

        for folder in folders_to_delete:
            folder_path = os.path.join(root, folder)
            if os.path.isdir(folder_path):
                print(f"Deleting folder: {folder_path}")
                shutil.rmtree(folder_path)

        for file in files_to_delete:
            file_path = os.path.join(root, file)
            if os.path.isfile(file_path):
                print(f"Deleting file: {file_path}")
                os.remove(file_path)

    for root, dirs, files in os.walk(OUTPUT_FOLDER):
        for dir_name in dirs:
            if dir_name.endswith('_vocals'):
                dir_path = os.path.join(root, dir_name)
                print(f"Deleting folder: {dir_path}")
                shutil.rmtree(dir_path)

    print("Cleanup completed.")

@spaces.GPU(duration=300)  # Adjust the duration as needed
def process_audio(uploaded_file):
    try:
        yield "Processing audio...", None

        if uploaded_file:
            input_path, formatted_title = handle_file_upload(uploaded_file)
            if input_path is None:
                raise ValueError("File upload failed.")
        else:
            raise ValueError("Please upload a WAV file.")

        yield "Starting SCNet inference...", None
        proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)

        yield "Starting Mel Band Roformer inference...", None
        proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True)

        yield "Starting HTDemucs inference...", None
        proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)

        source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav'
        destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav'

        os.rename(source_path, destination_path)

        yield "Starting BS Roformer inference...", None
        proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER)

        yield "Moving input files...", None
        delete_input_files(INPUT_FOLDER)

        yield "Moving stems to parent...", None
        move_stems_to_parent(OUTPUT_FOLDER)

        yield "Combining stems...", None
        combine_stems_for_all(OUTPUT_FOLDER)

        yield "Cleaning up...", None
        delete_folders_and_files(OUTPUT_FOLDER)

        yield f"Audio processing completed successfully.", f'{OUTPUT_FOLDER}{formatted_title}/{formatted_title}.MDS.wav'
    except Exception as e:
        error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}"
        logging.error(error_msg)
        yield error_msg, None

with gr.Blocks() as demo:
    gr.Markdown("# Music Player and Processor")

    file_upload = gr.File(label="Upload WAV file", file_types=[".wav"])

    audio_output = gr.Audio(label="Audio Player")

    process_button = gr.Button("Process Audio")

    log_output = gr.Textbox(label="Processing Log", interactive=False)
    processed_audio_output = gr.Audio(label="Processed Audio")

    file_upload.change(
        fn=lambda file: file.name if file else None,
        inputs=file_upload,
        outputs=audio_output
    )

    process_button.click(
        fn=process_audio,
        inputs=file_upload,
        outputs=[log_output, processed_audio_output],
        show_progress=True
    )

demo.launch()