|
import os |
|
import tensorflow as tf |
|
from functools import lru_cache |
|
from huggingface_hub import hf_hub_download |
|
from hyperpyyaml import load_hyperpyyaml |
|
from typing import Union |
|
|
|
from decode import get_searcher |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
|
|
def _get_checkpoint_filename( |
|
repo_id: str, |
|
filename: str, |
|
local_dir: str = None, |
|
local_dir_use_symlinks: Union[bool, str] = "auto", |
|
subfolder: str = "checkpoints" |
|
) -> str: |
|
model_filename = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
subfolder=subfolder, |
|
local_dir=local_dir, |
|
local_dir_use_symlinks=local_dir_use_symlinks, |
|
) |
|
return model_filename |
|
|
|
|
|
def _get_bpe_model_filename( |
|
repo_id: str, |
|
filename: str, |
|
local_dir: str = None, |
|
local_dir_use_symlinks: Union[bool, str] = "auto", |
|
subfolder: str = "vocabs" |
|
) -> str: |
|
bpe_model_filename = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
subfolder=subfolder, |
|
local_dir=local_dir, |
|
local_dir_use_symlinks=local_dir_use_symlinks, |
|
) |
|
return bpe_model_filename |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"): |
|
for postfix in ["index", "data-00000-of-00001"]: |
|
tmp = _get_checkpoint_filename( |
|
repo_id=repo_id, |
|
filename="avg_top5_27-32.ckpt.{}".format(postfix), |
|
subfolder=checkpoint_dir, |
|
local_dir=os.path.dirname(__file__), |
|
local_dir_use_symlinks=True, |
|
) |
|
print(tmp) |
|
|
|
for postfix in ["model", "vocab"]: |
|
tmp = _get_bpe_model_filename( |
|
repo_id=repo_id, |
|
filename="subword_vietnamese_500.{}".format(postfix), |
|
local_dir=os.path.dirname(__file__), |
|
local_dir_use_symlinks=True, |
|
) |
|
print(tmp) |
|
|
|
config_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="config.yaml", |
|
local_dir=os.path.dirname(__file__), |
|
local_dir_use_symlinks=True, |
|
) |
|
|
|
with open(config_path, "r") as f: |
|
config = load_hyperpyyaml(f) |
|
|
|
encoder_model = config["encoder_model"] |
|
text_encoder = config["text_encoder"] |
|
jointer = config["jointer_model"] |
|
decoder = config["decoder_model"] |
|
|
|
model = config["model"] |
|
audio_encoder = config["audio_encoder"] |
|
model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial() |
|
|
|
return audio_encoder, encoder_model, jointer, decoder, text_encoder, model |
|
|
|
|
|
def read_audio(in_filename: str): |
|
audio = tf.io.read_file(in_filename) |
|
audio = tf.audio.decode_wav(audio)[0] |
|
audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0) |
|
return audio |
|
|
|
|
|
class UETASRModel: |
|
def __init__( |
|
self, |
|
repo_id: str, |
|
decoding_method: str, |
|
beam_size: int, |
|
max_symbols_per_step: int, |
|
): |
|
self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id) |
|
self.searcher = get_searcher( |
|
decoding_method, |
|
decoder, |
|
jointer, |
|
text_encoder, |
|
beam_size, |
|
max_symbols_per_step, |
|
) |
|
|
|
def predict(self, in_filename: str): |
|
inputs = read_audio(in_filename) |
|
features = self.featurizer(inputs) |
|
features = self.model.cmvn(features) if self.model.use_cmvn else features |
|
|
|
mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1]) |
|
mask = tf.expand_dims(mask, axis=1) |
|
encoder_outputs, encoder_masks = self.encoder_model( |
|
features, mask, training=False) |
|
|
|
encoder_mask = tf.squeeze(encoder_masks, axis=1) |
|
features_length = tf.math.reduce_sum( |
|
tf.cast(encoder_mask, tf.int32), |
|
axis=1 |
|
) |
|
|
|
outputs = self.searcher.infer(encoder_outputs, features_length) |
|
outputs = tf.squeeze(outputs) |
|
outputs = tf.compat.as_str_any(outputs.numpy()) |
|
|
|
return outputs |
|
|