Spaces:
Runtime error
Runtime error
import streamlit as st | |
import langcodes | |
from allosaurus.app import read_recognizer | |
from pathlib import Path | |
import string | |
from itertools import permutations | |
from collections import defaultdict | |
import torchaudio | |
def get_supported_codes(): | |
model = read_recognizer() | |
supported_codes = [] | |
supported_codes.append("ipa") # default option | |
for combo in permutations(string.ascii_lowercase, r=3): | |
code = "".join(combo) | |
if model.is_available(code): | |
supported_codes.append(code) | |
return supported_codes | |
def get_path_to_wav_format(uploaded_file, suppress_outputs=False): | |
# st.write(dir(uploaded_file)) | |
# st.write(type(uploaded_file)) | |
# st.write(uploaded_file) | |
uploaded_bytes = uploaded_file.getvalue() | |
actual_file_path = Path(uploaded_file.name) | |
actual_file_path.write_bytes(uploaded_bytes) | |
if ".wav" in uploaded_file.name: | |
return Path(uploaded_file.name) | |
if ".mp3" in uploaded_file.name: | |
new_desired_path = actual_file_path.with_suffix(".wav") | |
encoding="PCM_S" # Prevent encoding errors. https://stackoverflow.com/questions/60352850/wave-error-unknown-format-3-arises-when-trying-to-convert-a-wav-file-into-text | |
bits_per_sample=16 | |
waveform, sample_rate = torchaudio.load(actual_file_path) | |
if not suppress_outputs: | |
st.info(f"Allosaurus requires .wav files. Converting with torchaudio, encoding={encoding}, bits_per_sample={bits_per_sample}") | |
st.info(f"Uploaded file sample_rate: {sample_rate}") | |
torchaudio.save(new_desired_path, waveform, sample_rate, | |
encoding=encoding, | |
bits_per_sample=bits_per_sample, | |
) | |
return new_desired_path | |
def get_langcode_description(input_code, url=False): | |
langcode = "ipa" # the default allosaurus recognizer | |
description = "the default universal setting, not specific to any language" | |
if not input_code or input_code==langcode: | |
return description | |
try: | |
lang = langcodes.get(input_code) | |
alpha3 = lang.to_alpha3() | |
langcode = alpha3 | |
display_name = lang.display_name() | |
if url: | |
description = f"[{display_name}](https://iso639-3.sil.org/code/{alpha3})" | |
else: | |
description = display_name | |
except langcodes.LanguageTagError as e: | |
pass | |
return description | |
def get_langcode_with_description(input_code): | |
return f"{input_code}: {get_langcode_description(input_code)}" | |
if __name__ == "__main__": | |
# input_code = st.text_input("(optional) 2 or 3-letter ISO code for input language. 2-letter codes will be converted to 3-letter codes", max_chars=3) | |
supported_codes = get_supported_codes() | |
index_of_desired_default = supported_codes.index("ipa") | |
langcode = st.selectbox("ISO code for input language. Allosaurus doesn't need this, but it can improve accuracy", | |
options=supported_codes, | |
index=index_of_desired_default, | |
format_func=get_langcode_with_description | |
) | |
model = read_recognizer() | |
description = get_langcode_description(langcode, url=True) | |
st.write(f"Instructing Allosaurus to recognize using language {langcode}. That is, {description}") | |
uploaded_files = st.file_uploader("Choose a file", type=[ | |
".wav", | |
".mp3", | |
], | |
accept_multiple_files=True, | |
) | |
results = {} # for better download/display | |
uploaded_files_count = len(uploaded_files) | |
suppress_output_threshold = 2 | |
my_bar = st.progress(0) | |
for i, uploaded_file in enumerate(uploaded_files): | |
if uploaded_file is not None: | |
wav_file = get_path_to_wav_format(uploaded_file, uploaded_files_count>suppress_output_threshold) | |
result = model.recognize(wav_file, langcode) | |
results[uploaded_file.name] = result | |
my_bar.progress(i+1/uploaded_files_count) | |
st.write(results) |