Siddhant's picture
Update app.py
ad2eea0 verified
raw
history blame
10.1 kB
# import base64
# import pathlib
# import tempfile
import os
os.system("python -m unidic download")
import nltk
nltk.download('averaged_perceptron_tagger_eng')
import gradio as gr
from espnet2.bin.tts_inference import Text2Speech
from espnet2.utils.types import str_or_none
lang = 'English'
tag = 'kan-bayashi/ljspeech_vits' #@param ["kan-bayashi/ljspeech_tacotron2", "kan-bayashi/ljspeech_fastspeech", "kan-bayashi/ljspeech_fastspeech2", "kan-bayashi/ljspeech_conformer_fastspeech2", "kan-bayashi/ljspeech_joint_finetune_conformer_fastspeech2_hifigan", "kan-bayashi/ljspeech_joint_train_conformer_fastspeech2_hifigan", "kan-bayashi/ljspeech_vits"] {type:"string"}
vocoder_tag = "none"
text2speech = Text2Speech.from_pretrained(
train_config="tts_model/exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/config.yaml",
model_file="tts_model/exp/tts_train_vits_raw_phn_tacotron_g2p_en_no_space/train.total_count.ave_10best.pth",
vocoder_tag=str_or_none(vocoder_tag),
device="cuda",
# Only for Tacotron 2 & Transformer
threshold=0.5,
# Only for Tacotron 2
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=False,
backward_window=1,
forward_window=3,
# Only for FastSpeech & FastSpeech2 & VITS
speed_control_alpha=1.0,
# Only for VITS
noise_scale=0.333,
noise_scale_dur=0.333,
)
# recorder_js = pathlib.Path('recorder.js').read_text()
# main_js = pathlib.Path('main.js').read_text()
# record_button_js = pathlib.Path('record_button.js').read_text().replace('let recorder_js = null;', recorder_js).replace(
# 'let main_js = null;', main_js)
# def save_base64_video(base64_string):
# base64_video = base64_string
# video_data = base64.b64decode(base64_video)
# with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
# temp_filename = temp_file.name
# temp_file.write(video_data)
# print(f"Temporary MP4 file saved as: {temp_filename}")
# return temp_filename
# import os
# os.system('python -m unidic download')
import numpy as np
from VAD.vad_iterator import VADIterator
import torch
import librosa
# from mlx_lm import load, stream_generate, generate
from LLM.chat import Chat
# from lightning_whisper_mlx import LightningWhisperMLX
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
)
# from melo.api import TTS
# LM_model, LM_tokenizer = load("mlx-community/SmolLM-360M-Instruct")
chat = Chat(2)
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words."})
user_role = "user"
# tts_model = TTS(language="EN_NEWEST", device="auto")
# speaker_id = tts_model.hps.data.spk2id["EN-Newest"]
blocksize = 512
with torch.no_grad():
wav = text2speech("Sid")["wav"]
# tts_model.tts_to_file("text", speaker_id, quiet=True)
dummy_input = torch.randn(
(3000),
dtype=getattr(torch, "float16"),
device="cpu",
).cpu().numpy()
import soundfile as sf
import kaldiio
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch
s2t = Speech2TextGreedySearch.from_pretrained(
"pyf98/owsm_ctc_v3.1_1B",
device="cuda",
generate_interctc_outputs=False,
lang_sym='<eng>',
task_sym='<asr>',
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
speech = librosa.util.fix_length(dummy_input, size=(16000 * 30))
res = s2t(speech)
end_event.record()
torch.cuda.synchronize()
def int2float(sound):
"""
Taken from https://github.com/snakers4/silero-vad
"""
abs_max = np.abs(sound).max()
sound = sound.astype("float32")
if abs_max > 0:
sound *= 1 / 32768
sound = sound.squeeze() # depends on the use case
return sound
text_str=""
vad_output=None
audio_output = None
min_speech_ms=500
max_speech_ms=float("inf")
# ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
# ASR_processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v3")
# ASR_model = AutoModelForSpeechSeq2Seq.from_pretrained(
# "distil-whisper/distil-large-v3",
# torch_dtype="float16",
# ).to("cpu")
LM_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
LM_model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM-360M-Instruct", torch_dtype="float16", trust_remote_code=True
).to("cuda")
LM_pipe = pipeline(
"text-generation", model=LM_model, tokenizer=LM_tokenizer, device="cuda"
)
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": "user", "content": dummy_input_text}]
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
LM_pipe(
dummy_chat,
max_new_tokens=32,
min_new_tokens=0,
temperature=0.0,
do_sample=False,
)
end_event.record()
torch.cuda.synchronize()
# vad_model, _ = torch.hub.load("snakers4/silero-vad:v4.0", "silero_vad")
# vad_iterator = VADIterator(
# vad_model,
# threshold=0.3,
# sampling_rate=16000,
# min_silence_duration_ms=250,
# speech_pad_ms=500,
# )
import webrtcvad
import time
def transcribe(stream, new_chunk):
sr, y = new_chunk
global text_str
global chat
global user_role
global audio_output
global vad_output
audio_int16 = np.frombuffer(y, dtype=np.int16)
audio_float32 = int2float(audio_int16)
audio_float32=librosa.resample(audio_float32, orig_sr=sr, target_sr=16000)
sr=16000
mel_spectrogram = librosa.feature.melspectrogram(y=audio_float32, sr=sr)
# Convert to decibels (log scale)
log_mel_spectrogram = librosa.power_to_db(mel_spectrogram)
print(log_mel_spectrogram)
print(sr)
print(audio_float32.shape)
# vad_output = vad_iterator(torch.from_numpy(audio_float32))
vad_count=0
for i in range(int(len(y)/960)):
vad = webrtcvad.Vad()
vad.set_mode(3)
if (vad.is_speech(y[i*960:(i+1)*960].tobytes(), orig_sr)):
vad_count+=1
if vad_count>10:
vad_curr=True
if vad_output is None:
vad_output=[torch.from_numpy(audio_float32)]
else:
vad_output.append(torch.from_numpy(audio_float32))
else:
vad_curr=False
if vad_output is not None and vad_curr==False:
print("VAD: end of speech detected")
array = torch.cat(vad_output).cpu().numpy()
duration_ms = len(array) / sr * 1000
if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
# input_features = ASR_processor(
# array, sampling_rate=16000, return_tensors="pt"
# ).input_features
# print(input_features)
# input_features = input_features.to("cpu", dtype=getattr(torch, "float16"))
# pred_ids = ASR_model.generate(input_features, max_new_tokens=128, min_new_tokens=0, num_beams=1, return_timestamps=False,task="transcribe",language="en")
# print(pred_ids)
# prompt = ASR_processor.batch_decode(
# pred_ids, skip_special_tokens=True, decode_with_timestamps=False
# )[0]
print(len(array))
start_time = time.time()
prompt=" ".join(s2t(array)[0][0].split()[1:])
# prompt=transcriber({"sampling_rate": sr, "raw": array})["text"]
print(prompt)
print("--- %s seconds ---" % (time.time() - start_time))
# prompt=ASR_model.transcribe(array)["text"].strip()
chat.append({"role": user_role, "content": prompt})
chat_messages = chat.to_list()
output=LM_pipe(
chat_messages,
max_new_tokens=64,
min_new_tokens=0,
temperature=0.0,
do_sample=False,
)
print("--- %s seconds ---" % (time.time() - start_time))
generated_text = output[0]['generated_text'][-1]["content"]
# torch.mps.empty_cache()
chat.append({"role": "assistant", "content": generated_text})
text_str=generated_text
# import pdb;pdb.set_trace()
with torch.no_grad():
audio_chunk = text2speech(text_str)["wav"].view(-1).cpu().numpy()
# audio_chunk = tts_model.tts_to_file(text_str, speaker_id, quiet=True)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
print(text2speech.fs)
audio_output=(text2speech.fs, audio_chunk)
print("--- %s seconds ---" % (time.time() - start_time))
# else:
# audio_output=None
text_str1=text_str
return stream, text_str1, audio_output
demo = gr.Interface(
transcribe,
["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))],
["state", "text", gr.Audio(label="Output", autoplay=True)],
live=True,
)
# with demo:
# start_button = gr.Button("Record Screen 🔴")
# video_component = gr.Video(interactive=True, show_share_button=True, include_audio=True)
# def toggle_button_label(returned_string):
# if returned_string.startswith("Record"):
# return gr.Button(value="Stop Recording ⚪"), None
# else:
# try:
# temp_filename = save_base64_video(returned_string)
# except Exception as e:
# return gr.Button(value="Record Screen 🔴"), gr.Warning(f'Failed to convert video to mp4:\n{e}')
# return gr.Button(value="Record Screen 🔴"), gr.Video(value=temp_filename, interactive=True,
# show_share_button=True)
# start_button.click(toggle_button_label, start_button, [start_button, video_component], js=record_button_js)
demo.launch("share=True")