|
import gradio as gr |
|
import numpy as np |
|
import soundfile as sf |
|
import torch |
|
from transformers import AutoProcessor, SeamlessM4TModel |
|
from tempfile import NamedTemporaryFile |
|
import subprocess |
|
import os |
|
from pydantic import BaseModel, ValidationError, ConfigDict |
|
from typing import Optional, Tuple |
|
|
|
class AudioInput(BaseModel): |
|
audio_data: Optional[Tuple[int, np.ndarray]] = None |
|
audio_path: Optional[str] = None |
|
model_config = ConfigDict(arbitrary_types_allowed=True) |
|
|
|
def validate_audio(self): |
|
if self.audio_data is None and self.audio_path is None: |
|
raise ValueError("Please provide an audio file or record from the microphone.") |
|
return self |
|
|
|
class SeamlessM4TApp: |
|
def __init__(self): |
|
self.device = "cpu" |
|
print("Using CPU for inference") |
|
|
|
|
|
model_name = "facebook/seamless-m4t-large" |
|
self.processor = AutoProcessor.from_pretrained(model_name) |
|
self.model = SeamlessM4TModel.from_pretrained( |
|
model_name, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.float32 |
|
) |
|
self.model.eval() |
|
|
|
def preprocess_audio_with_ffmpeg(self, input_path: str, output_path: str) -> bool: |
|
try: |
|
command = [ |
|
"ffmpeg", |
|
"-i", input_path, |
|
"-ar", "16000", |
|
"-ac", "1", |
|
"-y", |
|
output_path |
|
] |
|
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
return True |
|
except subprocess.CalledProcessError as e: |
|
print(f"FFmpeg error: {e.stderr.decode('utf-8')}") |
|
return False |
|
except Exception as e: |
|
print(f"Error during FFmpeg processing: {str(e)}") |
|
return False |
|
|
|
def save_audio_to_tempfile(self, audio_data: np.ndarray, sample_rate: int) -> str: |
|
with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: |
|
sf.write(temp_file.name, audio_data, sample_rate) |
|
return temp_file.name |
|
|
|
def transcribe(self, audio_input): |
|
try: |
|
|
|
audio_input = AudioInput(audio_data=audio_input) |
|
audio_input.validate_audio() |
|
|
|
|
|
if audio_input.audio_data is not None: |
|
sample_rate, audio_data = audio_input.audio_data |
|
audio_path = self.save_audio_to_tempfile(audio_data, sample_rate) |
|
else: |
|
return "Invalid input. Please record audio or upload an audio file.", None, None |
|
|
|
|
|
with NamedTemporaryFile(suffix=".wav", delete=False) as processed_temp_file: |
|
processed_audio_path = processed_temp_file.name |
|
if not self.preprocess_audio_with_ffmpeg(audio_path, processed_audio_path): |
|
return "Error: Failed to preprocess audio. Please check the file format.", None, None |
|
|
|
|
|
inputs = self.processor( |
|
audio=processed_audio_path, |
|
return_tensors="pt", |
|
sampling_rate=16000 |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs.input_features, |
|
tgt_lang="kaz", |
|
task="transcribe", |
|
max_new_tokens=256 |
|
) |
|
|
|
|
|
transcription = self.processor.decode( |
|
outputs[0].tolist(), |
|
skip_special_tokens=True |
|
) |
|
|
|
|
|
transcription_file = f"transcription_{os.path.basename(audio_path)}.txt" |
|
with open(transcription_file, "w", encoding="utf-8") as f: |
|
f.write(transcription) |
|
|
|
return transcription, processed_audio_path, transcription_file |
|
|
|
except ValidationError as e: |
|
print(f"Validation error: {str(e)}") |
|
return f"Validation error: {str(e)}", None, None |
|
except Exception as e: |
|
print(f"Error during transcription: {str(e)}") |
|
return f"Error during transcription: {str(e)}", None, None |
|
finally: |
|
|
|
if "audio_path" in locals() and os.path.exists(audio_path): |
|
os.remove(audio_path) |
|
if "processed_audio_path" in locals() and os.path.exists(processed_audio_path): |
|
os.remove(processed_audio_path) |
|
|
|
|
|
app = SeamlessM4TApp() |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown("# Kazakh Speech-to-Text using SeamlessM4T") |
|
gr.Markdown("Record audio or upload an audio file to transcribe speech in Kazakh.") |
|
|
|
with gr.Row(): |
|
audio_input = gr.Audio( |
|
label="Record or Upload Audio", |
|
sources=["microphone", "upload"], |
|
type="numpy" |
|
) |
|
|
|
with gr.Row(): |
|
transcription_output = gr.Textbox(label="Transcription", lines=4) |
|
audio_playback = gr.Audio(label="Playback Audio", visible=True) |
|
download_button = gr.File(label="Download Transcription") |
|
|
|
submit_button = gr.Button("Submit") |
|
submit_button.click( |
|
fn=app.transcribe, |
|
inputs=[audio_input], |
|
outputs=[transcription_output, audio_playback, download_button] |
|
) |
|
|
|
|
|
demo.launch() |