File size: 5,277 Bytes
89c1bb5
 
42fff29
 
 
 
e4528d2
89c1bb5
 
42fff29
 
e4528d2
42fff29
 
 
89c1bb5
42fff29
 
89c1bb5
 
 
42fff29
 
89c1bb5
 
 
 
4570d8a
 
89c1bb5
 
4570d8a
 
89c1bb5
 
 
 
 
 
 
 
4570d8a
 
89c1bb5
4570d8a
42fff29
 
 
4570d8a
42fff29
 
89c1bb5
 
 
 
 
 
 
 
 
 
 
 
42fff29
 
 
89c1bb5
 
 
 
 
42fff29
89c1bb5
 
 
 
 
 
 
 
 
42fff29
89c1bb5
 
42fff29
89c1bb5
42fff29
 
89c1bb5
42fff29
e4528d2
 
89c1bb5
e4528d2
 
89c1bb5
 
 
 
e4528d2
89c1bb5
 
 
e4528d2
 
 
42fff29
 
 
 
 
e4528d2
42fff29
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
!pip install git+https://github.com/speechbrain/speechbrain.git@develop

import re
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
import torch
import numpy as np
import torchaudio
from speechbrain.inference.classifiers import EncoderClassifier

# Load Whisper model for transcription
whisper_model_name = "openai/whisper-large"
processor = WhisperProcessor.from_pretrained(whisper_model_name)
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)

# Initialize the language detection model (using zero-shot classification for language detection)
lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

# Load the SpeechBrain language ID model
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa", savedir="tmp")

# Function to transcribe audio to text using Whisper model
def transcribe_audio(audio_file):
    """
    Function to transcribe audio to text using Whisper model.
    Handles both file input and live audio input.
    """
    # Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
    if isinstance(audio_file, list):
        # Ensure all elements in the list are of the same length before concatenating
        audio = np.concatenate([np.array(a) for a in audio_file if a is not None])
    else:
        audio = np.array(audio_file)  # Ensure it's a 1D array

    # If audio is stereo (2D array with shape (2, N)), mix the channels by averaging them
    if audio.ndim > 1:
        audio = audio.mean(axis=0)  # Mix the stereo channels into a mono signal

    # Ensure the audio is a 1D array (e.g., [N])
    if audio.ndim != 1:
        raise ValueError("The audio input must be a 1D array (mono).")

    # Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
    input_features = processor(audio, return_tensors="pt", sampling_rate=48000)

    # Generate transcription
    generated_ids = model.generate(input_features["input_features"])
    transcription = processor.decode(generated_ids[0], skip_special_tokens=True)

    return transcription

# Function to detect language using SpeechBrain's language ID model
def detect_language_speechbrain(audio_file):
    # Load the audio using torchaudio
    signal, sample_rate = torchaudio.load(audio_file)

    # Use SpeechBrain to classify the language of the audio
    prediction = language_id.classify_batch(signal)

    # Extract the language ISO code and its confidence
    language = prediction[3][0]  # Extracted language
    confidence = prediction[1].exp()  # Linear scale of confidence
    return language, confidence.item()

# Cleanup function to remove filler words and clean the transcription
def cleanup_text(text):
    """
    Function to clean the transcription text by removing filler words, unnecessary spaces,
    non-alphabetic characters, and ensuring proper capitalization.
    """
    # Step 1: Remove filler words like "uh", "um", etc.
    text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE)

    # Step 2: Remove unwanted characters (e.g., non-alphabetical characters except punctuation)
    text = re.sub(r'[^a-zA-Z0-9\s,.\'?!]', '', text)

    # Step 3: Remove extra spaces and ensure proper spacing around punctuation
    text = re.sub(r'\s+', ' ', text)  # Replace multiple spaces with a single space
    text = re.sub(r'\s([?.!.,])', r'\1', text)  # Remove space before punctuation

    # Step 4: Normalize the whitespace (remove leading/trailing spaces)
    text = text.strip()

    # Step 5: Capitalize the first letter of the transcription
    text = text.capitalize()

    return text

# Main function to process the audio, transcribe it, and detect the language
def process_audio(audio_file):
    try:
        transcription = transcribe_audio(audio_file)  # Transcribe audio to text

        if not transcription.strip():  # If transcription is empty or just whitespace
            raise ValueError("Transcription is empty.")

        # Detect language using SpeechBrain's model
        language, confidence = detect_language_speechbrain(audio_file)

        cleaned_text = cleanup_text(transcription)  # Clean up the transcription

        return cleaned_text, language, confidence  # Return cleaned transcription, language, and confidence score

    except Exception as e:
        # If any error occurs, return the error message
        return f"Error: {str(e)}", "", ""

# Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(label="Record your voice", type="numpy", scale=1)  # Input for live audio (microphone)
            output_text = gr.Textbox(label="Transcription", scale=1)  # Output text for transcription
            output_lang = gr.Textbox(label="Detected Language", scale=1)  # Output text for detected language
            output_score = gr.Textbox(label="Confidence Score", scale=1)  # Output confidence score
            process_btn = gr.Button("Process Audio")  # Button to process audio

        process_btn.click(fn=process_audio, inputs=[audio_input], outputs=[output_text, output_lang, output_score])

    demo.launch(debug=True)