audio-transcriptor / pyscript /transcriptor.py
Tingusto's picture
Update secrets
9e36430 verified
raw
history blame
6.79 kB
import os
from dotenv import load_dotenv
import whisper
from pyannote.audio import Pipeline
import torch
from tqdm import tqdm
from time import time
from transformers import pipeline
from .transcription import Transcription
from .audio_processing import AudioProcessor
load_dotenv()
class Transcriptor:
"""
A class for transcribing and diarizing audio files.
This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.
Attributes
----------
model_size : str
The size of the Whisper model to use for transcription. Available options are:
- 'tiny': Fastest, lowest accuracy
- 'base': Fast, good accuracy for many use cases
- 'small': Balanced speed and accuracy
- 'medium': High accuracy, slower than smaller models
- 'large': High accuracy, slower and more resource-intensive
- 'large-v1': Improved version of the large model
- 'large-v2': Further improved version of the large model
- 'large-v3': Latest and most accurate version of the large model
- 'large-v3-turbo': Optimized version of the large-v3 model for faster processing
model : whisper.model.Whisper
The Whisper model for transcription.
pipeline : pyannote.audio.pipelines.SpeakerDiarization
The PyAnnote speaker diarization pipeline.
Usage:
>>> transcript = Transcriptor(model_size="large-v3")
>>> transcription = transcript.transcribe_audio("/path/to/audio.wav")
>>> transcription.get_name_speakers()
>>> transcription.save("/path/to/transcripts")
Note:
Larger models, especially 'large-v3', provide higher accuracy but require more
computational resources and may be slower to process audio.
"""
def __init__(self, model_size: str = "base"):
self.model_size = model_size
self.HF_TOKEN = os.environ.get("HF_TOKEN")
if not self.HF_TOKEN:
raise ValueError("HF_TOKEN not found. Please set it as a Gradio secret.")
self._setup()
def _setup(self):
"""Initialize the Whisper model and diarization pipeline."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print("Initializing Whisper model...")
if self.model_size == "large-v3-turbo":
self.model = pipeline(
task="automatic-speech-recognition",
model="ylacombe/whisper-large-v3-turbo",
chunk_length_s=30,
device=self.device,
)
else:
self.model = whisper.load_model(self.model_size, device=self.device)
print("Building diarization pipeline...")
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self.HF_TOKEN
).to(torch.device(self.device))
print("Setup completed successfully!")
def transcribe_audio(self, audio_file_path: str, enhanced: bool = False) -> Transcription:
"""
Transcribe an audio file.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be transcribed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve transcription quality.
This includes noise reduction, voice enhancement, and volume boosting.
Returns:
--------
Transcription
A Transcription object containing the transcribed text and speaker segments.
"""
try:
print("Processing audio file...")
processed_audio = self.process_audio(audio_file_path, enhanced)
audio_file_path = processed_audio.path
audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
print("Diarization in progress...")
start_time = time()
diarization = self.perform_diarization(audio_file_path)
print(f"Diarization completed in {time() - start_time:.2f} seconds.")
segments = list(diarization.itertracks(yield_label=True))
transcriptions = self.transcribe_segments(audio, sr, duration, segments)
return Transcription(audio_file_path, transcriptions, segments)
except Exception as e:
raise RuntimeError(f"Failed to process the audio file: {e}")
def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
"""
Process the audio file to ensure it meets the requirements for transcription.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be processed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve audio quality.
This includes optimizing noise reduction, voice enhancement, and volume boosting
parameters based on the audio characteristics.
Returns:
--------
AudioProcessor
An AudioProcessor object containing the processed audio file.
"""
processed_audio = AudioProcessor(audio_file_path)
if processed_audio.format != ".wav":
processed_audio.convert_to_wav()
if processed_audio.sample_rate != 16000:
processed_audio.resample_wav()
if enhanced:
parameters = processed_audio.optimize_enhancement_parameters()
processed_audio.enhance_audio(noise_reduce_strength=parameters[0],
voice_enhance_strength=parameters[1],
volume_boost=parameters[2])
processed_audio.display_changes()
return processed_audio
def perform_diarization(self, audio_file_path: str):
"""Perform speaker diarization on the audio file."""
with torch.no_grad():
return self.pipeline(audio_file_path)
def transcribe_segments(self, audio, sr, duration, segments):
"""Transcribe audio segments based on diarization."""
transcriptions = []
for turn, _, speaker in tqdm(segments, desc="Transcribing segments", unit="segment", ncols=100, colour="green"):
start = turn.start
end = min(turn.end, duration)
segment = audio[int(start * sr):int(end * sr)]
if self.model_size == "large-v3-turbo":
result = self.model(segment)
else:
result = self.model.transcribe(segment, fp16=self.device == "cuda")
transcriptions.append((speaker, result['text'].strip()))
return transcriptions