File size: 4,106 Bytes
a0dfd75
 
 
 
 
 
 
e9812a3
 
a0dfd75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9812a3
a0dfd75
 
 
 
e9812a3
 
 
 
a0dfd75
 
 
 
e9812a3
a0dfd75
 
 
 
 
 
 
 
 
 
e9812a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0dfd75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9812a3
 
a0dfd75
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import tensorflow as tf
from functools import lru_cache
from huggingface_hub import hf_hub_download
from hyperpyyaml import load_hyperpyyaml
from typing import Union

from decode import get_searcher

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def _get_checkpoint_filename(
    repo_id: str,
    filename: str,
    local_dir: str = None,
    local_dir_use_symlinks: Union[bool, str] = "auto",
    subfolder: str = "checkpoints"
) -> str:
    model_filename = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        subfolder=subfolder,
        local_dir=local_dir,
        local_dir_use_symlinks=local_dir_use_symlinks,
    )
    return model_filename


def _get_bpe_model_filename(
    repo_id: str,
    filename: str,
    local_dir: str = None,
    local_dir_use_symlinks: Union[bool, str] = "auto",
    subfolder: str = "vocabs"
) -> str:
    bpe_model_filename = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        subfolder=subfolder,
        local_dir=local_dir,
        local_dir_use_symlinks=local_dir_use_symlinks,
    )
    return bpe_model_filename


@lru_cache(maxsize=1)
def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"):
    for postfix in ["index", "data-00000-of-00001"]:
        tmp = _get_checkpoint_filename(
            repo_id=repo_id,
            filename="avg_top5_27-32.ckpt.{}".format(postfix),
            subfolder=checkpoint_dir,
            local_dir=os.path.dirname(__file__),  # noqa
            local_dir_use_symlinks=True,
        )
        print(tmp)

    for postfix in ["model", "vocab"]:
        tmp = _get_bpe_model_filename(
            repo_id=repo_id,
            filename="subword_vietnamese_500.{}".format(postfix),
            local_dir=os.path.dirname(__file__),  # noqa
            local_dir_use_symlinks=True,
        )
        print(tmp)

    config_path = hf_hub_download(
        repo_id=repo_id,
        filename="config.yaml",
        local_dir=os.path.dirname(__file__),  # noqa
        local_dir_use_symlinks=True,
    )

    with open(config_path, "r") as f:
        config = load_hyperpyyaml(f)

    encoder_model = config["encoder_model"]
    text_encoder = config["text_encoder"]
    jointer = config["jointer_model"]
    decoder = config["decoder_model"]
    # searcher = config["decoder"]
    model = config["model"]
    audio_encoder = config["audio_encoder"]
    model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial()

    return audio_encoder, encoder_model, jointer, decoder, text_encoder, model


def read_audio(in_filename: str):
    audio = tf.io.read_file(in_filename)
    audio = tf.audio.decode_wav(audio)[0]
    audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0)
    return audio


class UETASRModel:
    def __init__(
        self,
        repo_id: str,
        decoding_method: str,
        beam_size: int,
        max_symbols_per_step: int,
    ):
        self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id)
        self.searcher = get_searcher(
            decoding_method,
            decoder,
            jointer,
            text_encoder,
            beam_size,
            max_symbols_per_step,
        )

    def predict(self, in_filename: str):
        inputs = read_audio(in_filename)
        features = self.featurizer(inputs)
        features = self.model.cmvn(features) if self.model.use_cmvn else features

        mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1])
        mask = tf.expand_dims(mask, axis=1)
        encoder_outputs, encoder_masks = self.encoder_model(
            features, mask, training=False)

        encoder_mask = tf.squeeze(encoder_masks, axis=1)
        features_length = tf.math.reduce_sum(
            tf.cast(encoder_mask, tf.int32),
            axis=1
        )

        outputs = self.searcher.infer(encoder_outputs, features_length)
        outputs = tf.squeeze(outputs)
        outputs = tf.compat.as_str_any(outputs.numpy())

        return outputs