# @title Define funcs import torch import torchaudio from typing import Callable, List import torch.nn.functional as F import warnings def get_speech_probs(audio: torch.Tensor, # model, threshold: float = 0.5, sampling_rate: int = 16000, window_size_samples: int = 512, progress_tracking_callback: Callable[[float], None] = None): if not torch.is_tensor(audio): try: audio = torch.Tensor(audio) except: raise TypeError("Audio cannot be casted to tensor. Cast it manually") if len(audio.shape) > 1: for i in range(len(audio.shape)): # trying to squeeze empty dimensions audio = audio.squeeze(0) if len(audio.shape) > 1: raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") if sampling_rate > 16000 and (sampling_rate % 16000 == 0): step = sampling_rate // 16000 sampling_rate = 16000 audio = audio[::step] warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!') else: step = 1 if sampling_rate == 8000 and window_size_samples > 768: warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!') if window_size_samples not in [256, 512, 768, 1024, 1536]: warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate') model.reset_states() audio_length_samples = len(audio) speech_probs = [] for current_start_sample in range(0, audio_length_samples, window_size_samples): chunk = audio[current_start_sample: current_start_sample + window_size_samples] if len(chunk) < window_size_samples: chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) speech_prob = model(chunk, sampling_rate).item() speech_probs.append(speech_prob) # caculate progress and seng it to callback function progress = current_start_sample + window_size_samples if progress > audio_length_samples: progress = audio_length_samples progress_percent = (progress / audio_length_samples) * 100 if progress_tracking_callback: progress_tracking_callback(progress_percent) return speech_probs def probs2speech_timestamps(speech_probs, audio_length_samples, threshold: float = 0.5, sampling_rate: int = 16000, min_speech_duration_ms: int = 250, max_speech_duration_s: float = float('inf'), min_silence_duration_ms: int = 100, window_size_samples: int = 512, speech_pad_ms: int = 30, return_seconds: bool = False, rounding: int = 1,): step = sampling_rate // 16000 min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 speech_pad_samples = sampling_rate * speech_pad_ms / 1000 max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 triggered = False speeches = [] current_speech = {} neg_threshold = threshold - 0.15 temp_end = 0 # to save potential segment end (and tolerate some silence) prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached for i, speech_prob in enumerate(speech_probs): if (speech_prob >= threshold) and temp_end: temp_end = 0 if next_start < prev_end: next_start = window_size_samples * i if (speech_prob >= threshold) and not triggered: triggered = True current_speech['start'] = window_size_samples * i continue if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples: if prev_end: current_speech['end'] = prev_end speeches.append(current_speech) current_speech = {} if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres) triggered = False else: current_speech['start'] = next_start prev_end = next_start = temp_end = 0 else: current_speech['end'] = window_size_samples * i speeches.append(current_speech) current_speech = {} prev_end = next_start = temp_end = 0 triggered = False continue if (speech_prob < neg_threshold) and triggered: if not temp_end: temp_end = window_size_samples * i if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence prev_end = temp_end if (window_size_samples * i) - temp_end < min_silence_samples: continue else: current_speech['end'] = temp_end if (current_speech['end'] - current_speech['start']) > min_speech_samples: speeches.append(current_speech) current_speech = {} prev_end = next_start = temp_end = 0 triggered = False continue if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples: current_speech['end'] = audio_length_samples speeches.append(current_speech) for i, speech in enumerate(speeches): if i == 0: speech['start'] = int(max(0, speech['start'] - speech_pad_samples)) if i != len(speeches) - 1: silence_duration = speeches[i+1]['start'] - speech['end'] if silence_duration < 2 * speech_pad_samples: speech['end'] += int(silence_duration // 2) speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2)) else: speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples)) else: speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) if return_seconds: for speech_dict in speeches: speech_dict['start'] = round(speech_dict['start'] / sampling_rate, rounding) speech_dict['end'] = round(speech_dict['end'] / sampling_rate, rounding) elif step > 1: for speech_dict in speeches: speech_dict['start'] *= step speech_dict['end'] *= step return speeches def make_visualization(probs, step): import pandas as pd pd.DataFrame({'probs': probs}, index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8), kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step], xlabel='seconds', ylabel='speech probability', colormap='tab20') torch.set_num_threads(1) USE_ONNX = True # change this to True if you want to test onnx model model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=True, onnx=USE_ONNX) (_, _, read_audio, *_) = utils