File size: 3,851 Bytes
7bcf8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5442f52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f53272
5442f52
 
 
 
 
 
 
 
 
 
 
 
4090e0d
7bcf8d7
d4e3aaf
 
 
 
 
a5b1a23
d4e3aaf
 
 
 
 
 
 
 
7bcf8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7f5da
7bcf8d7
 
 
6f53272
 
 
7bcf8d7
 
 
 
 
cd75560
 
 
7bcf8d7
 
 
a45002a
 
7bcf8d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json

from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder

ASR_SAMPLING_RATE = 16_000

ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        ASR_LANGUAGES[iso] = name

MODEL_ID = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)


# lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename="decoding_config.json",
#     subfolder="mms-1b-all",
# )

# with open(lm_decoding_configfile) as f:
#     lm_decoding_config = json.loads(f.read())

# # allow language model decoding for "eng"

# decoding_config = lm_decoding_config["eng"]

# lm_file = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename=decoding_config["lmfile"].rsplit("/", 1)[1],
#     subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
# )
# token_file = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
#     subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
# )
# lexicon_file = None
# if decoding_config["lexiconfile"] is not None:
#     lexicon_file = hf_hub_download(
#         repo_id="facebook/mms-cclms",
#         filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
#         subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
#     )
    
# beam_search_decoder = ctc_decoder(
#     lexicon=lexicon_file,
#     tokens=token_file,
#     lm=lm_file,
#     nbest=1,
#     beam_size=500,
#     beam_size_token=50,
#     lm_weight=float(decoding_config["lmweight"]),
#     word_score=float(decoding_config["wordscore"]),
#     sil_score=float(decoding_config["silweight"]),
#     blank_token="<s>",
# )


def transcribe(audio_data, lang="eng (English)"):    

    if isinstance(audio_data, tuple):
        # microphone
        sr, audio_samples = audio_data
        audio_samples = (audio_samples/32768.0).astype(np.float)
        print("case1", audio_samples[:5])
        assert sr == ASR_SAMPLING_RATE, "Invalid sampling rate"
    else:
        # file upload
        isinstance(audio_data, str)
        print("case2 1", audio_data)
        audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
        print("case2", audio_samples[:5])

    lang_code = lang.split()[0]
    processor.tokenizer.set_target_lang(lang_code)
    model.load_adapter(lang_code)

    inputs = processor(
        audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
    )

    # set device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif (
        hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    ):
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    model.to(device)
    inputs = inputs.to(device)

    with torch.no_grad():
        outputs = model(**inputs).logits

    if lang_code != "eng" or True:
        ids = torch.argmax(outputs, dim=-1)[0]
        transcription = processor.decode(ids)
    else:
        assert False 
        # beam_search_result = beam_search_decoder(outputs.to("cpu"))
        # transcription = " ".join(beam_search_result[0][0].words).strip()

    return transcription


ASR_EXAMPLES = [
    ["assets/english.mp3", "eng (English)"],
    # ["assets/tamil.mp3", "tam (Tamil)"],
    # ["assets/burmese.mp3",  "mya (Burmese)"],
]

ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model. 
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
"""