Spaces:
Build error
Build error
import shutil | |
import os | |
import tempfile | |
from collections import OrderedDict | |
from glob import glob | |
import numpy | |
import torch | |
import torchaudio | |
import torchaudio.functional as F | |
from pydub import AudioSegment | |
from tqdm import tqdm | |
from speechbrain.pretrained import VAD | |
from speechbrain.pretrained import EncoderASR | |
import gradio as gr | |
tempdir = tempfile.mkdtemp() | |
def read_and_resample(filename, outdir): | |
# load the file | |
AudioSegment.from_file(filename).export(f"{filename}.wav", format='wav', parameters=["-ar", "16000", "-ac", '1']) | |
filename = f"{filename}.wav" | |
signal, sr = torchaudio.load(filename) | |
if sr != 16_000: | |
# downsample to 16khz and mono | |
resampled = F.resample(signal, sr, 16_000, lowpass_filter_width=128).mean(dim=0).view(1, -1).cpu() | |
else: | |
resampled = signal.mean(dim=0).view(1, -1).cpu() | |
# get tmp dir: | |
filename = os.path.basename(filename).split(".")[0] | |
# yield segments of 90 minutes. | |
c_size = 60 * 60 * 16_000 | |
for i, c in enumerate(range(0, resampled.shape[1], c_size)): | |
tempaudio = os.path.join(outdir, f"{filename}-{i}.wav") | |
# save to tmp dir: | |
torchaudio.save(tempaudio, resampled[:, c:c+c_size], 16_000) | |
yield (tempaudio, resampled[:, c:c+c_size]) | |
def segment_file(VAD, id, prefix, filename, resampled, output_dir): | |
min_chunk_size = 4 # seconds | |
max_allowed_length = 12 # seconds | |
margin = 0.15 | |
with torch.no_grad(): | |
audio_info = VAD.get_speech_segments(filename, apply_energy_VAD=True, len_th=0.5, | |
deactivation_th=0.4, double_check=False, close_th=0.25) | |
# save segments: | |
s = -1 | |
for _s, _e in audio_info: | |
_s, _e = _s.item(), _e.item() | |
_s = max(0, _s - margin) | |
e = min(resampled.size(1) / 16_000, _e + margin) | |
if s == -1: | |
s = _s | |
chunk_length = e - s | |
if chunk_length > min_chunk_size: | |
no_chunks = int(numpy.ceil(chunk_length / max_allowed_length)) | |
starts = numpy.linspace(s, e, no_chunks + 1).tolist() | |
if chunk_length > max_allowed_length: | |
print("WARNING: segment too long:", chunk_length) | |
print(no_chunks, starts) | |
for x in range(no_chunks): | |
start = starts[x] | |
end = starts[x + 1] | |
local_chunk_length = end - start | |
print(f"Saving segment: {start:08.2f}-{end:08.2f}, with length: {local_chunk_length:05.2f} secs") | |
fname = f"{id}-{prefix}-{start:08.2f}-{end:08.2f}.wav" | |
# convert from seconds to samples: | |
start = int(start * 16_000) | |
end = int(end * 16_000) | |
# save segment: | |
torchaudio.save(os.path.join(output_dir, fname), resampled[:, start:end], 16_000) | |
s = -1 | |
def format_time(secs: float): | |
m, s = divmod(secs, 60) | |
h, m = divmod(m, 60) | |
return "%d:%02d:%02d,%03d" % (h, m, s, int(secs * 1000 % 1000)) | |
asr_model = EncoderASR.from_hparams(source="asafaya/hubert-large-arabic-transcribe") | |
vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty") | |
def main(filename, generate_srt=False): | |
try: | |
AudioSegment.from_file(filename) | |
except: | |
return "Please upload a valid audio file" | |
outdir = os.path.join(tempdir, filename.split("/")[-1].split(".")[0]) | |
if not os.path.exists(outdir): | |
os.mkdir(outdir) | |
print("Applying VAD to", filename) | |
# directory to save | |
segments_dir = os.path.join(outdir, "segments") | |
if os.path.exists(segments_dir): | |
raise Exception(f"Segments directory already exists: {segments_dir}") | |
os.mkdir(segments_dir) | |
print("Saving segments to", segments_dir) | |
for c, (tempaudio, resampled) in enumerate(read_and_resample(filename, outdir)): | |
print(f"Segmenting file: {filename}, with length: {resampled.shape[1] / 16_000:05.2f} secs: {tempaudio}") | |
segment_file(vad_model, os.path.basename(tempaudio), c, tempaudio, resampled, segments_dir) | |
# os.remove(tempaudio) | |
transcriptions = OrderedDict() | |
files = glob(os.path.join(segments_dir, "*.wav")) | |
print("Start transcribing") | |
for f in tqdm(sorted(files)): | |
try: | |
transcriptions[os.path.basename(f).replace(".wav", "")] = asr_model.transcribe_file(f) | |
# os.remove(os.path.basename(f)) | |
except Exception as e: | |
print(e) | |
print("Error transcribing file {}".format(f)) | |
print("Skipping...") | |
# shutil.rmtree(outdir) | |
fo = "" | |
for i, key in enumerate(transcriptions): | |
line = key | |
# segment-0-00148.72-00156.97 | |
start_sec = float(line.split("-")[-2]) | |
end_sec = float(line.split("-")[-1]) | |
if len(line) < 2: continue | |
if generate_srt: | |
fo += ("{}\n".format(i+1)) | |
fo += ("{} --> ".format(format_time(start_sec))) | |
fo += ("{}\n".format(format_time(end_sec))) | |
fo += ("{}\n".format(transcriptions[key])) | |
fo += ("\n") if generate_srt else "" | |
return fo | |
outputs = gr.outputs.Textbox(label="Transcription") | |
title = "Arabic Speech Transcription" | |
description = "Simply upload your audio." | |
gr.Interface(main, [gr.inputs.Audio(label="Arabic Audio File", type="filepath"), "checkbox"], outputs, title=title, description=description, enable_queue=True).launch() |