ajchri5's picture
Update app.py
89c1bb5 verified
!pip install git+https://github.com/speechbrain/speechbrain.git@develop
import re
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
import torch
import numpy as np
import torchaudio
from speechbrain.inference.classifiers import EncoderClassifier
# Load Whisper model for transcription
whisper_model_name = "openai/whisper-large"
processor = WhisperProcessor.from_pretrained(whisper_model_name)
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
# Initialize the language detection model (using zero-shot classification for language detection)
lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Load the SpeechBrain language ID model
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa", savedir="tmp")
# Function to transcribe audio to text using Whisper model
def transcribe_audio(audio_file):
"""
Function to transcribe audio to text using Whisper model.
Handles both file input and live audio input.
"""
# Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
if isinstance(audio_file, list):
# Ensure all elements in the list are of the same length before concatenating
audio = np.concatenate([np.array(a) for a in audio_file if a is not None])
else:
audio = np.array(audio_file) # Ensure it's a 1D array
# If audio is stereo (2D array with shape (2, N)), mix the channels by averaging them
if audio.ndim > 1:
audio = audio.mean(axis=0) # Mix the stereo channels into a mono signal
# Ensure the audio is a 1D array (e.g., [N])
if audio.ndim != 1:
raise ValueError("The audio input must be a 1D array (mono).")
# Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
input_features = processor(audio, return_tensors="pt", sampling_rate=48000)
# Generate transcription
generated_ids = model.generate(input_features["input_features"])
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
return transcription
# Function to detect language using SpeechBrain's language ID model
def detect_language_speechbrain(audio_file):
# Load the audio using torchaudio
signal, sample_rate = torchaudio.load(audio_file)
# Use SpeechBrain to classify the language of the audio
prediction = language_id.classify_batch(signal)
# Extract the language ISO code and its confidence
language = prediction[3][0] # Extracted language
confidence = prediction[1].exp() # Linear scale of confidence
return language, confidence.item()
# Cleanup function to remove filler words and clean the transcription
def cleanup_text(text):
"""
Function to clean the transcription text by removing filler words, unnecessary spaces,
non-alphabetic characters, and ensuring proper capitalization.
"""
# Step 1: Remove filler words like "uh", "um", etc.
text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE)
# Step 2: Remove unwanted characters (e.g., non-alphabetical characters except punctuation)
text = re.sub(r'[^a-zA-Z0-9\s,.\'?!]', '', text)
# Step 3: Remove extra spaces and ensure proper spacing around punctuation
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with a single space
text = re.sub(r'\s([?.!.,])', r'\1', text) # Remove space before punctuation
# Step 4: Normalize the whitespace (remove leading/trailing spaces)
text = text.strip()
# Step 5: Capitalize the first letter of the transcription
text = text.capitalize()
return text
# Main function to process the audio, transcribe it, and detect the language
def process_audio(audio_file):
try:
transcription = transcribe_audio(audio_file) # Transcribe audio to text
if not transcription.strip(): # If transcription is empty or just whitespace
raise ValueError("Transcription is empty.")
# Detect language using SpeechBrain's model
language, confidence = detect_language_speechbrain(audio_file)
cleaned_text = cleanup_text(transcription) # Clean up the transcription
return cleaned_text, language, confidence # Return cleaned transcription, language, and confidence score
except Exception as e:
# If any error occurs, return the error message
return f"Error: {str(e)}", "", ""
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Record your voice", type="numpy", scale=1) # Input for live audio (microphone)
output_text = gr.Textbox(label="Transcription", scale=1) # Output text for transcription
output_lang = gr.Textbox(label="Detected Language", scale=1) # Output text for detected language
output_score = gr.Textbox(label="Confidence Score", scale=1) # Output confidence score
process_btn = gr.Button("Process Audio") # Button to process audio
process_btn.click(fn=process_audio, inputs=[audio_input], outputs=[output_text, output_lang, output_score])
demo.launch(debug=True)