Spaces:
Runtime error
Runtime error
| import re | |
| import time | |
| import ffmpeg | |
| import gradio as gr | |
| import librosa | |
| import noisereduce as nr | |
| import numpy as np | |
| from transformers import ( | |
| SpeechT5Processor, | |
| SpeechT5ForSpeechToText, | |
| ) | |
| HF_MODEL_PATH = 'mohammad-shirkhani/speecht5_asr_finetune_persian' | |
| AUDIO_SAMPLING_RATE = 16000 # Hz | |
| AUDIO_FRAME_MIN_DUR = 4.5 # second | |
| AUDIO_FRAME_MAX_DUR = 11.5 # second | |
| SILENCE_FRAME_DUR = 0.300 # second | |
| SILENCE_FRAME_SHIFT = 0.010 # second | |
| TEXT_GEN_MAX_LEN = 250 # character | |
| model = None | |
| processor = None | |
| def initialize_model(): | |
| global model | |
| global processor | |
| model = SpeechT5ForSpeechToText.from_pretrained(HF_MODEL_PATH) | |
| processor = SpeechT5Processor.from_pretrained(HF_MODEL_PATH) | |
| def handle_user_input(audio_path, video_path): | |
| t_start = time.time() | |
| audio_asr_result = None | |
| video_asr_result = None | |
| if audio_path is not None: | |
| # Load the uploaded audio file and resample to 16 KHz | |
| waveform, sample_rate = librosa.load(audio_path, sr=None) | |
| waveform = librosa.resample( | |
| waveform, | |
| orig_sr=sample_rate, | |
| target_sr=AUDIO_SAMPLING_RATE | |
| ) | |
| # Perform ASR on the audio waveform | |
| audio_asr_result = perform_asr(waveform) | |
| if video_path is not None: | |
| # Load the uploaded video file and extract its audio | |
| ( | |
| ffmpeg | |
| .input(video_path) | |
| .output('tmp.wav', acodec='pcm_s16le') | |
| .run(overwrite_output=True) | |
| ) | |
| # Load the extracted audio file and resample to 16 KHz | |
| waveform, sample_rate = librosa.load('tmp.wav', sr=None) | |
| waveform = librosa.resample( | |
| waveform, | |
| orig_sr=sample_rate, | |
| target_sr=AUDIO_SAMPLING_RATE | |
| ) | |
| # Perform ASR on the audio waveform | |
| video_asr_result = perform_asr(waveform) | |
| delta_t = time.time() - t_start | |
| print(f'Total Time = {delta_t:5.1f} s\n') | |
| return audio_asr_result, video_asr_result | |
| def perform_asr(waveform): | |
| # Mono, nothing to be done :) | |
| if waveform.ndim == 1: | |
| pass | |
| # Stereo, convert to mono by averaging the channels | |
| elif waveform.ndim == 2 and waveform.shape[1] == 2: | |
| waveform = np.mean(waveform, axis=1) | |
| else: | |
| raise ValueError(f'Bad audio array shape: "{waveform.shape}"') | |
| t_start = time.time() | |
| # Split the audio array into smaller frames | |
| audio_frames = [] | |
| start_idx = 0 | |
| while start_idx != len(waveform): | |
| frame_end_min = int( | |
| start_idx + AUDIO_FRAME_MIN_DUR * AUDIO_SAMPLING_RATE | |
| ) | |
| frame_end_max = int( | |
| start_idx + AUDIO_FRAME_MAX_DUR * AUDIO_SAMPLING_RATE | |
| ) | |
| if frame_end_max < len(waveform): | |
| break_point = search_for_breakpoint( | |
| waveform, | |
| frame_end_min, | |
| frame_end_max | |
| ) | |
| else: | |
| break_point = len(waveform) | |
| audio_frames.append(waveform[start_idx:break_point]) | |
| start_idx = break_point | |
| delta_t = time.time() - t_start | |
| print(f'Audio Framing = {delta_t:5.1f} s') | |
| t_start = time.time() | |
| # Apply noise reduction on each audio frame | |
| audio_frames = [ | |
| nr.reduce_noise(y=frame, sr=AUDIO_SAMPLING_RATE) | |
| for frame in audio_frames | |
| ] | |
| delta_t = time.time() - t_start | |
| print(f'Noise Reduction = {delta_t:5.1f} s') | |
| ######################### Method 1 - For Loop ######################### | |
| # transcriptions = [] | |
| # for frame in audio_frames: | |
| # inputs = processor( | |
| # audio=frame, | |
| # sampling_rate=AUDIO_SAMPLING_RATE, | |
| # return_tensors='pt' | |
| # ) | |
| # predicted_ids = model.generate( | |
| # **inputs, | |
| # max_length=TEXT_GEN_MAX_LEN | |
| # ) | |
| # transcription = processor.batch_decode( | |
| # predicted_ids, | |
| # skip_special_tokens=True | |
| # )[0] | |
| # transcriptions.append(transcription) | |
| ######################### Method 2 - Batch ############################ | |
| t_start = time.time() | |
| # Process the entire batch of audio frames | |
| inputs = processor( | |
| audio=audio_frames, | |
| sampling_rate=AUDIO_SAMPLING_RATE, | |
| padding=True, | |
| return_tensors='pt' | |
| ) | |
| # Generate predictions for the entire batch | |
| predicted_ids = model.generate( | |
| **inputs, | |
| max_length=TEXT_GEN_MAX_LEN | |
| ) | |
| # Decode the predicted IDs into transcriptions | |
| transcriptions = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| ) | |
| delta_t = time.time() - t_start | |
| print(f'Text Generation = {delta_t:5.1f} s') | |
| t_start = time.time() | |
| # Clean the model-generated transcriptions | |
| transcriptions = [clean_model_answer(t) for t in transcriptions] | |
| delta_t = time.time() - t_start | |
| print(f'Text Cleaning = {delta_t:5.1f} s') | |
| return '\n\n'.join(transcriptions) | |
| def search_for_breakpoint(waveform, begin, end): | |
| waveform_ampl = np.abs(waveform) | |
| frame_size = int(SILENCE_FRAME_DUR * AUDIO_SAMPLING_RATE) | |
| frame_shift = int(SILENCE_FRAME_SHIFT * AUDIO_SAMPLING_RATE) | |
| avg_amplitudes = {} | |
| for start_idx in range(begin, end - frame_size + 1, frame_shift): | |
| stop_idx = start_idx + frame_size | |
| avg_amplitudes[start_idx] = np.mean(waveform_ampl[start_idx:stop_idx]) | |
| # Consider the center of the most quiet frame as the breakpoint | |
| best_start_idx = min(avg_amplitudes, key=avg_amplitudes.get) | |
| break_point = best_start_idx + int(frame_size / 2) | |
| return break_point | |
| def clean_model_answer(txt): | |
| txt = re.sub(r'\s(?!\s)', '', txt) | |
| txt = re.sub(r'\s+', ' ', txt) | |
| return txt | |
| if __name__ == '__main__': | |
| # Initialize the ASR model | |
| initialize_model() | |
| # Create a Gradio interface with required inputs and outputs | |
| iface = gr.Interface( | |
| fn=handle_user_input, | |
| inputs=[ | |
| gr.Audio(label='Upload/Record Audio', type='filepath'), | |
| gr.Video(label='Upload Video', sources='upload'), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Audio Transcript", rtl=True), | |
| gr.Textbox(label="Video Transcript", rtl=True), | |
| ], | |
| title="Automatic Speech Recognition for Farsi Language", | |
| description="Upload an audio/video file to generate its transcript!", | |
| examples=[ | |
| ['examples/roya_nonahali.mp3', None], # Example Audio 1 | |
| ['examples/keikavoos_yakideh.mp3', None], # Example Audio 2 | |
| ['examples/amirmohammad_samsami.mp3', None], # Example Audio 3 | |
| ], | |
| cache_examples=False, | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |