File size: 10,864 Bytes
3cdeba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import os
from dotenv import load_dotenv
import whisper
from pyannote.audio import Pipeline
import torch
from tqdm import tqdm
from time import time
from transformers import pipeline
from .transcription import Transcription
from .audio_processing import AudioProcessor
import io
from contextlib import redirect_stdout
import sys

load_dotenv()

class Transcriptor:
    """

    A class for transcribing and diarizing audio files.



    This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.



    Attributes

    ----------

    model_size : str

        The size of the Whisper model to use for transcription. Available options are:

        - 'tiny': Fastest, lowest accuracy

        - 'base': Fast, good accuracy for many use cases

        - 'small': Balanced speed and accuracy

        - 'medium': High accuracy, slower than smaller models

        - 'large-v3': Latest and most accurate version of the large model

        - 'large-v3-turbo': Optimized version of the large-v3 model for faster processing

    model : whisper.model.Whisper

        The Whisper model for transcription.

    pipeline : pyannote.audio.pipelines.SpeakerDiarization

        The PyAnnote speaker diarization pipeline.



    Usage:

        >>> transcript = Transcriptor(model_size="large-v3")

        >>> transcription = transcript.transcribe_audio("/path/to/audio.wav")

        >>> transcription.get_name_speakers()

        >>> transcription.save("/path/to/transcripts")



    Note:

        Larger models, especially 'large-v3', provide higher accuracy but require more 

        computational resources and may be slower to process audio.

    """

    def __init__(self, model_size: str = "base"):
        self.model_size = model_size
        self.HF_TOKEN = os.getenv("HF_TOKEN")
        if not self.HF_TOKEN:
            raise ValueError("HF_TOKEN not found. Please store token in .env")
        self._setup()

    def _setup(self):
        """Initialize the Whisper model and diarization pipeline."""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        print("Initializing Whisper model...")
        if self.model_size == "large-v3-turbo":
            self.model = pipeline(
                task="automatic-speech-recognition",
                model="ylacombe/whisper-large-v3-turbo",
                chunk_length_s=30,
                device=self.device,
            )
        else:
            self.model = whisper.load_model(self.model_size, device=self.device)
        print("Building diarization pipeline...")
        self.pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1", 
            use_auth_token=self.HF_TOKEN
        ).to(torch.device(self.device))
        print("Setup completed successfully!")

    def transcribe_audio(self, audio_file_path: str, enhanced: bool = False, buffer_logs: bool = False):
        """

        Transcribe an audio file.



        Parameters:

        -----------

        audio_file_path : str

            Path to the audio file to be transcribed.

        enhanced : bool, optional

            If True, applies audio enhancement techniques to improve transcription quality.

        buffer_logs : bool, optional

            If True, captures logs and returns them with the transcription. If False, prints to terminal.



        Returns:

        --------

        Union[Transcription, Tuple[Transcription, str]]

            Returns either just the Transcription object (if buffer_logs=False) 

            or a tuple of (Transcription, logs string) if buffer_logs=True

        """
        if buffer_logs:
            logs_buffer = io.StringIO()
            with redirect_stdout(logs_buffer):
                transcription = self._perform_transcription(audio_file_path, enhanced)
                logs = logs_buffer.getvalue()
                return transcription, logs
        else:
            transcription = self._perform_transcription(audio_file_path, enhanced)
            return transcription

    def _perform_transcription(self, audio_file_path: str, enhanced: bool = False):
        """Internal method to handle the actual transcription process."""
        try:
            print(f"Received audio_file_path: {audio_file_path}")
            print(f"Type of audio_file_path: {type(audio_file_path)}")
            
            if audio_file_path is None:
                raise ValueError("No audio file was uploaded. Please upload an audio file.")
            
            if not isinstance(audio_file_path, (str, bytes, os.PathLike)):
                raise ValueError(f"Invalid audio file path type: {type(audio_file_path)}")
            
            if not os.path.exists(audio_file_path):
                raise FileNotFoundError(f"Audio file not found at path: {audio_file_path}")
            
            print("Processing audio file...")
            processed_audio = self.process_audio(audio_file_path, enhanced)
            audio_file_path = processed_audio.path
            audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
            
            print("Diarization in progress...")
            start_time = time()
            diarization = self.perform_diarization(audio_file_path)
            print(f"Diarization completed in {time() - start_time:.2f} seconds.")
            segments = list(diarization.itertracks(yield_label=True))

            transcriptions = self.transcribe_segments(audio, sr, duration, segments)
            return Transcription(audio_file_path, transcriptions, segments)
                
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            raise RuntimeError(f"Failed to process the audio file: {str(e)}")

    def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
        """

        Process the audio file to ensure it meets the requirements for transcription.



        Parameters:

        -----------

        audio_file_path : str

            Path to the audio file to be processed.

        enhanced : bool, optional

            If True, applies audio enhancement techniques to improve audio quality.

            This includes optimizing noise reduction, voice enhancement, and volume boosting

            parameters based on the audio characteristics.



        Returns:

        --------

        AudioProcessor

            An AudioProcessor object containing the processed audio file.

        """
        processed_audio = AudioProcessor(audio_file_path)
        if processed_audio.format != ".wav":
            processed_audio.convert_to_wav()
        
        if processed_audio.sample_rate != 16000:
            processed_audio.resample_wav()
        
        if enhanced:
            parameters = processed_audio.optimize_enhancement_parameters()
            processed_audio.enhance_audio(noise_reduce_strength=parameters[0], 
                                        voice_enhance_strength=parameters[1], 
                                        volume_boost=parameters[2])
        
        processed_audio.display_changes()
        return processed_audio

    def perform_diarization(self, audio_file_path: str):
        """Perform speaker diarization on the audio file."""
        with torch.no_grad():
            return self.pipeline(audio_file_path)

    def transcribe_segments(self, audio, sr, duration, segments):
        """Transcribe audio segments based on diarization."""
        transcriptions = []
        
        audio_segments = []
        for turn, _, speaker in segments:
            start = turn.start
            end = min(turn.end, duration)
            segment = audio[int(start * sr):int(end * sr)]
            audio_segments.append((segment, speaker))
        
        with tqdm(
            total=len(audio_segments),
            desc="Transcribing segments",
            unit="segment",
            ncols=100,
            colour="green",
            file=sys.stdout,
            mininterval=0.1,
            dynamic_ncols=True,
            leave=True
        ) as pbar:
            if self.device == "cuda":
                try:
                    total_memory = torch.cuda.get_device_properties(0).total_memory
                    reserved_memory = torch.cuda.memory_reserved(0)
                    allocated_memory = torch.cuda.memory_allocated(0)
                    free_memory = total_memory - reserved_memory - allocated_memory
                    
                    memory_per_sample = 1024 * 1024 * 1024  # 1GB
                    batch_size = max(1, min(4, int((free_memory * 0.7) // memory_per_sample)))
                    print(f"Using batch size of {batch_size} for GPU processing")
                    
                    for i in range(0, len(audio_segments), batch_size):
                        try:
                            batch = audio_segments[i:i + batch_size]
                            torch.cuda.empty_cache()
                            results = self.model([segment for segment, _ in batch])
                            for (_, speaker), result in zip(batch, results):
                                transcriptions.append((speaker, result['text'].strip()))
                            pbar.update(len(batch))
                        except RuntimeError as e:
                            if "out of memory" in str(e):
                                torch.cuda.empty_cache()
                                for segment, speaker in batch:
                                    results = self.model([segment])
                                    transcriptions.append((speaker, results[0]['text'].strip()))
                                    pbar.update(0.5)
                            else:
                                raise e
                except Exception as e:
                    print(f"GPU processing failed: {str(e)}. Falling back to CPU processing...")
                    self.model = self.model.to('cpu')
                    self.device = 'cpu'
            else:
                for segment, speaker in audio_segments:
                    if self.model_size == "large-v3-turbo":
                        result = self.model(segment)
                        transcriptions.append((speaker, result['text'].strip()))
                    else:
                        result = self.model.transcribe(segment, fp16=self.device == "cuda")
                        transcriptions.append((speaker, result['text'].strip()))
                    pbar.update(1)
        
        return transcriptions