File size: 3,836 Bytes
7bcf8d7 0a6cd62 7bcf8d7 5442f52 90945f2 5442f52 4090e0d 7bcf8d7 90945f2 d4e3aaf bcdaadd 9fea840 d4e3aaf 90945f2 7bcf8d7 1b7f5da 7bcf8d7 90945f2 6f53272 7bcf8d7 cd75560 7bcf8d7 a45002a 90945f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
ASR_SAMPLING_RATE = 16_000
ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
ASR_LANGUAGES[iso] = name
MODEL_ID = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
# lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename="decoding_config.json",
# subfolder="mms-1b-all",
# )
# with open(lm_decoding_configfile) as f:
# lm_decoding_config = json.loads(f.read())
# # allow language model decoding for "eng"
# decoding_config = lm_decoding_config["eng"]
# lm_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lmfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
# )
# token_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
# )
# lexicon_file = None
# if decoding_config["lexiconfile"] is not None:
# lexicon_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
# )
# beam_search_decoder = ctc_decoder(
# lexicon=lexicon_file,
# tokens=token_file,
# lm=lm_file,
# nbest=1,
# beam_size=500,
# beam_size_token=50,
# lm_weight=float(decoding_config["lmweight"]),
# word_score=float(decoding_config["wordscore"]),
# sil_score=float(decoding_config["silweight"]),
# blank_token="<s>",
# )
def transcribe(audio_data, lang="eng (English)"):
if isinstance(audio_data, tuple):
# microphone
sr, audio_samples = audio_data
audio_samples = (audio_samples / 32768.0).astype(np.float32)
if sr != ASR_SAMPLING_RATE:
audio_samples = librosa.resample(
audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
)
else:
# file upload
isinstance(audio_data, str)
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
lang_code = lang.split()[0]
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code)
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
if lang_code != "eng" or True:
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
else:
assert False
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
# transcription = " ".join(beam_search_result[0][0].words).strip()
return transcription
ASR_EXAMPLES = [
["assets/english.mp3", "eng (English)"],
# ["assets/tamil.mp3", "tam (Tamil)"],
# ["assets/burmese.mp3", "mya (Burmese)"],
]
ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model.
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
""" |