sarinam's picture
Initial commit
574ab7e
raw
history blame
955 Bytes
from espnet2.bin.asr_inference import Speech2Text
import resampy
from espnet_model_zoo.downloader import ModelDownloader
TAGS_TO_MODELS = {
'phones': 'asr_tts-phn_en.zip',
'STT': 'asr_stt_en.zip',
'TTS': 'asr_tts_en.zip'
}
class DemoASR:
def __init__(self, model_path, model_tag, device):
self.model_tag = model_tag
d = ModelDownloader()
self.speech2text = Speech2Text(
**d.download_and_unpack(str(model_path / TAGS_TO_MODELS[self.model_tag])),
device=str(device),
minlenratio=0.0,
maxlenratio=0.0,
ctc_weight=0.4,
beam_size=15,
batch_size=1,
nbest=1
)
def recognize_speech(self, audio, sr):
if len(audio.shape) == 2:
audio = audio.T[0]
speech = resampy.resample(audio, sr, 16000)
nbests = self.speech2text(speech)
text, *_ = nbests[0]
return text