Spaces:
Runtime error
Runtime error
File size: 6,792 Bytes
9e36430 02d76b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
from dotenv import load_dotenv
import whisper
from pyannote.audio import Pipeline
import torch
from tqdm import tqdm
from time import time
from transformers import pipeline
from .transcription import Transcription
from .audio_processing import AudioProcessor
load_dotenv()
class Transcriptor:
"""
A class for transcribing and diarizing audio files.
This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.
Attributes
----------
model_size : str
The size of the Whisper model to use for transcription. Available options are:
- 'tiny': Fastest, lowest accuracy
- 'base': Fast, good accuracy for many use cases
- 'small': Balanced speed and accuracy
- 'medium': High accuracy, slower than smaller models
- 'large': High accuracy, slower and more resource-intensive
- 'large-v1': Improved version of the large model
- 'large-v2': Further improved version of the large model
- 'large-v3': Latest and most accurate version of the large model
- 'large-v3-turbo': Optimized version of the large-v3 model for faster processing
model : whisper.model.Whisper
The Whisper model for transcription.
pipeline : pyannote.audio.pipelines.SpeakerDiarization
The PyAnnote speaker diarization pipeline.
Usage:
>>> transcript = Transcriptor(model_size="large-v3")
>>> transcription = transcript.transcribe_audio("/path/to/audio.wav")
>>> transcription.get_name_speakers()
>>> transcription.save("/path/to/transcripts")
Note:
Larger models, especially 'large-v3', provide higher accuracy but require more
computational resources and may be slower to process audio.
"""
def __init__(self, model_size: str = "base"):
self.model_size = model_size
self.HF_TOKEN = os.environ.get("HF_TOKEN")
if not self.HF_TOKEN:
raise ValueError("HF_TOKEN not found. Please set it as a Gradio secret.")
self._setup()
def _setup(self):
"""Initialize the Whisper model and diarization pipeline."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print("Initializing Whisper model...")
if self.model_size == "large-v3-turbo":
self.model = pipeline(
task="automatic-speech-recognition",
model="ylacombe/whisper-large-v3-turbo",
chunk_length_s=30,
device=self.device,
)
else:
self.model = whisper.load_model(self.model_size, device=self.device)
print("Building diarization pipeline...")
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self.HF_TOKEN
).to(torch.device(self.device))
print("Setup completed successfully!")
def transcribe_audio(self, audio_file_path: str, enhanced: bool = False) -> Transcription:
"""
Transcribe an audio file.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be transcribed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve transcription quality.
This includes noise reduction, voice enhancement, and volume boosting.
Returns:
--------
Transcription
A Transcription object containing the transcribed text and speaker segments.
"""
try:
print("Processing audio file...")
processed_audio = self.process_audio(audio_file_path, enhanced)
audio_file_path = processed_audio.path
audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
print("Diarization in progress...")
start_time = time()
diarization = self.perform_diarization(audio_file_path)
print(f"Diarization completed in {time() - start_time:.2f} seconds.")
segments = list(diarization.itertracks(yield_label=True))
transcriptions = self.transcribe_segments(audio, sr, duration, segments)
return Transcription(audio_file_path, transcriptions, segments)
except Exception as e:
raise RuntimeError(f"Failed to process the audio file: {e}")
def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
"""
Process the audio file to ensure it meets the requirements for transcription.
Parameters:
-----------
audio_file_path : str
Path to the audio file to be processed.
enhanced : bool, optional
If True, applies audio enhancement techniques to improve audio quality.
This includes optimizing noise reduction, voice enhancement, and volume boosting
parameters based on the audio characteristics.
Returns:
--------
AudioProcessor
An AudioProcessor object containing the processed audio file.
"""
processed_audio = AudioProcessor(audio_file_path)
if processed_audio.format != ".wav":
processed_audio.convert_to_wav()
if processed_audio.sample_rate != 16000:
processed_audio.resample_wav()
if enhanced:
parameters = processed_audio.optimize_enhancement_parameters()
processed_audio.enhance_audio(noise_reduce_strength=parameters[0],
voice_enhance_strength=parameters[1],
volume_boost=parameters[2])
processed_audio.display_changes()
return processed_audio
def perform_diarization(self, audio_file_path: str):
"""Perform speaker diarization on the audio file."""
with torch.no_grad():
return self.pipeline(audio_file_path)
def transcribe_segments(self, audio, sr, duration, segments):
"""Transcribe audio segments based on diarization."""
transcriptions = []
for turn, _, speaker in tqdm(segments, desc="Transcribing segments", unit="segment", ncols=100, colour="green"):
start = turn.start
end = min(turn.end, duration)
segment = audio[int(start * sr):int(end * sr)]
if self.model_size == "large-v3-turbo":
result = self.model(segment)
else:
result = self.model.transcribe(segment, fp16=self.device == "cuda")
transcriptions.append((speaker, result['text'].strip()))
return transcriptions |