audio-transcriptor / pyscript /transcriptor.py
Tingusto's picture
Uploaded initial demo
3cdeba6 verified
raw
history blame
10.9 kB
import os
from dotenv import load_dotenv
import whisper
from pyannote.audio import Pipeline
import torch
from tqdm import tqdm
from time import time
from transformers import pipeline
from .transcription import Transcription
from .audio_processing import AudioProcessor
import io
from contextlib import redirect_stdout
import sys
load_dotenv()
class Transcriptor:
"""
A class for transcribing and diarizing audio files.
This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.
Attributes
----------
model_size : str
The size of the Whisper model to use for transcription. Available options are:
- 'tiny': Fastest, lowest accuracy
- 'base': Fast, good accuracy for many use cases
- 'small': Balanced speed and accuracy
- 'medium': High accuracy, slower than smaller models
- 'large-v3': Latest and most accurate version of the large model
- 'large-v3-turbo': Optimized version of the large-v3 model for faster processing
model : whisper.model.Whisper
The Whisper model for transcription.
pipeline : pyannote.audio.pipelines.SpeakerDiarization
The PyAnnote speaker diarization pipeline.
Usage:
>>> transcript = Transcriptor(model_size="large-v3")
>>> transcription = transcript.transcribe_audio("/path/to/audio.wav")
>>> transcription.get_name_speakers()
>>> transcription.save("/path/to/transcripts")
Note:
Larger models, especially 'large-v3', provide higher accuracy but require more
computational resources and may be slower to process audio.
"""
def __init__(self, model_size: str = "base"):
self.model_size = model_size
self.HF_TOKEN = os.getenv("HF_TOKEN")
if not self.HF_TOKEN:
raise ValueError("HF_TOKEN not found. Please store token in .env")
self._setup()
def _setup(self):
"""Initialize the Whisper model and diarization pipeline."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
print("Initializing Whisper model...")
if self.model_size == "large-v3-turbo":
self.model = pipeline(
task="automatic-speech-recognition",
model="ylacombe/whisper-large-v3-turbo",
chunk_length_s=30,
device=self.device,
)
else:
self.model = whisper.load_model(self.model_size, device=self.device)
print("Building diarization pipeline...")
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self.HF_TOKEN
).to(torch.device(self.device))
print("Setup completed successfully!")
def transcribe_audio(self, audio_file_path: str, enhanced: bool = False, buffer_logs: bool = False):
"""
Transcribe an audio file.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be transcribed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve transcription quality.
buffer_logs : bool, optional
If True, captures logs and returns them with the transcription. If False, prints to terminal.
Returns:
--------
Union[Transcription, Tuple[Transcription, str]]
Returns either just the Transcription object (if buffer_logs=False)
or a tuple of (Transcription, logs string) if buffer_logs=True
"""
if buffer_logs:
logs_buffer = io.StringIO()
with redirect_stdout(logs_buffer):
transcription = self._perform_transcription(audio_file_path, enhanced)
logs = logs_buffer.getvalue()
return transcription, logs
else:
transcription = self._perform_transcription(audio_file_path, enhanced)
return transcription
def _perform_transcription(self, audio_file_path: str, enhanced: bool = False):
"""Internal method to handle the actual transcription process."""
try:
print(f"Received audio_file_path: {audio_file_path}")
print(f"Type of audio_file_path: {type(audio_file_path)}")
if audio_file_path is None:
raise ValueError("No audio file was uploaded. Please upload an audio file.")
if not isinstance(audio_file_path, (str, bytes, os.PathLike)):
raise ValueError(f"Invalid audio file path type: {type(audio_file_path)}")
if not os.path.exists(audio_file_path):
raise FileNotFoundError(f"Audio file not found at path: {audio_file_path}")
print("Processing audio file...")
processed_audio = self.process_audio(audio_file_path, enhanced)
audio_file_path = processed_audio.path
audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
print("Diarization in progress...")
start_time = time()
diarization = self.perform_diarization(audio_file_path)
print(f"Diarization completed in {time() - start_time:.2f} seconds.")
segments = list(diarization.itertracks(yield_label=True))
transcriptions = self.transcribe_segments(audio, sr, duration, segments)
return Transcription(audio_file_path, transcriptions, segments)
except Exception as e:
print(f"Error occurred: {str(e)}")
raise RuntimeError(f"Failed to process the audio file: {str(e)}")
def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
"""
Process the audio file to ensure it meets the requirements for transcription.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be processed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve audio quality.
This includes optimizing noise reduction, voice enhancement, and volume boosting
parameters based on the audio characteristics.
Returns:
--------
AudioProcessor
An AudioProcessor object containing the processed audio file.
"""
processed_audio = AudioProcessor(audio_file_path)
if processed_audio.format != ".wav":
processed_audio.convert_to_wav()
if processed_audio.sample_rate != 16000:
processed_audio.resample_wav()
if enhanced:
parameters = processed_audio.optimize_enhancement_parameters()
processed_audio.enhance_audio(noise_reduce_strength=parameters[0],
voice_enhance_strength=parameters[1],
volume_boost=parameters[2])
processed_audio.display_changes()
return processed_audio
def perform_diarization(self, audio_file_path: str):
"""Perform speaker diarization on the audio file."""
with torch.no_grad():
return self.pipeline(audio_file_path)
def transcribe_segments(self, audio, sr, duration, segments):
"""Transcribe audio segments based on diarization."""
transcriptions = []
audio_segments = []
for turn, _, speaker in segments:
start = turn.start
end = min(turn.end, duration)
segment = audio[int(start * sr):int(end * sr)]
audio_segments.append((segment, speaker))
with tqdm(
total=len(audio_segments),
desc="Transcribing segments",
unit="segment",
ncols=100,
colour="green",
file=sys.stdout,
mininterval=0.1,
dynamic_ncols=True,
leave=True
) as pbar:
if self.device == "cuda":
try:
total_memory = torch.cuda.get_device_properties(0).total_memory
reserved_memory = torch.cuda.memory_reserved(0)
allocated_memory = torch.cuda.memory_allocated(0)
free_memory = total_memory - reserved_memory - allocated_memory
memory_per_sample = 1024 * 1024 * 1024 # 1GB
batch_size = max(1, min(4, int((free_memory * 0.7) // memory_per_sample)))
print(f"Using batch size of {batch_size} for GPU processing")
for i in range(0, len(audio_segments), batch_size):
try:
batch = audio_segments[i:i + batch_size]
torch.cuda.empty_cache()
results = self.model([segment for segment, _ in batch])
for (_, speaker), result in zip(batch, results):
transcriptions.append((speaker, result['text'].strip()))
pbar.update(len(batch))
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
for segment, speaker in batch:
results = self.model([segment])
transcriptions.append((speaker, results[0]['text'].strip()))
pbar.update(0.5)
else:
raise e
except Exception as e:
print(f"GPU processing failed: {str(e)}. Falling back to CPU processing...")
self.model = self.model.to('cpu')
self.device = 'cpu'
else:
for segment, speaker in audio_segments:
if self.model_size == "large-v3-turbo":
result = self.model(segment)
transcriptions.append((speaker, result['text'].strip()))
else:
result = self.model.transcribe(segment, fp16=self.device == "cuda")
transcriptions.append((speaker, result['text'].strip()))
pbar.update(1)
return transcriptions