# import os
# import gradio as gr
# from scipy.io.wavfile import write
# import subprocess
# import torch

# from audio_separator import Separator  # Ensure this is correctly implemented

# def inference(audio):
#     os.makedirs("out", exist_ok=True)
#     audio_path = 'test.wav'
#     write(audio_path, audio[0], audio[1])
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     if device=='cuda':
#         use_cuda=True
#         print(f"Using device: {device}")
#     else:
#         use_cuda=False
#         print(f"Using device: {device}")
#     try:
        
#         # Using subprocess.run for better control
#         command = f"python3 -m demucs.separate -n htdemucs_6s -d {device} {audio_path} -o out"
#         process = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
#         print("Demucs script output:", process.stdout.decode())
#     except subprocess.CalledProcessError as e:
#         print("Error in Demucs script:", e.stderr.decode())
#         return None

#     try:
#         # Separating the stems using your custom separator
#         separator = Separator("./out/htdemucs_6s/test/vocals.wav", model_name='UVR_MDXNET_KARA_2', use_cuda=use_cuda, output_format='mp3')
#         primary_stem_path, secondary_stem_path = separator.separate()
#     except Exception as e:
#         print("Error in custom separation:", str(e))
#         return None

#     # Collecting all file paths
#     files = [f"./out/htdemucs_6s/test/{stem}.wav" for stem in ["vocals", "bass", "drums", "other", "piano", "guitar"]]
#     files.extend([secondary_stem_path,primary_stem_path ])

#     # Check if files exist
#     existing_files = [file for file in files if os.path.isfile(file)]
#     if not existing_files:
#         print("No files were created.")
#         return None

#     return existing_files

# # Gradio Interface
# title = "Source Separation Demo"
# description = "Music Source Separation in the Waveform Domain. To use it, simply upload your audio."
# gr.Interface(
#     inference, 
#     gr.components.Audio(type="numpy", label="Input"), 
#     [gr.components.Audio(type="filepath", label=stem) for stem in ["Full Vocals","Bass", "Drums", "Other", "Piano", "Guitar", "Lead Vocals", "Backing Vocals" ]],
#     title=title,
#     description=description,
# ).launch()

 

import os
import gradio as gr
from scipy.io.wavfile import write
import subprocess
import torch

# Assuming audio_separator is available in your environment
from audio_separator import Separator  

def inference(audio, vocals, bass, drums, other, piano, guitar, lead_vocals, backing_vocals):
    os.makedirs("out", exist_ok=True)
    audio_path = 'test.wav'
    write(audio_path, audio[0], audio[1])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    try:
        command = f"python3 -m demucs.separate -n htdemucs_6s -d {device} {audio_path} -o out"
        process = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("Demucs script output:", process.stdout.decode())
    except subprocess.CalledProcessError as e:
        print("Error in Demucs script:", e.stderr.decode())
        return [gr.Audio(visible=False)] * 8 + ["Failed to process audio."]

    try:
        separator = Separator("./out/htdemucs_6s/test/vocals.wav", model_name='UVR_MDXNET_KARA_2', use_cuda=device=='cuda', output_format='wav')
        primary_stem_path, secondary_stem_path = separator.separate()
    except Exception as e:
        print("Error in custom separation:", str(e))
        return [gr.Audio(visible=False)] * 8 + ["Failed to process audio."]

    stem_paths = {
        "vocals": "./out/htdemucs_6s/test/vocals.wav" if vocals else None,
        "bass": "./out/htdemucs_6s/test/bass.wav" if bass else None,
        "drums": "./out/htdemucs_6s/test/drums.wav" if drums else None,
        "other": "./out/htdemucs_6s/test/other.wav" if other else None,
        "piano": "./out/htdemucs_6s/test/piano.wav" if piano else None,
        "guitar": "./out/htdemucs_6s/test/guitar.wav" if guitar else None,
        "lead_vocals": primary_stem_path if lead_vocals else None,
        "backing_vocals": secondary_stem_path if backing_vocals else None
    }

    return tuple([gr.Audio(stem_paths[stem], visible=bool(stem_paths[stem])) for stem in stem_paths]) + ("Done! Successfully processed.",)

# Define checkboxes for each stem
checkbox_labels = ["Full Vocals", "Bass", "Drums", "Other", "Piano", "Guitar", "Lead Vocals", "Backing Vocals"]
checkboxes = [gr.components.Checkbox(label=label) for label in checkbox_labels]

# Gradio Interface
title = "Source Separation Demo"
description = "Music Source Separation in the Waveform Domain. Upload your audio to begin."
iface = gr.Interface(
    inference, 
    [gr.components.Audio(type="numpy", label="Input")] + checkboxes,
    [gr.Audio(label=label, visible=False) for label in checkbox_labels] + [gr.Label()],
    title=title,
    description=description,
)

iface.launch()