import gradio as gr
import os
import allin1
import time

from pathlib import Path

HEADER = """
<header style="text-align: center;">
  <h1>
    All-In-One Music Structure Analyzer 🔮
  </h1>
  <p>
    <a href="https://github.com/mir-aidj/all-in-one">[Python Package]</a>
    <a href="https://arxiv.org/abs/2307.16425">[Paper]</a>
    <a href="https://taejun.kim/music-dissector/">[Visual Demo]</a>
  </p>
</header>
<main
  style="display: flex; justify-content: center;"
>
  <div
    style="display: inline-block;"
  >
    <p>
      This Space demonstrates the music structure analyzer predicts:
      <ul
        style="padding-left: 1rem;"
      >
        <li>BPM</li>
        <li>Beats</li>
        <li>Downbeats</li>
        <li>Functional segment boundaries</li>
        <li>Functional segment labels (e.g. intro, verse, chorus, bridge, outro)</li>
      </ul>
    </p>
    <p>
      For more information, please visit the links above ✨🧸
    </p>
  </div>
</main>
"""

CACHE_EXAMPLES = os.getenv('CACHE_EXAMPLES', '1') == '1'

base_dir = "/tmp/gradio/"

def analyze(path):
  #Measure time for inference
  start = time.time()
    
  path = Path(path)
  result= allin1.analyze(
    path,
    out_dir='./struct',
    multiprocess=False,
    keep_byproducts=True,  # TODO: remove this
  )

  json_structure_output = None
  for root, dirs, files in os.walk(f"./struct"):
    for file_path in files:
      json_structure_output = os.path.join(root, file_path)
      print(json_structure_output)
    
  fig = allin1.visualize(
    result,
    multiprocess=False,
  )
  fig.set_dpi(300)

  #allin1.sonify(
  #  result,
  #  out_dir='./sonif',
  #  multiprocess=False,
  #)
  #sonif_path = Path(f'./sonif/{path.stem}.sonif{path.suffix}').resolve().as_posix()

  #Measure time for inference
  end = time.time()
  elapsed_time = end-start

  # Get the base name of the file
  file_name = os.path.basename(path)
    
  # Remove the extension from the file name
  file_name_without_extension = os.path.splitext(file_name)[0]
  print(file_name_without_extension)
  bass_path, drums_path, other_path, vocals_path = None, None, None, None
  for root, dirs, files in os.walk(f"./demix/htdemucs/{file_name_without_extension}"):
    for file_path in files:
      file_path = os.path.join(root, file_path)
      print(file_path)
      if "bass.wav" in file_path:
        bass_path = file_path
      if "vocals.wav" in file_path:
        vocals_path = file_path
      if "other.wav" in file_path:
        other_path = file_path
      if "drums.wav" in file_path:
        drums_path = file_path

  #return result.bpm, fig, sonif_path, elapsed_time
  return result.bpm, fig, elapsed_time, json_structure_output, bass_path, drums_path, other_path, vocals_path


with gr.Blocks() as demo:
  gr.HTML(HEADER)

  input_audio_path = gr.Audio(
    label='Input',
    source='upload',
    type='filepath',
    format='mp3',
    show_download_button=False,
  )
  button = gr.Button('Analyze', variant='primary')
  output_viz = gr.Plot(label='Visualization')
  with gr.Row():
    output_bpm = gr.Textbox(label='BPM', scale=1)
    #output_sonif = gr.Audio(
    #  label='Sonification',
    #  type='filepath',
    #  format='mp3',
    #  show_download_button=False,
    #  scale=9,
    #)
    elapsed_time = gr.Textbox(label='Overall inference time', scale=1)
    json_structure_output = gr.File(label="Json structure")
  with gr.Column():
    bass = gr.Audio(label='bass', show_share_button=False)
    vocals =gr.Audio(label='vocals', show_share_button=False)
    other = gr.Audio(label='other', show_share_button=False)
    drums =gr.Audio(label='drums', show_share_button=False)
    #bass_path = gr.Textbox(label='bass_path', scale=1)
    #drums_path = gr.Textbox(label='drums_path', scale=1)
    #other_path = gr.Textbox(label='other_path', scale=1)
    #vocals_path = gr.Textbox(label='vocals_path', scale=1)
  #gr.Examples(
  #  examples=[
  #    './assets/NewJeans - Super Shy.mp3',
  #    './assets/Bruno Mars - 24k Magic.mp3'
  #  ],
  #  inputs=input_audio_path,
  #  outputs=[output_bpm, output_viz, output_sonif],
  #  fn=analyze,
  #  cache_examples=CACHE_EXAMPLES,
  #)
  
  button.click(
    fn=analyze,
    inputs=input_audio_path,
    #outputs=[output_bpm, output_viz, output_sonif, elapsed_time],
    outputs=[output_bpm, output_viz, elapsed_time, json_structure_output, bass, drums, other, vocals],
    api_name='analyze',
  )

if __name__ == '__main__':
  demo.launch()