Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import itertools as it | |
from typing import Any, Dict, List | |
import torch | |
from fairseq.data.dictionary import Dictionary | |
from fairseq.models.fairseq_model import FairseqModel | |
class BaseDecoder: | |
def __init__(self, tgt_dict: Dictionary) -> None: | |
self.tgt_dict = tgt_dict | |
self.vocab_size = len(tgt_dict) | |
self.blank = ( | |
tgt_dict.index("<ctc_blank>") | |
if "<ctc_blank>" in tgt_dict.indices | |
else tgt_dict.bos() | |
) | |
if "<sep>" in tgt_dict.indices: | |
self.silence = tgt_dict.index("<sep>") | |
elif "|" in tgt_dict.indices: | |
self.silence = tgt_dict.index("|") | |
else: | |
self.silence = tgt_dict.eos() | |
def generate( | |
self, models: List[FairseqModel], sample: Dict[str, Any], **unused | |
) -> List[List[Dict[str, torch.LongTensor]]]: | |
encoder_input = { | |
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" | |
} | |
emissions = self.get_emissions(models, encoder_input) | |
return self.decode(emissions) | |
def get_emissions( | |
self, | |
models: List[FairseqModel], | |
encoder_input: Dict[str, Any], | |
) -> torch.FloatTensor: | |
model = models[0] | |
encoder_out = model(**encoder_input) | |
if hasattr(model, "get_logits"): | |
emissions = model.get_logits(encoder_out) | |
else: | |
emissions = model.get_normalized_probs(encoder_out, log_probs=True) | |
return emissions.transpose(0, 1).float().cpu().contiguous() | |
def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: | |
idxs = (g[0] for g in it.groupby(idxs)) | |
idxs = filter(lambda x: x != self.blank, idxs) | |
return torch.LongTensor(list(idxs)) | |
def decode( | |
self, | |
emissions: torch.FloatTensor, | |
) -> List[List[Dict[str, torch.LongTensor]]]: | |
raise NotImplementedError | |