import gradio as gr
import librosa
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch

# config
model_name = "vumichien/wav2vec2-large-xlsr-japanese-hiragana"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)


def process_audio_file(file):
    data, sr = librosa.load(file)
    if sr != 16000:
        data = librosa.resample(data, sr, 16000)
    print(data.shape)
    inputs = processor(data, sampling_rate=16000, return_tensors="pt", padding=True)
    return inputs


def transcribe(micro, file):
    if file is not None and micro is None:
        input_audio = file
    elif file is None and micro is not None:
        input_audio = micro
    else:
        input_audio = file
    inputs = process_audio_file(input_audio )
    with torch.no_grad():
        output_logit = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
    pred_ids = torch.argmax(output_logit, dim=-1)
    text = processor.batch_decode(pred_ids)[0]
    return text 


description = "A simple interface to transcribe from spoken Japanese to Hiragana."
article = "Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>."

inputs = [gr.Audio(source="microphone", type="filepath", optional=True),
          gr.Audio(source="upload", type="filepath", optional=True),
        ]
outputs = ["textbox"]
        
# examples = [["samples/BASIC5000_0001.wav",""],
#             ["samples/BASIC5000_0005.wav",""]
#         ]
iface = gr.Interface(
    fn=transcribe,
    inputs=inputs,
    outputs=outputs,
    layout="horizontal",
    theme="huggingface",
    title="Transcribe Japanese audio to Hiragana",
    description=description,
    article=article,
    allow_flagging='never',
    # examples=examples,
    live=True,
)
iface.launch(enable_queue=True)