File size: 5,595 Bytes
48cf8b9
a585ac9
 
5ff9030
 
a5072d6
865b00a
5ff9030
cf6913d
c1673b9
 
 
5ff9030
 
cf6913d
 
c1673b9
 
 
 
48cf8b9
5ff9030
 
 
 
a5072d6
5ff9030
 
 
 
 
 
 
c1673b9
5ff9030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865b00a
5ff9030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865b00a
 
 
 
 
5ff9030
865b00a
 
 
 
 
 
 
 
 
5ff9030
865b00a
 
 
c82aed7
7b0f006
865b00a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from transformers import AutoProcessor, SeamlessM4TModel
from tempfile import NamedTemporaryFile
import subprocess
import os
from pydantic import BaseModel, ValidationError, ConfigDict
from typing import Optional, Tuple

class AudioInput(BaseModel):
    audio_data: Optional[Tuple[int, np.ndarray]] = None
    audio_path: Optional[str] = None
    model_config = ConfigDict(arbitrary_types_allowed=True)

    def validate_audio(self):
        if self.audio_data is None and self.audio_path is None:
            raise ValueError("Please provide an audio file or record from the microphone.")
        return self

class SeamlessM4TApp:
    def __init__(self):
        self.device = "cpu"
        print("Using CPU for inference")
        
        # Load model and processor
        model_name = "facebook/seamless-m4t-large"
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.model = SeamlessM4TModel.from_pretrained(
            model_name,
            low_cpu_mem_usage=True,
            torch_dtype=torch.float32
        )
        self.model.eval()

    def preprocess_audio_with_ffmpeg(self, input_path: str, output_path: str) -> bool:
        try:
            command = [
                "ffmpeg",
                "-i", input_path,
                "-ar", "16000",
                "-ac", "1",
                "-y",
                output_path
            ]
            subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            return True
        except subprocess.CalledProcessError as e:
            print(f"FFmpeg error: {e.stderr.decode('utf-8')}")
            return False
        except Exception as e:
            print(f"Error during FFmpeg processing: {str(e)}")
            return False

    def save_audio_to_tempfile(self, audio_data: np.ndarray, sample_rate: int) -> str:
        with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
            sf.write(temp_file.name, audio_data, sample_rate)
            return temp_file.name

    def transcribe(self, audio_input):
        try:
            # Validate the audio input
            audio_input = AudioInput(audio_data=audio_input)
            audio_input.validate_audio()

            # Handle microphone input
            if audio_input.audio_data is not None:
                sample_rate, audio_data = audio_input.audio_data
                audio_path = self.save_audio_to_tempfile(audio_data, sample_rate)
            else:
                return "Invalid input. Please record audio or upload an audio file.", None, None

            # Preprocess the audio using FFmpeg
            with NamedTemporaryFile(suffix=".wav", delete=False) as processed_temp_file:
                processed_audio_path = processed_temp_file.name
                if not self.preprocess_audio_with_ffmpeg(audio_path, processed_audio_path):
                    return "Error: Failed to preprocess audio. Please check the file format.", None, None

            # Process audio with SeamlessM4T
            inputs = self.processor(
                audio=processed_audio_path,
                return_tensors="pt",
                sampling_rate=16000
            )

            # Generate transcription
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_features,
                    tgt_lang="kaz",  # Set target language to Kazakh
                    task="transcribe",
                    max_new_tokens=256
                )

            # Decode the outputs
            transcription = self.processor.decode(
                outputs[0].tolist(),
                skip_special_tokens=True
            )

            # Save transcription to file
            transcription_file = f"transcription_{os.path.basename(audio_path)}.txt"
            with open(transcription_file, "w", encoding="utf-8") as f:
                f.write(transcription)

            return transcription, processed_audio_path, transcription_file

        except ValidationError as e:
            print(f"Validation error: {str(e)}")
            return f"Validation error: {str(e)}", None, None
        except Exception as e:
            print(f"Error during transcription: {str(e)}")
            return f"Error during transcription: {str(e)}", None, None
        finally:
            # Clean up temporary files
            if "audio_path" in locals() and os.path.exists(audio_path):
                os.remove(audio_path)
            if "processed_audio_path" in locals() and os.path.exists(processed_audio_path):
                os.remove(processed_audio_path)

# Initialize the app
app = SeamlessM4TApp()

# Create Gradio interface
demo = gr.Blocks()

with demo:
    gr.Markdown("# Kazakh Speech-to-Text using SeamlessM4T")
    gr.Markdown("Record audio or upload an audio file to transcribe speech in Kazakh.")
    
    with gr.Row():
        audio_input = gr.Audio(
            label="Record or Upload Audio",
            sources=["microphone", "upload"],
            type="numpy"
        )
    
    with gr.Row():
        transcription_output = gr.Textbox(label="Transcription", lines=4)
        audio_playback = gr.Audio(label="Playback Audio", visible=True)
        download_button = gr.File(label="Download Transcription")
    
    submit_button = gr.Button("Submit")
    submit_button.click(
        fn=app.transcribe,
        inputs=[audio_input],
        outputs=[transcription_output, audio_playback, download_button]
    )

# Launch the app
demo.launch()