File size: 7,110 Bytes
d8aed4f 980c8e0 d8aed4f 980c8e0 d8aed4f 980c8e0 d8aed4f 980c8e0 2a0a9d6 980c8e0 d8aed4f 980c8e0 2a0a9d6 980c8e0 d8aed4f 980c8e0 d8aed4f 980c8e0 d8aed4f 980c8e0 9e611ee d8aed4f 9e611ee d8aed4f 9e611ee d8aed4f 8f2275b d8aed4f 8f2275b 9e611ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
from speechbrain.inference.interfaces import Pretrained
import librosa
import numpy as np
import torchaudio
import os
class ASR(Pretrained):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch_whisper(self, device, wavs, wav_lens=None, normalize=False):
wavs = wavs.to(device)
wav_lens = wav_lens.to(device)
# Forward encoder + decoder
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
tokens = tokens.to(device)
enc_out, logits, _ = self.mods.whisper(wavs, tokens)
log_probs = self.hparams.log_softmax(logits)
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
predicted_words = [self.mods.whisper.tokenizer.decode(token, skip_special_tokens=True).strip() for token in hyps]
return predicted_words
def filter_repetitions(self, seq, max_repetition_length):
seq = list(seq)
output = []
max_n = len(seq) // 2
for n in range(max_n, 0, -1):
max_repetitions = max(max_repetition_length // n, 1)
# Don't need to iterate over impossible n values:
# len(seq) can change a lot during iteration
if (len(seq) <= n*2) or (len(seq) <= max_repetition_length):
continue
iterator = enumerate(seq)
# Fill first buffers:
buffers = [[next(iterator)[1]] for _ in range(n)]
for seq_index, token in iterator:
current_buffer = seq_index % n
if token != buffers[current_buffer][-1]:
# No repeat, we can flush some tokens
buf_len = sum(map(len, buffers))
flush_start = (current_buffer-buf_len) % n
# Keep n-1 tokens, but possibly mark some for removal
for flush_index in range(buf_len - buf_len%n):
if (buf_len - flush_index) > n-1:
to_flush = buffers[(flush_index + flush_start) % n].pop(0)
else:
to_flush = None
# Here, repetitions get removed:
if (flush_index // n < max_repetitions) and to_flush is not None:
output.append(to_flush)
elif (flush_index // n >= max_repetitions) and to_flush is None:
output.append(to_flush)
buffers[current_buffer].append(token)
# At the end, final flush
current_buffer += 1
buf_len = sum(map(len, buffers))
flush_start = (current_buffer-buf_len) % n
for flush_index in range(buf_len):
to_flush = buffers[(flush_index + flush_start) % n].pop(0)
# Here, repetitions just get removed:
if flush_index // n < max_repetitions:
output.append(to_flush)
seq = []
to_delete = 0
for token in output:
if token is None:
to_delete += 1
elif to_delete > 0:
to_delete -= 1
else:
seq.append(token)
output = []
return seq
def classify_file_whisper_mkd(self, file, vad_model, device):
# Get audio length in seconds
sr = 16000
max_segment_length = 30
# waveform, sr = librosa.load(file, sr=sr)
waveform, file_sr = torchaudio.load(file)
waveform = waveform.mean(dim=0, keepdim=True) # convert to mono
# resample if not 16kHz
if file_sr != sr:
waveform = torchaudio.transforms.Resample(file_sr, sr)(waveform)
# limit to 1 min
# waveform = waveform[:, :60*sr]
waveform = waveform.squeeze()
audio_length = len(waveform) / sr
print(f"Audio length: {audio_length:.2f} seconds")
if audio_length >= max_segment_length:
print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
# save waveform temporarily
torchaudio.save("temp.wav", waveform.unsqueeze(0), sr)
# get boundaries based on VAD
boundaries = vad_model.get_speech_segments("temp.wav",
large_chunk_size=30,
small_chunk_size=10,
apply_energy_VAD=True,
double_check=True)
# remove temp file
os.remove("temp.wav")
# Merge the segments to max max_segment_length
segments = []
current_start = boundaries[0][0].item()
current_end = boundaries[0][1].item()
for i in range(1, len(boundaries)):
next_start = boundaries[i][0].item()
next_end = boundaries[i][1].item()
# Check if the current segment can merge with the next segment
if (current_end - current_start) + (next_end - next_start) <= max_segment_length:
# Extend the current segment
current_end = next_end
else:
# Add the current segment to the result and start a new one
segments.append([current_start, current_end])
current_start = next_start
current_end = next_end
# Add the last segment
segments.append([current_start, current_end])
# Process each segment
outputs = []
for i, segment in enumerate(segments):
start, end = segment
start = int(start * sr)
end = int(end * sr)
segment = waveform[start:end]
print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
# import soundfile as sf
# sf.write(f"outputs/segment_{i}.wav", segment, sr)
segment_tensor = torch.tensor(segment).to(device)
# Fake a batch for the segment
batch = segment_tensor.unsqueeze(0).to(device)
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
# Pass the segment through the ASR model
segment_output = self.encode_batch_whisper(device, batch, rel_length)
# outputs.append(segment_output)
yield segment_output
else:
waveform = torch.tensor(waveform).to(device)
waveform = waveform.to(device)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0]).to(device)
# outputs.append(self.encode_batch_whisper(device, batch, rel_length))
outputs = self.encode_batch_whisper(device, batch, rel_length)
yield outputs
|