Spaces:
Sleeping
Sleeping
File size: 10,864 Bytes
3cdeba6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
import os
from dotenv import load_dotenv
import whisper
from pyannote.audio import Pipeline
import torch
from tqdm import tqdm
from time import time
from transformers import pipeline
from .transcription import Transcription
from .audio_processing import AudioProcessor
import io
from contextlib import redirect_stdout
import sys
load_dotenv()
class Transcriptor:
"""
A class for transcribing and diarizing audio files.
This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.
Attributes
----------
model_size : str
The size of the Whisper model to use for transcription. Available options are:
- 'tiny': Fastest, lowest accuracy
- 'base': Fast, good accuracy for many use cases
- 'small': Balanced speed and accuracy
- 'medium': High accuracy, slower than smaller models
- 'large-v3': Latest and most accurate version of the large model
- 'large-v3-turbo': Optimized version of the large-v3 model for faster processing
model : whisper.model.Whisper
The Whisper model for transcription.
pipeline : pyannote.audio.pipelines.SpeakerDiarization
The PyAnnote speaker diarization pipeline.
Usage:
>>> transcript = Transcriptor(model_size="large-v3")
>>> transcription = transcript.transcribe_audio("/path/to/audio.wav")
>>> transcription.get_name_speakers()
>>> transcription.save("/path/to/transcripts")
Note:
Larger models, especially 'large-v3', provide higher accuracy but require more
computational resources and may be slower to process audio.
"""
def __init__(self, model_size: str = "base"):
self.model_size = model_size
self.HF_TOKEN = os.getenv("HF_TOKEN")
if not self.HF_TOKEN:
raise ValueError("HF_TOKEN not found. Please store token in .env")
self._setup()
def _setup(self):
"""Initialize the Whisper model and diarization pipeline."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
print("Initializing Whisper model...")
if self.model_size == "large-v3-turbo":
self.model = pipeline(
task="automatic-speech-recognition",
model="ylacombe/whisper-large-v3-turbo",
chunk_length_s=30,
device=self.device,
)
else:
self.model = whisper.load_model(self.model_size, device=self.device)
print("Building diarization pipeline...")
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self.HF_TOKEN
).to(torch.device(self.device))
print("Setup completed successfully!")
def transcribe_audio(self, audio_file_path: str, enhanced: bool = False, buffer_logs: bool = False):
"""
Transcribe an audio file.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be transcribed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve transcription quality.
buffer_logs : bool, optional
If True, captures logs and returns them with the transcription. If False, prints to terminal.
Returns:
--------
Union[Transcription, Tuple[Transcription, str]]
Returns either just the Transcription object (if buffer_logs=False)
or a tuple of (Transcription, logs string) if buffer_logs=True
"""
if buffer_logs:
logs_buffer = io.StringIO()
with redirect_stdout(logs_buffer):
transcription = self._perform_transcription(audio_file_path, enhanced)
logs = logs_buffer.getvalue()
return transcription, logs
else:
transcription = self._perform_transcription(audio_file_path, enhanced)
return transcription
def _perform_transcription(self, audio_file_path: str, enhanced: bool = False):
"""Internal method to handle the actual transcription process."""
try:
print(f"Received audio_file_path: {audio_file_path}")
print(f"Type of audio_file_path: {type(audio_file_path)}")
if audio_file_path is None:
raise ValueError("No audio file was uploaded. Please upload an audio file.")
if not isinstance(audio_file_path, (str, bytes, os.PathLike)):
raise ValueError(f"Invalid audio file path type: {type(audio_file_path)}")
if not os.path.exists(audio_file_path):
raise FileNotFoundError(f"Audio file not found at path: {audio_file_path}")
print("Processing audio file...")
processed_audio = self.process_audio(audio_file_path, enhanced)
audio_file_path = processed_audio.path
audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
print("Diarization in progress...")
start_time = time()
diarization = self.perform_diarization(audio_file_path)
print(f"Diarization completed in {time() - start_time:.2f} seconds.")
segments = list(diarization.itertracks(yield_label=True))
transcriptions = self.transcribe_segments(audio, sr, duration, segments)
return Transcription(audio_file_path, transcriptions, segments)
except Exception as e:
print(f"Error occurred: {str(e)}")
raise RuntimeError(f"Failed to process the audio file: {str(e)}")
def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
"""
Process the audio file to ensure it meets the requirements for transcription.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be processed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve audio quality.
This includes optimizing noise reduction, voice enhancement, and volume boosting
parameters based on the audio characteristics.
Returns:
--------
AudioProcessor
An AudioProcessor object containing the processed audio file.
"""
processed_audio = AudioProcessor(audio_file_path)
if processed_audio.format != ".wav":
processed_audio.convert_to_wav()
if processed_audio.sample_rate != 16000:
processed_audio.resample_wav()
if enhanced:
parameters = processed_audio.optimize_enhancement_parameters()
processed_audio.enhance_audio(noise_reduce_strength=parameters[0],
voice_enhance_strength=parameters[1],
volume_boost=parameters[2])
processed_audio.display_changes()
return processed_audio
def perform_diarization(self, audio_file_path: str):
"""Perform speaker diarization on the audio file."""
with torch.no_grad():
return self.pipeline(audio_file_path)
def transcribe_segments(self, audio, sr, duration, segments):
"""Transcribe audio segments based on diarization."""
transcriptions = []
audio_segments = []
for turn, _, speaker in segments:
start = turn.start
end = min(turn.end, duration)
segment = audio[int(start * sr):int(end * sr)]
audio_segments.append((segment, speaker))
with tqdm(
total=len(audio_segments),
desc="Transcribing segments",
unit="segment",
ncols=100,
colour="green",
file=sys.stdout,
mininterval=0.1,
dynamic_ncols=True,
leave=True
) as pbar:
if self.device == "cuda":
try:
total_memory = torch.cuda.get_device_properties(0).total_memory
reserved_memory = torch.cuda.memory_reserved(0)
allocated_memory = torch.cuda.memory_allocated(0)
free_memory = total_memory - reserved_memory - allocated_memory
memory_per_sample = 1024 * 1024 * 1024 # 1GB
batch_size = max(1, min(4, int((free_memory * 0.7) // memory_per_sample)))
print(f"Using batch size of {batch_size} for GPU processing")
for i in range(0, len(audio_segments), batch_size):
try:
batch = audio_segments[i:i + batch_size]
torch.cuda.empty_cache()
results = self.model([segment for segment, _ in batch])
for (_, speaker), result in zip(batch, results):
transcriptions.append((speaker, result['text'].strip()))
pbar.update(len(batch))
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
for segment, speaker in batch:
results = self.model([segment])
transcriptions.append((speaker, results[0]['text'].strip()))
pbar.update(0.5)
else:
raise e
except Exception as e:
print(f"GPU processing failed: {str(e)}. Falling back to CPU processing...")
self.model = self.model.to('cpu')
self.device = 'cpu'
else:
for segment, speaker in audio_segments:
if self.model_size == "large-v3-turbo":
result = self.model(segment)
transcriptions.append((speaker, result['text'].strip()))
else:
result = self.model.transcribe(segment, fp16=self.device == "cuda")
transcriptions.append((speaker, result['text'].strip()))
pbar.update(1)
return transcriptions |