import os import tensorflow as tf from functools import lru_cache from huggingface_hub import hf_hub_download from hyperpyyaml import load_hyperpyyaml from typing import Union from decode import get_searcher os.environ["CUDA_VISIBLE_DEVICES"] = "-1" def _get_checkpoint_filename( repo_id: str, filename: str, local_dir: str = None, local_dir_use_symlinks: Union[bool, str] = "auto", subfolder: str = "checkpoints" ) -> str: model_filename = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, ) return model_filename def _get_bpe_model_filename( repo_id: str, filename: str, local_dir: str = None, local_dir_use_symlinks: Union[bool, str] = "auto", subfolder: str = "vocabs" ) -> str: bpe_model_filename = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, ) return bpe_model_filename @lru_cache(maxsize=1) def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"): for postfix in ["index", "data-00000-of-00001"]: tmp = _get_checkpoint_filename( repo_id=repo_id, filename="avg_top5_27-32.ckpt.{}".format(postfix), subfolder=checkpoint_dir, local_dir=os.path.dirname(__file__), # noqa local_dir_use_symlinks=True, ) print(tmp) for postfix in ["model", "vocab"]: tmp = _get_bpe_model_filename( repo_id=repo_id, filename="subword_vietnamese_500.{}".format(postfix), local_dir=os.path.dirname(__file__), # noqa local_dir_use_symlinks=True, ) print(tmp) config_path = hf_hub_download( repo_id=repo_id, filename="config.yaml", local_dir=os.path.dirname(__file__), # noqa local_dir_use_symlinks=True, ) with open(config_path, "r") as f: config = load_hyperpyyaml(f) encoder_model = config["encoder_model"] text_encoder = config["text_encoder"] jointer = config["jointer_model"] decoder = config["decoder_model"] # searcher = config["decoder"] model = config["model"] audio_encoder = config["audio_encoder"] model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial() return audio_encoder, encoder_model, jointer, decoder, text_encoder, model def read_audio(in_filename: str): audio = tf.io.read_file(in_filename) audio = tf.audio.decode_wav(audio)[0] audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0) return audio class UETASRModel: def __init__( self, repo_id: str, decoding_method: str, beam_size: int, max_symbols_per_step: int, ): self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id) self.searcher = get_searcher( decoding_method, decoder, jointer, text_encoder, beam_size, max_symbols_per_step, ) def predict(self, in_filename: str): inputs = read_audio(in_filename) features = self.featurizer(inputs) features = self.model.cmvn(features) if self.model.use_cmvn else features mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1]) mask = tf.expand_dims(mask, axis=1) encoder_outputs, encoder_masks = self.encoder_model( features, mask, training=False) encoder_mask = tf.squeeze(encoder_masks, axis=1) features_length = tf.math.reduce_sum( tf.cast(encoder_mask, tf.int32), axis=1 ) outputs = self.searcher.infer(encoder_outputs, features_length) outputs = tf.squeeze(outputs) outputs = tf.compat.as_str_any(outputs.numpy()) return outputs