Mohabedalgani's picture
Duplicate from asafaya/arabic-audio-transcription
fea5bdc
import shutil
import os
import tempfile
from collections import OrderedDict
from glob import glob
import numpy
import torch
import torchaudio
import torchaudio.functional as F
from pydub import AudioSegment
from tqdm import tqdm
from speechbrain.pretrained import VAD
from speechbrain.pretrained import EncoderASR
import gradio as gr
tempdir = tempfile.mkdtemp()
def read_and_resample(filename, outdir):
# load the file
AudioSegment.from_file(filename).export(f"{filename}.wav", format='wav', parameters=["-ar", "16000", "-ac", '1'])
filename = f"{filename}.wav"
signal, sr = torchaudio.load(filename)
if sr != 16_000:
# downsample to 16khz and mono
resampled = F.resample(signal, sr, 16_000, lowpass_filter_width=128).mean(dim=0).view(1, -1).cpu()
else:
resampled = signal.mean(dim=0).view(1, -1).cpu()
# get tmp dir:
filename = os.path.basename(filename).split(".")[0]
# yield segments of 90 minutes.
c_size = 60 * 60 * 16_000
for i, c in enumerate(range(0, resampled.shape[1], c_size)):
tempaudio = os.path.join(outdir, f"{filename}-{i}.wav")
# save to tmp dir:
torchaudio.save(tempaudio, resampled[:, c:c+c_size], 16_000)
yield (tempaudio, resampled[:, c:c+c_size])
def segment_file(VAD, id, prefix, filename, resampled, output_dir):
min_chunk_size = 4 # seconds
max_allowed_length = 12 # seconds
margin = 0.15
with torch.no_grad():
audio_info = VAD.get_speech_segments(filename, apply_energy_VAD=True, len_th=0.5,
deactivation_th=0.4, double_check=False, close_th=0.25)
# save segments:
s = -1
for _s, _e in audio_info:
_s, _e = _s.item(), _e.item()
_s = max(0, _s - margin)
e = min(resampled.size(1) / 16_000, _e + margin)
if s == -1:
s = _s
chunk_length = e - s
if chunk_length > min_chunk_size:
no_chunks = int(numpy.ceil(chunk_length / max_allowed_length))
starts = numpy.linspace(s, e, no_chunks + 1).tolist()
if chunk_length > max_allowed_length:
print("WARNING: segment too long:", chunk_length)
print(no_chunks, starts)
for x in range(no_chunks):
start = starts[x]
end = starts[x + 1]
local_chunk_length = end - start
print(f"Saving segment: {start:08.2f}-{end:08.2f}, with length: {local_chunk_length:05.2f} secs")
fname = f"{id}-{prefix}-{start:08.2f}-{end:08.2f}.wav"
# convert from seconds to samples:
start = int(start * 16_000)
end = int(end * 16_000)
# save segment:
torchaudio.save(os.path.join(output_dir, fname), resampled[:, start:end], 16_000)
s = -1
def format_time(secs: float):
m, s = divmod(secs, 60)
h, m = divmod(m, 60)
return "%d:%02d:%02d,%03d" % (h, m, s, int(secs * 1000 % 1000))
asr_model = EncoderASR.from_hparams(source="asafaya/hubert-large-arabic-transcribe")
vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty")
def main(filename, generate_srt=False):
try:
AudioSegment.from_file(filename)
except:
return "Please upload a valid audio file"
outdir = os.path.join(tempdir, filename.split("/")[-1].split(".")[0])
if not os.path.exists(outdir):
os.mkdir(outdir)
print("Applying VAD to", filename)
# directory to save
segments_dir = os.path.join(outdir, "segments")
if os.path.exists(segments_dir):
raise Exception(f"Segments directory already exists: {segments_dir}")
os.mkdir(segments_dir)
print("Saving segments to", segments_dir)
for c, (tempaudio, resampled) in enumerate(read_and_resample(filename, outdir)):
print(f"Segmenting file: {filename}, with length: {resampled.shape[1] / 16_000:05.2f} secs: {tempaudio}")
segment_file(vad_model, os.path.basename(tempaudio), c, tempaudio, resampled, segments_dir)
# os.remove(tempaudio)
transcriptions = OrderedDict()
files = glob(os.path.join(segments_dir, "*.wav"))
print("Start transcribing")
for f in tqdm(sorted(files)):
try:
transcriptions[os.path.basename(f).replace(".wav", "")] = asr_model.transcribe_file(f)
# os.remove(os.path.basename(f))
except Exception as e:
print(e)
print("Error transcribing file {}".format(f))
print("Skipping...")
# shutil.rmtree(outdir)
fo = ""
for i, key in enumerate(transcriptions):
line = key
# segment-0-00148.72-00156.97
start_sec = float(line.split("-")[-2])
end_sec = float(line.split("-")[-1])
if len(line) < 2: continue
if generate_srt:
fo += ("{}\n".format(i+1))
fo += ("{} --> ".format(format_time(start_sec)))
fo += ("{}\n".format(format_time(end_sec)))
fo += ("{}\n".format(transcriptions[key]))
fo += ("\n") if generate_srt else ""
return fo
outputs = gr.outputs.Textbox(label="Transcription")
title = "Arabic Speech Transcription"
description = "Simply upload your audio."
gr.Interface(main, [gr.inputs.Audio(label="Arabic Audio File", type="filepath"), "checkbox"], outputs, title=title, description=description, enable_queue=True).launch()