Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import whisper | |
from pyannote.audio import Pipeline | |
import torch | |
from tqdm import tqdm | |
from time import time | |
from transformers import pipeline | |
from .transcription import Transcription | |
from .audio_processing import AudioProcessor | |
import io | |
from contextlib import redirect_stdout | |
import sys | |
load_dotenv() | |
class Transcriptor: | |
""" | |
A class for transcribing and diarizing audio files. | |
This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification. | |
Attributes | |
---------- | |
model_size : str | |
The size of the Whisper model to use for transcription. Available options are: | |
- 'tiny': Fastest, lowest accuracy | |
- 'base': Fast, good accuracy for many use cases | |
- 'small': Balanced speed and accuracy | |
- 'medium': High accuracy, slower than smaller models | |
- 'large-v3': Latest and most accurate version of the large model | |
- 'large-v3-turbo': Optimized version of the large-v3 model for faster processing | |
model : whisper.model.Whisper | |
The Whisper model for transcription. | |
pipeline : pyannote.audio.pipelines.SpeakerDiarization | |
The PyAnnote speaker diarization pipeline. | |
Usage: | |
>>> transcript = Transcriptor(model_size="large-v3") | |
>>> transcription = transcript.transcribe_audio("/path/to/audio.wav") | |
>>> transcription.get_name_speakers() | |
>>> transcription.save("/path/to/transcripts") | |
Note: | |
Larger models, especially 'large-v3', provide higher accuracy but require more | |
computational resources and may be slower to process audio. | |
""" | |
def __init__(self, model_size: str = "base"): | |
self.model_size = model_size | |
self.HF_TOKEN = os.getenv("HF_TOKEN") | |
if not self.HF_TOKEN: | |
raise ValueError("HF_TOKEN not found. Please store token in .env") | |
self._setup() | |
def _setup(self): | |
"""Initialize the Whisper model and diarization pipeline.""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
print("Initializing Whisper model...") | |
if self.model_size == "large-v3-turbo": | |
self.model = pipeline( | |
task="automatic-speech-recognition", | |
model="ylacombe/whisper-large-v3-turbo", | |
chunk_length_s=30, | |
device=self.device, | |
) | |
else: | |
self.model = whisper.load_model(self.model_size, device=self.device) | |
print("Building diarization pipeline...") | |
self.pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=self.HF_TOKEN | |
).to(torch.device(self.device)) | |
print("Setup completed successfully!") | |
def transcribe_audio(self, audio_file_path: str, enhanced: bool = False, buffer_logs: bool = False): | |
""" | |
Transcribe an audio file. | |
Parameters: | |
----------- | |
audio_file_path : str | |
Path to the audio file to be transcribed. | |
enhanced : bool, optional | |
If True, applies audio enhancement techniques to improve transcription quality. | |
buffer_logs : bool, optional | |
If True, captures logs and returns them with the transcription. If False, prints to terminal. | |
Returns: | |
-------- | |
Union[Transcription, Tuple[Transcription, str]] | |
Returns either just the Transcription object (if buffer_logs=False) | |
or a tuple of (Transcription, logs string) if buffer_logs=True | |
""" | |
if buffer_logs: | |
logs_buffer = io.StringIO() | |
with redirect_stdout(logs_buffer): | |
transcription = self._perform_transcription(audio_file_path, enhanced) | |
logs = logs_buffer.getvalue() | |
return transcription, logs | |
else: | |
transcription = self._perform_transcription(audio_file_path, enhanced) | |
return transcription | |
def _perform_transcription(self, audio_file_path: str, enhanced: bool = False): | |
"""Internal method to handle the actual transcription process.""" | |
try: | |
print(f"Received audio_file_path: {audio_file_path}") | |
print(f"Type of audio_file_path: {type(audio_file_path)}") | |
if audio_file_path is None: | |
raise ValueError("No audio file was uploaded. Please upload an audio file.") | |
if not isinstance(audio_file_path, (str, bytes, os.PathLike)): | |
raise ValueError(f"Invalid audio file path type: {type(audio_file_path)}") | |
if not os.path.exists(audio_file_path): | |
raise FileNotFoundError(f"Audio file not found at path: {audio_file_path}") | |
print("Processing audio file...") | |
processed_audio = self.process_audio(audio_file_path, enhanced) | |
audio_file_path = processed_audio.path | |
audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration | |
print("Diarization in progress...") | |
start_time = time() | |
diarization = self.perform_diarization(audio_file_path) | |
print(f"Diarization completed in {time() - start_time:.2f} seconds.") | |
segments = list(diarization.itertracks(yield_label=True)) | |
transcriptions = self.transcribe_segments(audio, sr, duration, segments) | |
return Transcription(audio_file_path, transcriptions, segments) | |
except Exception as e: | |
print(f"Error occurred: {str(e)}") | |
raise RuntimeError(f"Failed to process the audio file: {str(e)}") | |
def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor: | |
""" | |
Process the audio file to ensure it meets the requirements for transcription. | |
Parameters: | |
----------- | |
audio_file_path : str | |
Path to the audio file to be processed. | |
enhanced : bool, optional | |
If True, applies audio enhancement techniques to improve audio quality. | |
This includes optimizing noise reduction, voice enhancement, and volume boosting | |
parameters based on the audio characteristics. | |
Returns: | |
-------- | |
AudioProcessor | |
An AudioProcessor object containing the processed audio file. | |
""" | |
processed_audio = AudioProcessor(audio_file_path) | |
if processed_audio.format != ".wav": | |
processed_audio.convert_to_wav() | |
if processed_audio.sample_rate != 16000: | |
processed_audio.resample_wav() | |
if enhanced: | |
parameters = processed_audio.optimize_enhancement_parameters() | |
processed_audio.enhance_audio(noise_reduce_strength=parameters[0], | |
voice_enhance_strength=parameters[1], | |
volume_boost=parameters[2]) | |
processed_audio.display_changes() | |
return processed_audio | |
def perform_diarization(self, audio_file_path: str): | |
"""Perform speaker diarization on the audio file.""" | |
with torch.no_grad(): | |
return self.pipeline(audio_file_path) | |
def transcribe_segments(self, audio, sr, duration, segments): | |
"""Transcribe audio segments based on diarization.""" | |
transcriptions = [] | |
audio_segments = [] | |
for turn, _, speaker in segments: | |
start = turn.start | |
end = min(turn.end, duration) | |
segment = audio[int(start * sr):int(end * sr)] | |
audio_segments.append((segment, speaker)) | |
with tqdm( | |
total=len(audio_segments), | |
desc="Transcribing segments", | |
unit="segment", | |
ncols=100, | |
colour="green", | |
file=sys.stdout, | |
mininterval=0.1, | |
dynamic_ncols=True, | |
leave=True | |
) as pbar: | |
if self.device == "cuda": | |
try: | |
total_memory = torch.cuda.get_device_properties(0).total_memory | |
reserved_memory = torch.cuda.memory_reserved(0) | |
allocated_memory = torch.cuda.memory_allocated(0) | |
free_memory = total_memory - reserved_memory - allocated_memory | |
memory_per_sample = 1024 * 1024 * 1024 # 1GB | |
batch_size = max(1, min(4, int((free_memory * 0.7) // memory_per_sample))) | |
print(f"Using batch size of {batch_size} for GPU processing") | |
for i in range(0, len(audio_segments), batch_size): | |
try: | |
batch = audio_segments[i:i + batch_size] | |
torch.cuda.empty_cache() | |
results = self.model([segment for segment, _ in batch]) | |
for (_, speaker), result in zip(batch, results): | |
transcriptions.append((speaker, result['text'].strip())) | |
pbar.update(len(batch)) | |
except RuntimeError as e: | |
if "out of memory" in str(e): | |
torch.cuda.empty_cache() | |
for segment, speaker in batch: | |
results = self.model([segment]) | |
transcriptions.append((speaker, results[0]['text'].strip())) | |
pbar.update(0.5) | |
else: | |
raise e | |
except Exception as e: | |
print(f"GPU processing failed: {str(e)}. Falling back to CPU processing...") | |
self.model = self.model.to('cpu') | |
self.device = 'cpu' | |
else: | |
for segment, speaker in audio_segments: | |
if self.model_size == "large-v3-turbo": | |
result = self.model(segment) | |
transcriptions.append((speaker, result['text'].strip())) | |
else: | |
result = self.model.transcribe(segment, fp16=self.device == "cuda") | |
transcriptions.append((speaker, result['text'].strip())) | |
pbar.update(1) | |
return transcriptions |