kzs2t / app.py
bektim's picture
Update app.py
5ff9030 verified
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from transformers import AutoProcessor, SeamlessM4TModel
from tempfile import NamedTemporaryFile
import subprocess
import os
from pydantic import BaseModel, ValidationError, ConfigDict
from typing import Optional, Tuple
class AudioInput(BaseModel):
audio_data: Optional[Tuple[int, np.ndarray]] = None
audio_path: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def validate_audio(self):
if self.audio_data is None and self.audio_path is None:
raise ValueError("Please provide an audio file or record from the microphone.")
return self
class SeamlessM4TApp:
def __init__(self):
self.device = "cpu"
print("Using CPU for inference")
# Load model and processor
model_name = "facebook/seamless-m4t-large"
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = SeamlessM4TModel.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
)
self.model.eval()
def preprocess_audio_with_ffmpeg(self, input_path: str, output_path: str) -> bool:
try:
command = [
"ffmpeg",
"-i", input_path,
"-ar", "16000",
"-ac", "1",
"-y",
output_path
]
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return True
except subprocess.CalledProcessError as e:
print(f"FFmpeg error: {e.stderr.decode('utf-8')}")
return False
except Exception as e:
print(f"Error during FFmpeg processing: {str(e)}")
return False
def save_audio_to_tempfile(self, audio_data: np.ndarray, sample_rate: int) -> str:
with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
sf.write(temp_file.name, audio_data, sample_rate)
return temp_file.name
def transcribe(self, audio_input):
try:
# Validate the audio input
audio_input = AudioInput(audio_data=audio_input)
audio_input.validate_audio()
# Handle microphone input
if audio_input.audio_data is not None:
sample_rate, audio_data = audio_input.audio_data
audio_path = self.save_audio_to_tempfile(audio_data, sample_rate)
else:
return "Invalid input. Please record audio or upload an audio file.", None, None
# Preprocess the audio using FFmpeg
with NamedTemporaryFile(suffix=".wav", delete=False) as processed_temp_file:
processed_audio_path = processed_temp_file.name
if not self.preprocess_audio_with_ffmpeg(audio_path, processed_audio_path):
return "Error: Failed to preprocess audio. Please check the file format.", None, None
# Process audio with SeamlessM4T
inputs = self.processor(
audio=processed_audio_path,
return_tensors="pt",
sampling_rate=16000
)
# Generate transcription
with torch.no_grad():
outputs = self.model.generate(
inputs.input_features,
tgt_lang="kaz", # Set target language to Kazakh
task="transcribe",
max_new_tokens=256
)
# Decode the outputs
transcription = self.processor.decode(
outputs[0].tolist(),
skip_special_tokens=True
)
# Save transcription to file
transcription_file = f"transcription_{os.path.basename(audio_path)}.txt"
with open(transcription_file, "w", encoding="utf-8") as f:
f.write(transcription)
return transcription, processed_audio_path, transcription_file
except ValidationError as e:
print(f"Validation error: {str(e)}")
return f"Validation error: {str(e)}", None, None
except Exception as e:
print(f"Error during transcription: {str(e)}")
return f"Error during transcription: {str(e)}", None, None
finally:
# Clean up temporary files
if "audio_path" in locals() and os.path.exists(audio_path):
os.remove(audio_path)
if "processed_audio_path" in locals() and os.path.exists(processed_audio_path):
os.remove(processed_audio_path)
# Initialize the app
app = SeamlessM4TApp()
# Create Gradio interface
demo = gr.Blocks()
with demo:
gr.Markdown("# Kazakh Speech-to-Text using SeamlessM4T")
gr.Markdown("Record audio or upload an audio file to transcribe speech in Kazakh.")
with gr.Row():
audio_input = gr.Audio(
label="Record or Upload Audio",
sources=["microphone", "upload"],
type="numpy"
)
with gr.Row():
transcription_output = gr.Textbox(label="Transcription", lines=4)
audio_playback = gr.Audio(label="Playback Audio", visible=True)
download_button = gr.File(label="Download Transcription")
submit_button = gr.Button("Submit")
submit_button.click(
fn=app.transcribe,
inputs=[audio_input],
outputs=[transcription_output, audio_playback, download_button]
)
# Launch the app
demo.launch()