|
|
|
|
|
|
|
import argparse |
|
import os |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torchaudio |
|
from tqdm import tqdm |
|
|
|
from transformers import AutoConfig, AutoModelForCTC, AutoProcessor |
|
|
|
|
|
class Wav2Vec2Aligner: |
|
def __init__(self, model_name, input_wavs_sr, cuda): |
|
self.cuda = cuda |
|
self.config = AutoConfig.from_pretrained(model_name) |
|
self.model = AutoModelForCTC.from_pretrained(model_name) |
|
self.model.eval() |
|
if self.cuda: |
|
self.model.to(device="cuda") |
|
self.processor = AutoProcessor.from_pretrained(model_name) |
|
self.resampler = torchaudio.transforms.Resample(input_wavs_sr, 16_000) |
|
blank_id = 0 |
|
vocab = list(self.processor.tokenizer.get_vocab().keys()) |
|
for i in range(len(vocab)): |
|
if vocab[i] == "[PAD]" or vocab[i] == "<pad>": |
|
blank_id = i |
|
print("Blank Token id [PAD]/<pad>", blank_id) |
|
self.blank_id = blank_id |
|
|
|
def speech_file_to_array_fn(self, wav_path): |
|
speech_array, sampling_rate = torchaudio.load(wav_path) |
|
speech = self.resampler(speech_array).squeeze().numpy() |
|
return speech |
|
|
|
def align_single_sample(self, item): |
|
blank_id = self.blank_id |
|
transcript = "|".join(item["sent"].split(" ")) |
|
if not os.path.isfile(item["wav_path"]): |
|
print(item["wav_path"], "not found in wavs directory") |
|
|
|
speech_array = self.speech_file_to_array_fn(item["wav_path"]) |
|
inputs = self.processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True) |
|
if self.cuda: |
|
inputs = inputs.to(device="cuda") |
|
|
|
with torch.no_grad(): |
|
logits = self.model(inputs.input_values).logits |
|
|
|
|
|
emissions = torch.log_softmax(logits, dim=-1) |
|
emission = emissions[0].cpu().detach() |
|
|
|
|
|
labels = ([""] + list(self.processor.tokenizer.get_vocab().keys()))[ |
|
:-1 |
|
] |
|
|
|
dictionary = {c: i for i, c in enumerate(labels)} |
|
tokens = [] |
|
for c in transcript: |
|
if c in dictionary: |
|
tokens.append(dictionary[c]) |
|
|
|
def get_trellis(emission, tokens, blank_id=0): |
|
""" |
|
Build a trellis matrix of shape (num_frames + 1, num_tokens + 1) |
|
that represents the probabilities of each source token being at a certain time step |
|
""" |
|
num_frames = emission.size(0) |
|
num_tokens = len(tokens) |
|
|
|
|
|
|
|
|
|
trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf")) |
|
trellis[:, 0] = 0 |
|
for t in range(num_frames): |
|
trellis[t + 1, 1:] = torch.maximum( |
|
|
|
trellis[t, 1:] + emission[t, blank_id], |
|
|
|
trellis[t, :-1] + emission[t, tokens], |
|
) |
|
return trellis |
|
|
|
trellis = get_trellis(emission, tokens, blank_id) |
|
|
|
@dataclass |
|
class Point: |
|
token_index: int |
|
time_index: int |
|
score: float |
|
|
|
def backtrack(trellis, emission, tokens, blank_id=0): |
|
""" |
|
Walk backwards from the last (sentence_token, time_step) pair to build the optimal sequence alignment path |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
j = trellis.size(1) - 1 |
|
t_start = torch.argmax(trellis[:, j]).item() |
|
|
|
path = [] |
|
for t in range(t_start, 0, -1): |
|
|
|
|
|
|
|
|
|
stayed = trellis[t - 1, j] + emission[t - 1, blank_id] |
|
|
|
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] |
|
|
|
|
|
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() |
|
|
|
path.append(Point(j - 1, t - 1, prob)) |
|
|
|
|
|
if changed > stayed: |
|
j -= 1 |
|
if j == 0: |
|
break |
|
else: |
|
raise ValueError("Failed to align") |
|
return path[::-1] |
|
|
|
path = backtrack(trellis, emission, tokens, blank_id) |
|
|
|
@dataclass |
|
class Segment: |
|
label: str |
|
start: int |
|
end: int |
|
score: float |
|
|
|
def __repr__(self): |
|
return f"{self.label}\t{self.score:4.2f}\t{self.start*20:5d}\t{self.end*20:5d}" |
|
|
|
@property |
|
def length(self): |
|
return self.end - self.start |
|
|
|
def merge_repeats(path): |
|
""" |
|
Merge repeated tokens into a single segment. Note: this shouldn't affect repeated characters from the |
|
original sentences (e.g. `ll` in `hello`) |
|
""" |
|
i1, i2 = 0, 0 |
|
segments = [] |
|
while i1 < len(path): |
|
while i2 < len(path) and path[i1].token_index == path[i2].token_index: |
|
i2 += 1 |
|
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) |
|
segments.append( |
|
Segment( |
|
transcript[path[i1].token_index], |
|
path[i1].time_index, |
|
path[i2 - 1].time_index + 1, |
|
score, |
|
) |
|
) |
|
i1 = i2 |
|
return segments |
|
|
|
segments = merge_repeats(path) |
|
with open(item["out_path"], "w") as out_align: |
|
for seg in segments: |
|
out_align.write(str(seg) + "\n") |
|
|
|
def align_data(self, wav_dir, text_file, output_dir): |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
|
|
lines = open(text_file, encoding="utf8").readlines() |
|
|
|
items = [] |
|
for line in lines: |
|
if len(line.strip().split("\t")) != 2: |
|
print("Script must be in format: 00001 this is my sentence") |
|
exit() |
|
|
|
wav_name, sentence = line.strip().split("\t") |
|
wav_path = os.path.join(wav_dir, wav_name + ".wav") |
|
out_path = os.path.join(output_dir, wav_name + ".txt") |
|
|
|
items.append({"sent": sentence, "wav_path": wav_path, "out_path": out_path}) |
|
print("Number of samples found in script file", len(items)) |
|
|
|
for item in tqdm(items): |
|
self.align_single_sample(item) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model_name", type=str, default="arijitx/wav2vec2-xls-r-300m-bengali", help="wav2vec model name" |
|
) |
|
parser.add_argument("--wav_dir", type=str, default="./wavs", help="directory containing wavs") |
|
parser.add_argument("--text_file", type=str, default="script.txt", help="file containing text") |
|
parser.add_argument("--input_wavs_sr", type=int, default=16000, help="sampling rate of input audios") |
|
parser.add_argument( |
|
"--output_dir", type=str, default="./out_alignment", help="output directory containing the alignment files" |
|
) |
|
parser.add_argument("--cuda", action="store_true") |
|
|
|
args = parser.parse_args() |
|
|
|
aligner = Wav2Vec2Aligner(args.model_name, args.input_wavs_sr, args.cuda) |
|
aligner.align_data(args.wav_dir, args.text_file, args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|