File size: 6,792 Bytes
9e36430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02d76b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
from dotenv import load_dotenv
import whisper
from pyannote.audio import Pipeline
import torch
from tqdm import tqdm
from time import time
from transformers import pipeline
from .transcription import Transcription
from .audio_processing import AudioProcessor

load_dotenv()

class Transcriptor:
    """
    A class for transcribing and diarizing audio files.

    This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.

    Attributes
    ----------
    model_size : str
        The size of the Whisper model to use for transcription. Available options are:
        - 'tiny': Fastest, lowest accuracy
        - 'base': Fast, good accuracy for many use cases
        - 'small': Balanced speed and accuracy
        - 'medium': High accuracy, slower than smaller models
        - 'large': High accuracy, slower and more resource-intensive
        - 'large-v1': Improved version of the large model
        - 'large-v2': Further improved version of the large model
        - 'large-v3': Latest and most accurate version of the large model
        - 'large-v3-turbo': Optimized version of the large-v3 model for faster processing
    model : whisper.model.Whisper
        The Whisper model for transcription.
    pipeline : pyannote.audio.pipelines.SpeakerDiarization
        The PyAnnote speaker diarization pipeline.

    Usage:
        >>> transcript = Transcriptor(model_size="large-v3")
        >>> transcription = transcript.transcribe_audio("/path/to/audio.wav")
        >>> transcription.get_name_speakers()
        >>> transcription.save("/path/to/transcripts")

    Note:
        Larger models, especially 'large-v3', provide higher accuracy but require more 
        computational resources and may be slower to process audio.
    """

    def __init__(self, model_size: str = "base"):
        self.model_size = model_size
        self.HF_TOKEN = os.environ.get("HF_TOKEN")
        if not self.HF_TOKEN:
            raise ValueError("HF_TOKEN not found. Please set it as a Gradio secret.")
        self._setup()

    def _setup(self):
        """Initialize the Whisper model and diarization pipeline."""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print("Initializing Whisper model...")
        if self.model_size == "large-v3-turbo":
            self.model = pipeline(
                task="automatic-speech-recognition",
                model="ylacombe/whisper-large-v3-turbo",
                chunk_length_s=30,
                device=self.device,
            )
        else:
            self.model = whisper.load_model(self.model_size, device=self.device)
        print("Building diarization pipeline...")
        self.pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1", 
            use_auth_token=self.HF_TOKEN
        ).to(torch.device(self.device))
        print("Setup completed successfully!")

    def transcribe_audio(self, audio_file_path: str, enhanced: bool = False) -> Transcription:
        """
        Transcribe an audio file.

        Parameters:
        -----------
        audio_file_path : str
            Path to the audio file to be transcribed.
        enhanced : bool, optional
            If True, applies audio enhancement techniques to improve transcription quality.
            This includes noise reduction, voice enhancement, and volume boosting.

        Returns:
        --------
        Transcription
            A Transcription object containing the transcribed text and speaker segments.
        """
        try:
            print("Processing audio file...")
            processed_audio = self.process_audio(audio_file_path, enhanced)
            audio_file_path = processed_audio.path
            audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
            
            print("Diarization in progress...")
            start_time = time()
            diarization = self.perform_diarization(audio_file_path)
            print(f"Diarization completed in {time() - start_time:.2f} seconds.")
            segments = list(diarization.itertracks(yield_label=True))

            transcriptions = self.transcribe_segments(audio, sr, duration, segments)
            return Transcription(audio_file_path, transcriptions, segments)
        except Exception as e:
            raise RuntimeError(f"Failed to process the audio file: {e}")

    def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
        """
        Process the audio file to ensure it meets the requirements for transcription.

        Parameters:
        -----------
        audio_file_path : str
            Path to the audio file to be processed.
        enhanced : bool, optional
            If True, applies audio enhancement techniques to improve audio quality.
            This includes optimizing noise reduction, voice enhancement, and volume boosting
            parameters based on the audio characteristics.

        Returns:
        --------
        AudioProcessor
            An AudioProcessor object containing the processed audio file.
        """
        processed_audio = AudioProcessor(audio_file_path)
        if processed_audio.format != ".wav":
            processed_audio.convert_to_wav()
        
        if processed_audio.sample_rate != 16000:
            processed_audio.resample_wav()
        
        if enhanced:
            parameters = processed_audio.optimize_enhancement_parameters()
            processed_audio.enhance_audio(noise_reduce_strength=parameters[0], 
                                        voice_enhance_strength=parameters[1], 
                                        volume_boost=parameters[2])
        
        processed_audio.display_changes()
        return processed_audio

    def perform_diarization(self, audio_file_path: str):
        """Perform speaker diarization on the audio file."""
        with torch.no_grad():
            return self.pipeline(audio_file_path)

    def transcribe_segments(self, audio, sr, duration, segments):
        """Transcribe audio segments based on diarization."""
        transcriptions = []
        
        for turn, _, speaker in tqdm(segments, desc="Transcribing segments", unit="segment", ncols=100, colour="green"):
            start = turn.start
            end = min(turn.end, duration)
            segment = audio[int(start * sr):int(end * sr)]
            if self.model_size == "large-v3-turbo":
                result = self.model(segment)
            else:
                result = self.model.transcribe(segment, fp16=self.device == "cuda")
            transcriptions.append((speaker, result['text'].strip()))
        
        return transcriptions