|
from abc import ABC, abstractmethod |
|
from collections import Counter, deque |
|
import os |
|
import time |
|
|
|
from typing import Any, Deque, Iterator, List, Dict |
|
|
|
from pprint import pprint |
|
from src.hooks.progressListener import ProgressListener |
|
from src.hooks.subTaskProgressListener import SubTaskProgressListener |
|
from src.hooks.whisperProgressHook import create_progress_listener_handle |
|
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache |
|
|
|
from src.segments import merge_timestamps |
|
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback |
|
|
|
|
|
try: |
|
import tensorflow as tf |
|
except ModuleNotFoundError: |
|
|
|
pass |
|
|
|
import torch |
|
|
|
import ffmpeg |
|
import numpy as np |
|
|
|
from src.utils import format_timestamp, len_wide |
|
from enum import Enum |
|
|
|
class NonSpeechStrategy(Enum): |
|
""" |
|
Ignore non-speech frames segments. |
|
""" |
|
SKIP = 1 |
|
""" |
|
Just treat non-speech segments as speech. |
|
""" |
|
CREATE_SEGMENT = 2 |
|
""" |
|
Expand speech segments into subsequent non-speech segments. |
|
""" |
|
EXPAND_SEGMENT = 3 |
|
|
|
|
|
SPEECH_TRESHOLD = 0.3 |
|
|
|
|
|
MIN_SEGMENT_DURATION = 1 |
|
|
|
|
|
MAX_PROMPT_WINDOW = 0 |
|
PROMPT_NO_SPEECH_PROB = 0.1 |
|
|
|
VAD_MAX_PROCESSING_CHUNK = 60 * 60 |
|
|
|
class TranscriptionConfig(ABC): |
|
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP, |
|
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None, |
|
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1): |
|
self.non_speech_strategy = non_speech_strategy |
|
self.segment_padding_left = segment_padding_left |
|
self.segment_padding_right = segment_padding_right |
|
self.max_silent_period = max_silent_period |
|
self.max_merge_size = max_merge_size |
|
self.max_prompt_window = max_prompt_window |
|
self.initial_segment_index = initial_segment_index |
|
|
|
class PeriodicTranscriptionConfig(TranscriptionConfig): |
|
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP, |
|
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None, |
|
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1): |
|
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index) |
|
self.periodic_duration = periodic_duration |
|
|
|
class AbstractTranscription(ABC): |
|
def __init__(self, sampling_rate: int = 16000): |
|
self.sampling_rate = sampling_rate |
|
|
|
def get_audio_segment(self, str, start_time: str = None, duration: str = None): |
|
return load_audio(str, self.sampling_rate, start_time, duration) |
|
|
|
def is_transcribe_timestamps_fast(self): |
|
""" |
|
Determine if get_transcribe_timestamps is fast enough to not need parallelization. |
|
""" |
|
return False |
|
|
|
@abstractmethod |
|
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float): |
|
""" |
|
Get the start and end timestamps of the sections that should be transcribed by this VAD method. |
|
|
|
Parameters |
|
---------- |
|
audio: str |
|
The audio file. |
|
config: TranscriptionConfig |
|
The transcription configuration. |
|
|
|
Returns |
|
------- |
|
A list of start and end timestamps, in fractional seconds. |
|
""" |
|
return |
|
|
|
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float): |
|
""" |
|
Get the start and end timestamps of the sections that should be transcribed by this VAD method, |
|
after merging the given segments using the specified configuration. |
|
|
|
Parameters |
|
---------- |
|
audio: str |
|
The audio file. |
|
config: TranscriptionConfig |
|
The transcription configuration. |
|
|
|
Returns |
|
------- |
|
A list of start and end timestamps, in fractional seconds. |
|
""" |
|
merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size, |
|
config.segment_padding_left, config.segment_padding_right) |
|
|
|
if config.non_speech_strategy != NonSpeechStrategy.SKIP: |
|
|
|
if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT): |
|
|
|
merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size) |
|
elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT: |
|
|
|
merged = self.expand_gaps(merged, total_duration=total_duration) |
|
else: |
|
raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy)) |
|
|
|
print("Transcribing non-speech:") |
|
pprint(merged) |
|
return merged |
|
|
|
def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig, |
|
progressListener: ProgressListener = None): |
|
""" |
|
Transcribe the given audo file. |
|
|
|
Parameters |
|
---------- |
|
audio: str |
|
The audio file. |
|
whisperCallable: WhisperCallback |
|
A callback object to call to transcribe each segment. |
|
|
|
Returns |
|
------- |
|
A list of start and end timestamps, in fractional seconds. |
|
""" |
|
|
|
try: |
|
max_audio_duration = self.get_audio_duration(audio, config) |
|
timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration) |
|
|
|
|
|
merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration) |
|
|
|
|
|
prompt_window = deque() |
|
|
|
print("Processing timestamps:") |
|
pprint(merged) |
|
|
|
result = { |
|
'text': "", |
|
'segments': [], |
|
'language': "" |
|
} |
|
languageCounter = Counter() |
|
detected_language = None |
|
|
|
segment_index = config.initial_segment_index |
|
|
|
|
|
progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0 |
|
progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged]) |
|
sub_task_total = 1/len(merged) |
|
|
|
|
|
for idx, segment in enumerate(merged): |
|
segment_index += 1 |
|
segment_start = segment['start'] |
|
segment_end = segment['end'] |
|
segment_expand_amount = segment.get('expand_amount', 0) |
|
segment_gap = segment.get('gap', False) |
|
|
|
segment_duration = segment_end - segment_start |
|
|
|
if segment_duration < MIN_SEGMENT_DURATION: |
|
continue |
|
|
|
|
|
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration)) |
|
|
|
segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None |
|
|
|
|
|
detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None |
|
|
|
print(f"Running whisper {idx}: from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", |
|
segment_duration, "expanded: ", segment_expand_amount, ", prompt: ", segment_prompt, ", detected language: ", detected_language) |
|
|
|
perf_start_time = time.perf_counter() |
|
|
|
scaled_progress_listener = SubTaskProgressListener(progressListener, |
|
base_task_total=progressListener.sub_task_total if isinstance(progressListener, SubTaskProgressListener) else progress_total_duration, |
|
sub_task_start=idx*(1/len(merged)), |
|
sub_task_total=1/len(merged)) |
|
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener) |
|
|
|
perf_end_time = time.perf_counter() |
|
print("\tWhisper took {} seconds".format(perf_end_time - perf_start_time)) |
|
|
|
adjusted_segments: List[Dict[str, Any]] = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration) |
|
|
|
if len(adjusted_segments) > 0: |
|
adjusted_segments[0]["segment_first"] = True |
|
adjusted_segments[-1]["segment_last"] = True |
|
|
|
|
|
if (segment_expand_amount > 0): |
|
segment_without_expansion = segment_duration - segment_expand_amount |
|
|
|
for adjusted_segment in adjusted_segments: |
|
adjusted_segment_end = adjusted_segment['end'] |
|
|
|
|
|
if (adjusted_segment_end > segment_without_expansion): |
|
adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion |
|
|
|
|
|
result['text'] += segment_result['text'] |
|
result['segments'].extend(adjusted_segments) |
|
|
|
|
|
if not segment_gap: |
|
languageCounter[segment_result['language']] += 1 |
|
|
|
|
|
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config) |
|
|
|
result['language'] = detected_language if detected_language is not None else segment_result['language'] |
|
finally: |
|
|
|
if progressListener is not None: |
|
progressListener.on_finished() |
|
return result |
|
|
|
def get_audio_duration(self, audio: str, config: TranscriptionConfig): |
|
return get_audio_duration(audio) |
|
|
|
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig): |
|
if (config.max_prompt_window is not None and config.max_prompt_window > 0): |
|
|
|
if not segment_gap: |
|
for segment in adjusted_segments: |
|
if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB: |
|
prompt_window.append(segment) |
|
|
|
while (len(prompt_window) > 0): |
|
first_end_time = prompt_window[0].get('end', 0) |
|
|
|
first_expand_time = prompt_window[0].get('expand_amount', 0) |
|
|
|
if (first_end_time - first_expand_time < segment_end - config.max_prompt_window): |
|
prompt_window.popleft() |
|
else: |
|
break |
|
|
|
def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float): |
|
result = [] |
|
last_end_time = 0 |
|
|
|
for segment in segments: |
|
segment_start = float(segment['start']) |
|
segment_end = float(segment['end']) |
|
|
|
if (last_end_time != segment_start): |
|
delta = segment_start - last_end_time |
|
|
|
if (min_gap_length is None or delta >= min_gap_length): |
|
result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } ) |
|
|
|
last_end_time = segment_end |
|
result.append(segment) |
|
|
|
|
|
if (total_duration is not None and last_end_time < total_duration): |
|
delta = total_duration - segment_start |
|
|
|
if (min_gap_length is None or delta >= min_gap_length): |
|
result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } ) |
|
|
|
return result |
|
|
|
|
|
def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float): |
|
result = [] |
|
|
|
if len(segments) == 0: |
|
return result |
|
|
|
|
|
if (segments[0]['start'] > 0): |
|
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } ) |
|
|
|
for i in range(len(segments) - 1): |
|
current_segment = segments[i] |
|
next_segment = segments[i + 1] |
|
|
|
delta = next_segment['start'] - current_segment['end'] |
|
|
|
|
|
if (delta >= 0): |
|
current_segment = current_segment.copy() |
|
current_segment['expand_amount'] = delta |
|
current_segment['end'] = next_segment['start'] |
|
|
|
result.append(current_segment) |
|
|
|
|
|
last_segment = segments[-1] |
|
result.append(last_segment) |
|
|
|
|
|
if (total_duration is not None): |
|
last_segment = result[-1] |
|
|
|
if (last_segment['end'] < total_duration): |
|
last_segment = last_segment.copy() |
|
last_segment['end'] = total_duration |
|
result[-1] = last_segment |
|
|
|
return result |
|
|
|
def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None): |
|
result = [] |
|
|
|
if len(segments) == 0: |
|
return result |
|
|
|
|
|
if (segments[0]['start'] > 0): |
|
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } ) |
|
|
|
for i in range(len(segments) - 1): |
|
expanded = False |
|
current_segment = segments[i] |
|
next_segment = segments[i + 1] |
|
|
|
delta = next_segment['start'] - current_segment['end'] |
|
|
|
if (max_expand_size is not None and delta <= max_expand_size): |
|
|
|
current_segment = current_segment.copy() |
|
current_segment['expand_amount'] = delta |
|
current_segment['end'] = next_segment['start'] |
|
expanded = True |
|
|
|
result.append(current_segment) |
|
|
|
|
|
if (delta >= 0 and not expanded): |
|
result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } ) |
|
|
|
|
|
last_segment = segments[-1] |
|
result.append(last_segment) |
|
|
|
|
|
if (total_duration is not None): |
|
last_segment = result[-1] |
|
|
|
delta = total_duration - last_segment['end'] |
|
|
|
if (delta > 0): |
|
if (max_expand_size is not None and delta <= max_expand_size): |
|
|
|
last_segment = last_segment.copy() |
|
last_segment['expand_amount'] = delta |
|
last_segment['end'] = total_duration |
|
result[-1] = last_segment |
|
else: |
|
result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } ) |
|
|
|
return result |
|
|
|
def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None): |
|
result = [] |
|
|
|
for segment in segments: |
|
segment_start = float(segment['start']) |
|
segment_end = float(segment['end']) |
|
|
|
|
|
if (max_source_time is not None): |
|
if (segment_start > max_source_time): |
|
continue |
|
segment_end = min(max_source_time, segment_end) |
|
|
|
|
|
|
|
new_segment = segment.copy() |
|
|
|
segment_duration = segment_end - segment_start |
|
if ("text" in segment and "words" in segment and segment_duration > 10): |
|
segment_words = new_segment["words"] |
|
del new_segment["text"] |
|
del new_segment["start"] |
|
del new_segment["end"] |
|
del new_segment["words"] |
|
sub_segment = new_segment.copy() |
|
sub_text = "" |
|
sub_words = [] |
|
word_length = 0 |
|
is_wide = False |
|
|
|
for idx, word in enumerate(segment_words): |
|
word2 = segment_words[idx + 1] if idx + 1 < len(segment_words) else None |
|
|
|
word["start"] = word["start"] + adjust_seconds |
|
word["end"] = word["end"] + adjust_seconds |
|
|
|
if "start" not in sub_segment: |
|
sub_segment["start"] = float(word["start"]) |
|
if not is_wide and len(word["word"]) > 1: |
|
is_wide = True |
|
|
|
sub_text += word["word"] |
|
sub_words.append(word) |
|
word_length += len_wide(word["word"]) |
|
if (sub_text.rstrip().endswith(".") or |
|
(word_length > 90 and (sub_text.rstrip().endswith(",") or sub_text.rstrip().endswith("?"))) or |
|
(word_length > 80 and is_wide and ( |
|
sub_text.rstrip().endswith(",") or sub_text.rstrip().endswith("?") or |
|
sub_text.rstrip().endswith("、") or sub_text.rstrip().endswith("。"))) or |
|
(word_length > 90 and is_wide and sub_text.endswith(" ")) or |
|
(word_length > 120 and word2 and (word2["word"].lstrip().startswith(",") or ((word2["word"].strip() in ["and", "or", "but"])))) or |
|
(word_length > 180 and sub_text.endswith(" "))): |
|
sub_segment["text"] = sub_text |
|
sub_segment["end"] = float(word["end"]) |
|
sub_segment["words"] = sub_words |
|
result.append(sub_segment) |
|
sub_segment = new_segment.copy() |
|
sub_text = "" |
|
sub_words = [] |
|
word_length = 0 |
|
is_wide = False |
|
if "start" in sub_segment: |
|
sub_segment["text"] = sub_text |
|
sub_segment["end"] = float(word["end"]) |
|
sub_segment["words"] = sub_words |
|
result.append(sub_segment) |
|
else: |
|
|
|
new_segment['start'] = segment_start + adjust_seconds |
|
new_segment["end"] = segment_end + adjust_seconds |
|
|
|
|
|
if ("words" in new_segment): |
|
for word in new_segment["words"]: |
|
|
|
word["start"] = word["start"] + adjust_seconds |
|
word["end"] = word["end"] + adjust_seconds |
|
|
|
result.append(new_segment) |
|
return result |
|
|
|
def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float): |
|
result = [] |
|
|
|
for entry in timestamps: |
|
start = entry['start'] |
|
end = entry['end'] |
|
|
|
result.append({ |
|
'start': start * factor, |
|
'end': end * factor |
|
}) |
|
return result |
|
|
|
|
|
class VadSileroTranscription(AbstractTranscription): |
|
def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None): |
|
super().__init__(sampling_rate=sampling_rate) |
|
self.model = None |
|
self.cache = cache |
|
self._initialize_model() |
|
|
|
def _initialize_model(self): |
|
if (self.cache is not None): |
|
model_key = "VadSileroTranscription" |
|
self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model) |
|
print("Loaded Silerio model from cache.") |
|
else: |
|
self.model, self.get_speech_timestamps = self._create_model() |
|
print("Created Silerio model") |
|
|
|
def _create_model(self): |
|
""" |
|
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils |
|
https://github.com/snakers4/silero-vad |
|
def get_speech_timestamps(audio: torch.Tensor, |
|
model, |
|
threshold: float = 0.5, |
|
sampling_rate: int = 16000, |
|
min_speech_duration_ms: int = 250, |
|
max_speech_duration_s: float = float('inf'), |
|
min_silence_duration_ms: int = 100, |
|
speech_pad_ms: int = 30, |
|
return_seconds: bool = False, |
|
visualize_probs: bool = False, |
|
progress_tracking_callback: Callable[[float], None] = None, |
|
neg_threshold: float = None, |
|
window_size_samples: int = 512,): |
|
|
|
This method is used for splitting long audios into speech chunks using silero VAD |
|
|
|
Parameters |
|
---------- |
|
audio: torch.Tensor, one dimensional |
|
One dimensional float torch.Tensor, other types are casted to torch if possible |
|
|
|
model: preloaded .jit/.onnx silero VAD model |
|
|
|
threshold: float (default - 0.5) |
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. |
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. |
|
|
|
sampling_rate: int (default - 16000) |
|
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates |
|
|
|
min_speech_duration_ms: int (default - 250 milliseconds) |
|
Final speech chunks shorter min_speech_duration_ms are thrown out |
|
|
|
max_speech_duration_s: int (default - inf) |
|
Maximum duration of speech chunks in seconds |
|
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting. |
|
Otherwise, they will be split aggressively just before max_speech_duration_s. |
|
|
|
min_silence_duration_ms: int (default - 100 milliseconds) |
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it |
|
|
|
speech_pad_ms: int (default - 30 milliseconds) |
|
Final speech chunks are padded by speech_pad_ms each side |
|
|
|
return_seconds: bool (default - False) |
|
whether return timestamps in seconds (default - samples) |
|
|
|
visualize_probs: bool (default - False) |
|
whether draw prob hist or not |
|
|
|
progress_tracking_callback: Callable[[float], None] (default - None) |
|
callback function taking progress in percents as an argument |
|
|
|
neg_threshold: float (default = threshold - 0.15) |
|
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH. |
|
|
|
window_size_samples: int (default - 512 samples) |
|
!!! DEPRECATED, DOES NOTHING !!! |
|
|
|
Returns |
|
---------- |
|
speeches: list of dicts |
|
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds) |
|
https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/utils_vad.py |
|
""" |
|
repo_owner = "snakers4" |
|
repo_name = "silero-vad:v4.0" |
|
ref = "master" |
|
|
|
try: |
|
model, utils = torch.hub.load(repo_or_dir=f'{repo_owner}/{repo_name}', model='silero_vad', trust_repo=True) |
|
except Exception as e: |
|
hub_dir = torch.hub.get_dir() |
|
owner_name_branch = '_'.join([repo_owner, repo_name, ref]) |
|
repo_dir = os.path.join(hub_dir, owner_name_branch) |
|
if os.path.exists(repo_dir): |
|
print(f"vad.py: torch.hub.load({repo_owner}/{repo_name}) Exception: {str(e)}, Using cache found in {repo_dir}\n") |
|
model, utils = torch.hub.load(repo_or_dir=repo_dir, model='silero_vad', source="local") |
|
else: |
|
raise |
|
|
|
|
|
|
|
torch.set_num_threads(1) |
|
(get_speech_timestamps, _, _, _, _) = utils |
|
|
|
return model, get_speech_timestamps |
|
|
|
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float): |
|
result = [] |
|
|
|
print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time)) |
|
perf_start_time = time.perf_counter() |
|
|
|
|
|
chunk_start = start_time |
|
|
|
while (chunk_start < end_time): |
|
chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK) |
|
|
|
print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration))) |
|
wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration)) |
|
|
|
sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD) |
|
seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate) |
|
adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration) |
|
|
|
|
|
|
|
result.extend(adjusted) |
|
chunk_start += chunk_duration |
|
|
|
perf_end_time = time.perf_counter() |
|
print("VAD processing took {} seconds".format(perf_end_time - perf_start_time)) |
|
|
|
return result |
|
|
|
def __getstate__(self): |
|
|
|
return { 'sampling_rate': self.sampling_rate } |
|
|
|
def __setstate__(self, state): |
|
self.sampling_rate = state['sampling_rate'] |
|
self.model = None |
|
|
|
self.cache = GLOBAL_MODEL_CACHE |
|
self._initialize_model() |
|
|
|
|
|
class VadPeriodicTranscription(AbstractTranscription): |
|
def __init__(self, sampling_rate: int = 16000): |
|
super().__init__(sampling_rate=sampling_rate) |
|
|
|
def is_transcribe_timestamps_fast(self): |
|
|
|
return True |
|
|
|
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float): |
|
result = [] |
|
|
|
|
|
start_timestamp = start_time |
|
|
|
while (start_timestamp < end_time): |
|
end_timestamp = min(start_timestamp + config.periodic_duration, end_time) |
|
segment_duration = end_timestamp - start_timestamp |
|
|
|
|
|
if (segment_duration >= 1): |
|
result.append( { 'start': start_timestamp, 'end': end_timestamp } ) |
|
|
|
start_timestamp = end_timestamp |
|
|
|
return result |
|
|
|
def get_audio_duration(file: str): |
|
return float(ffmpeg.probe(file)["format"]["duration"]) |
|
|
|
def load_audio(file: str, sample_rate: int = 16000, |
|
start_time: str = None, duration: str = None): |
|
""" |
|
Open an audio file and read as mono waveform, resampling as necessary |
|
|
|
Parameters |
|
---------- |
|
file: str |
|
The audio file to open |
|
|
|
sr: int |
|
The sample rate to resample the audio if necessary |
|
|
|
start_time: str |
|
The start time, using the standard FFMPEG time duration syntax, or None to disable. |
|
|
|
duration: str |
|
The duration, using the standard FFMPEG time duration syntax, or None to disable. |
|
|
|
Returns |
|
------- |
|
A NumPy array containing the audio waveform, in float32 dtype. |
|
""" |
|
try: |
|
inputArgs = {'threads': 0} |
|
|
|
if (start_time is not None): |
|
inputArgs['ss'] = start_time |
|
if (duration is not None): |
|
inputArgs['t'] = duration |
|
|
|
|
|
|
|
out, _ = ( |
|
ffmpeg.input(file, **inputArgs) |
|
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate) |
|
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True) |
|
) |
|
except ffmpeg.Error as e: |
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") |
|
|
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |