Spaces:
Runtime error
Runtime error
File size: 5,277 Bytes
89c1bb5 42fff29 e4528d2 89c1bb5 42fff29 e4528d2 42fff29 89c1bb5 42fff29 89c1bb5 42fff29 89c1bb5 4570d8a 89c1bb5 4570d8a 89c1bb5 4570d8a 89c1bb5 4570d8a 42fff29 4570d8a 42fff29 89c1bb5 42fff29 89c1bb5 42fff29 89c1bb5 42fff29 89c1bb5 42fff29 89c1bb5 42fff29 89c1bb5 42fff29 e4528d2 89c1bb5 e4528d2 89c1bb5 e4528d2 89c1bb5 e4528d2 42fff29 e4528d2 42fff29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
!pip install git+https://github.com/speechbrain/speechbrain.git@develop
import re
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
import torch
import numpy as np
import torchaudio
from speechbrain.inference.classifiers import EncoderClassifier
# Load Whisper model for transcription
whisper_model_name = "openai/whisper-large"
processor = WhisperProcessor.from_pretrained(whisper_model_name)
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
# Initialize the language detection model (using zero-shot classification for language detection)
lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Load the SpeechBrain language ID model
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa", savedir="tmp")
# Function to transcribe audio to text using Whisper model
def transcribe_audio(audio_file):
"""
Function to transcribe audio to text using Whisper model.
Handles both file input and live audio input.
"""
# Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
if isinstance(audio_file, list):
# Ensure all elements in the list are of the same length before concatenating
audio = np.concatenate([np.array(a) for a in audio_file if a is not None])
else:
audio = np.array(audio_file) # Ensure it's a 1D array
# If audio is stereo (2D array with shape (2, N)), mix the channels by averaging them
if audio.ndim > 1:
audio = audio.mean(axis=0) # Mix the stereo channels into a mono signal
# Ensure the audio is a 1D array (e.g., [N])
if audio.ndim != 1:
raise ValueError("The audio input must be a 1D array (mono).")
# Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
input_features = processor(audio, return_tensors="pt", sampling_rate=48000)
# Generate transcription
generated_ids = model.generate(input_features["input_features"])
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
return transcription
# Function to detect language using SpeechBrain's language ID model
def detect_language_speechbrain(audio_file):
# Load the audio using torchaudio
signal, sample_rate = torchaudio.load(audio_file)
# Use SpeechBrain to classify the language of the audio
prediction = language_id.classify_batch(signal)
# Extract the language ISO code and its confidence
language = prediction[3][0] # Extracted language
confidence = prediction[1].exp() # Linear scale of confidence
return language, confidence.item()
# Cleanup function to remove filler words and clean the transcription
def cleanup_text(text):
"""
Function to clean the transcription text by removing filler words, unnecessary spaces,
non-alphabetic characters, and ensuring proper capitalization.
"""
# Step 1: Remove filler words like "uh", "um", etc.
text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE)
# Step 2: Remove unwanted characters (e.g., non-alphabetical characters except punctuation)
text = re.sub(r'[^a-zA-Z0-9\s,.\'?!]', '', text)
# Step 3: Remove extra spaces and ensure proper spacing around punctuation
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with a single space
text = re.sub(r'\s([?.!.,])', r'\1', text) # Remove space before punctuation
# Step 4: Normalize the whitespace (remove leading/trailing spaces)
text = text.strip()
# Step 5: Capitalize the first letter of the transcription
text = text.capitalize()
return text
# Main function to process the audio, transcribe it, and detect the language
def process_audio(audio_file):
try:
transcription = transcribe_audio(audio_file) # Transcribe audio to text
if not transcription.strip(): # If transcription is empty or just whitespace
raise ValueError("Transcription is empty.")
# Detect language using SpeechBrain's model
language, confidence = detect_language_speechbrain(audio_file)
cleaned_text = cleanup_text(transcription) # Clean up the transcription
return cleaned_text, language, confidence # Return cleaned transcription, language, and confidence score
except Exception as e:
# If any error occurs, return the error message
return f"Error: {str(e)}", "", ""
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Record your voice", type="numpy", scale=1) # Input for live audio (microphone)
output_text = gr.Textbox(label="Transcription", scale=1) # Output text for transcription
output_lang = gr.Textbox(label="Detected Language", scale=1) # Output text for detected language
output_score = gr.Textbox(label="Confidence Score", scale=1) # Output confidence score
process_btn = gr.Button("Process Audio") # Button to process audio
process_btn.click(fn=process_audio, inputs=[audio_input], outputs=[output_text, output_lang, output_score])
demo.launch(debug=True) |