import gradio as gr
import openai
from t2a import text_to_audio
import joblib
from sentence_transformers import SentenceTransformer
import numpy as np
import os

reg = joblib.load('text_reg.joblib')
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
finetune = "davinci:ft-personal:autodrummer-v5-2022-11-04-22-34-07"

with open('description.txt', 'r') as f:
    description = f.read()
with open('article.txt', 'r') as f:
    article = f.read()

def get_note_text(prompt):
    prompt = prompt + " ->"
    # get completion from finetune
    response = openai.Completion.create(
        engine=finetune,
        prompt=prompt,
        temperature=0.5,
        max_tokens=200,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        stop=["###"]
    )
    return response.choices[0].text.strip()
    
def increment_count():
    with open('count.txt', 'r') as f:
        count = int(f.read())
    count += 1
    with open('count.txt', 'w') as f:
        f.write(str(count))

def get_drummer_output(prompt, tempo):
    openai.api_key = os.environ['key']
    if tempo == "fast":
        tempo = 138
    elif tempo == "slow":
        tempo = 100
    note_text = get_note_text(prompt)
    # note_text = note_text + " " + note_text
    # prompt_enc = model.encode([prompt])
    # bpm = int(reg.predict(prompt_enc)[0]) + 20
    audio = text_to_audio(note_text, tempo)
    audio = np.array(audio.get_array_of_samples(), dtype=np.float32)
    increment_count()
    return (96000, audio)

iface = gr.Interface(
    fn=get_drummer_output,
    inputs=[
        "text",
        gr.Radio(["fast", "slow"], label="Tempo", default="fast"),
    ],
    examples=[
        ["hiphop groove 808", "fast"],
        ["rock metal", "fast"],
        ["disco funk", "fast"],
    ],
    outputs="audio",
    title='Autodrummer',
    description=description,
    article=article,
)
iface.launch()