Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Deep speech decoder.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import itertools | |
from nltk.metrics import distance | |
import numpy as np | |
class DeepSpeechDecoder(object): | |
"""Greedy decoder implementation for Deep Speech model.""" | |
def __init__(self, labels, blank_index=28): | |
"""Decoder initialization. | |
Arguments: | |
labels: a string specifying the speech labels for the decoder to use. | |
blank_index: an integer specifying index for the blank character. | |
Defaults to 28. | |
""" | |
# e.g. labels = "[a-z]' _" | |
self.labels = labels | |
self.blank_index = blank_index | |
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) | |
def convert_to_string(self, sequence): | |
"""Convert a sequence of indexes into corresponding string.""" | |
return ''.join([self.int_to_char[i] for i in sequence]) | |
def wer(self, decode, target): | |
"""Computes the Word Error Rate (WER). | |
WER is defined as the edit distance between the two provided sentences after | |
tokenizing to words. | |
Args: | |
decode: string of the decoded output. | |
target: a string for the ground truth label. | |
Returns: | |
A float number for the WER of the current decode-target pair. | |
""" | |
# Map each word to a new char. | |
words = set(decode.split() + target.split()) | |
word2char = dict(zip(words, range(len(words)))) | |
new_decode = [chr(word2char[w]) for w in decode.split()] | |
new_target = [chr(word2char[w]) for w in target.split()] | |
return distance.edit_distance(''.join(new_decode), ''.join(new_target)) | |
def cer(self, decode, target): | |
"""Computes the Character Error Rate (CER). | |
CER is defined as the edit distance between the two given strings. | |
Args: | |
decode: a string of the decoded output. | |
target: a string for the ground truth label. | |
Returns: | |
A float number denoting the CER for the current sentence pair. | |
""" | |
return distance.edit_distance(decode, target) | |
def decode(self, logits): | |
"""Decode the best guess from logits using greedy algorithm.""" | |
# Choose the class with maximimum probability. | |
best = list(np.argmax(logits, axis=1)) | |
# Merge repeated chars. | |
merge = [k for k, _ in itertools.groupby(best)] | |
# Remove the blank index in the decoded sequence. | |
merge_remove_blank = [] | |
for k in merge: | |
if k != self.blank_index: | |
merge_remove_blank.append(k) | |
return self.convert_to_string(merge_remove_blank) | |