bel32123's picture
Adjust to use multitask model
1e93f37
raw
history blame
4.35 kB
import streamlit as st
from speechbrain.pretrained import GraphemeToPhoneme
import os
import torchaudio
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel, MultitaskPhonemeASRModel
import torch
@st.cache_resource
def load_model():
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt")
vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab")
device = "cpu"
asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device)
return mispronounciation_detector
def save_file(sound_file):
# save your sound file in the right folder by following the path
audio_folder_path = os.path.join(os.getcwd(), 'audio_files')
if not os.path.exists(audio_folder_path):
os.makedirs(audio_folder_path)
with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f:
f.write(sound_file.getbuffer())
return sound_file.name
@st.cache_data
def get_audio(saved_sound_filename):
audio_path = f'audio_files/{saved_sound_filename}'
audio, org_sr = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
audio = audio.view(audio.shape[1])
return audio
def mispronounciation_detection_section():
st.write('# Prediction')
st.write('1. Upload a recording of you saying the text in .wav format')
uploaded_file = st.file_uploader(' ', type='wav')
st.write('2. Input the text you are saying in your recording')
text = st.text_input(
"Enter the text you want to read πŸ‘‡",
label_visibility='collapsed'
)
if st.button('Predict'):
if uploaded_file is not None and len(text) > 0:
# get audio from loaded file
save_file(uploaded_file)
audio = get_audio(uploaded_file.name)
# load model
mispronunciation_detector = load_model()
# start prediction
st.write('# Detection Results')
with st.spinner('Predicting...'):
raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)
st.write('#### Phoneme Level Analysis')
st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
st.markdown(
f"""
<style>
textarea {{
white-space: nowrap;
}}
</style>
```
{raw_info['ref']}
{raw_info['hyp']}
{raw_info['phoneme_errors']}
```
""",
unsafe_allow_html=True,
)
st.divider()
md = []
for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
if has_error:
md.append(f"**{word}**")
else:
md.append(word)
st.write('#### Word Level Analysis')
st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:")
st.markdown(" ".join(md))
else:
st.error('The audio or text has not been properly input', icon="🚨")
return
if __name__ == '__main__':
st.write('___')
# create a sidebar
st.sidebar.title('Pronounciation Evaluation')
select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection'], key='1', label_visibility='collapsed')
st.sidebar.write(select)
if select=='Mispronounciation Detection':
mispronounciation_detection_section()
# else: stay on the home page
else:
st.write('# Pronounciation Evaluation')
st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')