English ASR sequence-to-sequence model. This model supports output normalizing text, labeling timestamps, and segmenting multiple speakers.

# !pip install transformers sentencepiece

from transformers import SpeechEncoderDecoderModel
from transformers import AutoFeatureExtractor, AutoTokenizer, GenerationConfig
import torchaudio
import torch

model_path = 'nguyenvulebinh/wav2vec2-bartpho'
model = SpeechEncoderDecoderModel.from_pretrained(model_path).eval()
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if torch.cuda.is_available():
  model = model.cuda()


def decode_tokens(token_ids, skip_special_tokens=True, time_precision=0.02):
    timestamp_begin = tokenizer.vocab_size
    outputs = [[]]
    for token in token_ids:
        if token >= timestamp_begin:
            timestamp = f" |{(token - timestamp_begin) * time_precision:.2f}| "
            outputs.append(timestamp)
            outputs.append([])
        else:
            outputs[-1].append(token)
    outputs = [
        s if isinstance(s, str) else tokenizer.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
    ]
    return "".join(outputs).replace("< |", "<|").replace("| >", "|>")

def decode_wav(audio_wavs, asr_model, prefix=""):
  device = next(asr_model.parameters()).device
  input_values = feature_extractor.pad(
    [{"input_values": feature} for feature in audio_wavs],
    padding=True,
    max_length=None,
    pad_to_multiple_of=None,
    return_tensors="pt",
  )

  output_beam_ids = asr_model.generate(
    input_values['input_values'].to(device), 
    attention_mask=input_values['attention_mask'].to(device),
    decoder_input_ids=tokenizer.batch_encode_plus([prefix] * len(audio_wavs), return_tensors="pt")['input_ids'][..., :-1].to(device),
    generation_config=GenerationConfig(decoder_start_token_id=tokenizer.bos_token_id),
    max_length=250, 
    num_beams=25, 
    no_repeat_ngram_size=4, 
    num_return_sequences=1, 
    early_stopping=True,
    return_dict_in_generate=True,
    output_scores=True,
  )

  output_text = [decode_tokens(sequence) for sequence in output_beam_ids.sequences]

  return output_text


# https://huggingface.co/nguyenvulebinh/wavlm-bart/resolve/main/sample.wav
print(decode_wav([torchaudio.load('sample.wav')[0].squeeze()], model))

# <|0.06| What are the many parts that make a machine learning system feel like it works so magically cheap? |5.86|>
# <|5.68| Explletability factors important, so they tend to gear towards more simpler models with less parameters, but easier to explain, and on the other spectrum there are |15.86|>

Citation

This repository uses the idea from the following paper. Please cite the paper if this model is used to help produce published results or is incorporated into other software.

@INPROCEEDINGS{10446589,
  author={Nguyen, Thai-Binh and Waibel, Alexander},
  booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 
  title={Synthetic Conversations Improve Multi-Talker ASR}, 
  year={2024},
  volume={},
  number={},
  pages={10461-10465},
  keywords={Systematics;Error analysis;Knowledge based systems;Oral communication;Signal processing;Data models;Acoustics;multi-talker;asr;synthetic conversation},
  doi={10.1109/ICASSP48485.2024.10446589}
}

Contact

[email protected]

Follow

Downloads last month
23
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.