File size: 4,106 Bytes
a0dfd75 e9812a3 a0dfd75 e9812a3 a0dfd75 e9812a3 a0dfd75 e9812a3 a0dfd75 e9812a3 a0dfd75 e9812a3 a0dfd75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import tensorflow as tf
from functools import lru_cache
from huggingface_hub import hf_hub_download
from hyperpyyaml import load_hyperpyyaml
from typing import Union
from decode import get_searcher
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def _get_checkpoint_filename(
repo_id: str,
filename: str,
local_dir: str = None,
local_dir_use_symlinks: Union[bool, str] = "auto",
subfolder: str = "checkpoints"
) -> str:
model_filename = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
)
return model_filename
def _get_bpe_model_filename(
repo_id: str,
filename: str,
local_dir: str = None,
local_dir_use_symlinks: Union[bool, str] = "auto",
subfolder: str = "vocabs"
) -> str:
bpe_model_filename = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
)
return bpe_model_filename
@lru_cache(maxsize=1)
def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"):
for postfix in ["index", "data-00000-of-00001"]:
tmp = _get_checkpoint_filename(
repo_id=repo_id,
filename="avg_top5_27-32.ckpt.{}".format(postfix),
subfolder=checkpoint_dir,
local_dir=os.path.dirname(__file__), # noqa
local_dir_use_symlinks=True,
)
print(tmp)
for postfix in ["model", "vocab"]:
tmp = _get_bpe_model_filename(
repo_id=repo_id,
filename="subword_vietnamese_500.{}".format(postfix),
local_dir=os.path.dirname(__file__), # noqa
local_dir_use_symlinks=True,
)
print(tmp)
config_path = hf_hub_download(
repo_id=repo_id,
filename="config.yaml",
local_dir=os.path.dirname(__file__), # noqa
local_dir_use_symlinks=True,
)
with open(config_path, "r") as f:
config = load_hyperpyyaml(f)
encoder_model = config["encoder_model"]
text_encoder = config["text_encoder"]
jointer = config["jointer_model"]
decoder = config["decoder_model"]
# searcher = config["decoder"]
model = config["model"]
audio_encoder = config["audio_encoder"]
model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial()
return audio_encoder, encoder_model, jointer, decoder, text_encoder, model
def read_audio(in_filename: str):
audio = tf.io.read_file(in_filename)
audio = tf.audio.decode_wav(audio)[0]
audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0)
return audio
class UETASRModel:
def __init__(
self,
repo_id: str,
decoding_method: str,
beam_size: int,
max_symbols_per_step: int,
):
self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id)
self.searcher = get_searcher(
decoding_method,
decoder,
jointer,
text_encoder,
beam_size,
max_symbols_per_step,
)
def predict(self, in_filename: str):
inputs = read_audio(in_filename)
features = self.featurizer(inputs)
features = self.model.cmvn(features) if self.model.use_cmvn else features
mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1])
mask = tf.expand_dims(mask, axis=1)
encoder_outputs, encoder_masks = self.encoder_model(
features, mask, training=False)
encoder_mask = tf.squeeze(encoder_masks, axis=1)
features_length = tf.math.reduce_sum(
tf.cast(encoder_mask, tf.int32),
axis=1
)
outputs = self.searcher.infer(encoder_outputs, features_length)
outputs = tf.squeeze(outputs)
outputs = tf.compat.as_str_any(outputs.numpy())
return outputs
|