bel32123 commited on
Commit
2676061
·
1 Parent(s): 490c46f

Introduce uncertainty to word error with PER threshold

Browse files
wav2vecasr/MispronounciationDetector.py CHANGED
@@ -1,6 +1,7 @@
1
  from pandas.core.construction import T
2
  import torch
3
  import jiwer
 
4
 
5
  class MispronounciationDetector:
6
  def __init__(self, l2_phoneme_recogniser, g2p, device):
@@ -8,18 +9,19 @@ class MispronounciationDetector:
8
  self.g2p = g2p
9
  self.device = device
10
 
11
- def detect(self, audio, text):
12
  l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
 
13
  native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
14
  standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
15
- raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones)
16
  return raw_info
17
 
18
  def get_native_speaker_phoneme_sequence(self, text):
19
  phonemes = self.g2p(text)
20
  return phonemes
21
 
22
- def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
23
  """
24
  Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
25
  :param text: original words read by the user
@@ -101,7 +103,7 @@ class MispronounciationDetector:
101
  # get mispronounced words based on if there are phoneme errors present in the phonemes of that word
102
  aligned_word_error_output = ""
103
  words = text.split(" ")
104
- word_error_bool = self.get_mispronounced_words(error_bool)
105
  wer = sum(word_error_bool) / len(words)
106
 
107
  raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
@@ -109,16 +111,27 @@ class MispronounciationDetector:
109
  return raw_info
110
 
111
 
112
- def get_mispronounced_words(self, phoneme_error_bool):
113
  # map mispronounced phones back to words that were mispronounce to get WER
114
  word_error_bool = []
115
  phoneme_error_bool.append("|")
116
  word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
 
 
117
  for phones in word_phones:
118
- if "s" in phones or "d" in phones or "a" in phones:
 
 
 
 
 
 
 
 
119
  word_error_bool.append(True)
120
  else:
121
  word_error_bool.append(False)
 
122
  return word_error_bool
123
 
124
 
 
1
  from pandas.core.construction import T
2
  import torch
3
  import jiwer
4
+ import re
5
 
6
  class MispronounciationDetector:
7
  def __init__(self, l2_phoneme_recogniser, g2p, device):
 
9
  self.g2p = g2p
10
  self.device = device
11
 
12
+ def detect(self, audio, text, phoneme_error_threshold=0.25):
13
  l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
14
+ l2_phones = [re.sub(r'\d', "", phone_str) for phone_str in l2_phones] #g2p has no lexical stress
15
  native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
16
  standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
17
+ raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones, phoneme_error_threshold)
18
  return raw_info
19
 
20
  def get_native_speaker_phoneme_sequence(self, text):
21
  phonemes = self.g2p(text)
22
  return phonemes
23
 
24
+ def get_mispronounciation_output(self, text, pred_phones, org_label_phones, phoneme_error_threshold):
25
  """
26
  Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
27
  :param text: original words read by the user
 
103
  # get mispronounced words based on if there are phoneme errors present in the phonemes of that word
104
  aligned_word_error_output = ""
105
  words = text.split(" ")
106
+ word_error_bool = self.get_mispronounced_words(error_bool, phoneme_error_threshold)
107
  wer = sum(word_error_bool) / len(words)
108
 
109
  raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
 
111
  return raw_info
112
 
113
 
114
+ def get_mispronounced_words(self, phoneme_error_bool, phoneme_error_threshold):
115
  # map mispronounced phones back to words that were mispronounce to get WER
116
  word_error_bool = []
117
  phoneme_error_bool.append("|")
118
  word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
119
+
120
+ # wrong only if percentage of phones that are wrong > phoneme error threshold
121
  for phones in word_phones:
122
+
123
+ # get count of "s", "d", "a" in phones
124
+ error_count = 0
125
+ for phone in phones:
126
+ if phone == "s" or phone == "d" or phone == "a":
127
+ error_count += 1
128
+
129
+ # check if pass threshold
130
+ if error_count / len(phones) > phoneme_error_threshold:
131
  word_error_bool.append(True)
132
  else:
133
  word_error_bool.append(False)
134
+
135
  return word_error_bool
136
 
137