Tingusto commited on
Commit
9e36430
·
verified ·
1 Parent(s): 02d76b7

Update secrets

Browse files
Files changed (1) hide show
  1. pyscript/transcriptor.py +162 -162
pyscript/transcriptor.py CHANGED
@@ -1,163 +1,163 @@
1
- import os
2
- from dotenv import load_dotenv
3
- import whisper
4
- from pyannote.audio import Pipeline
5
- import torch
6
- from tqdm import tqdm
7
- from time import time
8
- from transformers import pipeline
9
- from .transcription import Transcription
10
- from .audio_processing import AudioProcessor
11
-
12
- load_dotenv()
13
-
14
- class Transcriptor:
15
- """
16
- A class for transcribing and diarizing audio files.
17
-
18
- This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.
19
-
20
- Attributes
21
- ----------
22
- model_size : str
23
- The size of the Whisper model to use for transcription. Available options are:
24
- - 'tiny': Fastest, lowest accuracy
25
- - 'base': Fast, good accuracy for many use cases
26
- - 'small': Balanced speed and accuracy
27
- - 'medium': High accuracy, slower than smaller models
28
- - 'large': High accuracy, slower and more resource-intensive
29
- - 'large-v1': Improved version of the large model
30
- - 'large-v2': Further improved version of the large model
31
- - 'large-v3': Latest and most accurate version of the large model
32
- - 'large-v3-turbo': Optimized version of the large-v3 model for faster processing
33
- model : whisper.model.Whisper
34
- The Whisper model for transcription.
35
- pipeline : pyannote.audio.pipelines.SpeakerDiarization
36
- The PyAnnote speaker diarization pipeline.
37
-
38
- Usage:
39
- >>> transcript = Transcriptor(model_size="large-v3")
40
- >>> transcription = transcript.transcribe_audio("/path/to/audio.wav")
41
- >>> transcription.get_name_speakers()
42
- >>> transcription.save("/path/to/transcripts")
43
-
44
- Note:
45
- Larger models, especially 'large-v3', provide higher accuracy but require more
46
- computational resources and may be slower to process audio.
47
- """
48
-
49
- def __init__(self, model_size: str = "base"):
50
- self.model_size = model_size
51
- self.HF_TOKEN = os.getenv("HF_TOKEN")
52
- if not self.HF_TOKEN:
53
- raise ValueError("HF_TOKEN not found. Please store token in .env")
54
- self._setup()
55
-
56
- def _setup(self):
57
- """Initialize the Whisper model and diarization pipeline."""
58
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
59
- print("Initializing Whisper model...")
60
- if self.model_size == "large-v3-turbo":
61
- self.model = pipeline(
62
- task="automatic-speech-recognition",
63
- model="ylacombe/whisper-large-v3-turbo",
64
- chunk_length_s=30,
65
- device=self.device,
66
- )
67
- else:
68
- self.model = whisper.load_model(self.model_size, device=self.device)
69
- print("Building diarization pipeline...")
70
- self.pipeline = Pipeline.from_pretrained(
71
- "pyannote/speaker-diarization-3.1",
72
- use_auth_token=self.HF_TOKEN
73
- ).to(torch.device(self.device))
74
- print("Setup completed successfully!")
75
-
76
- def transcribe_audio(self, audio_file_path: str, enhanced: bool = False) -> Transcription:
77
- """
78
- Transcribe an audio file.
79
-
80
- Parameters:
81
- -----------
82
- audio_file_path : str
83
- Path to the audio file to be transcribed.
84
- enhanced : bool, optional
85
- If True, applies audio enhancement techniques to improve transcription quality.
86
- This includes noise reduction, voice enhancement, and volume boosting.
87
-
88
- Returns:
89
- --------
90
- Transcription
91
- A Transcription object containing the transcribed text and speaker segments.
92
- """
93
- try:
94
- print("Processing audio file...")
95
- processed_audio = self.process_audio(audio_file_path, enhanced)
96
- audio_file_path = processed_audio.path
97
- audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
98
-
99
- print("Diarization in progress...")
100
- start_time = time()
101
- diarization = self.perform_diarization(audio_file_path)
102
- print(f"Diarization completed in {time() - start_time:.2f} seconds.")
103
- segments = list(diarization.itertracks(yield_label=True))
104
-
105
- transcriptions = self.transcribe_segments(audio, sr, duration, segments)
106
- return Transcription(audio_file_path, transcriptions, segments)
107
- except Exception as e:
108
- raise RuntimeError(f"Failed to process the audio file: {e}")
109
-
110
- def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
111
- """
112
- Process the audio file to ensure it meets the requirements for transcription.
113
-
114
- Parameters:
115
- -----------
116
- audio_file_path : str
117
- Path to the audio file to be processed.
118
- enhanced : bool, optional
119
- If True, applies audio enhancement techniques to improve audio quality.
120
- This includes optimizing noise reduction, voice enhancement, and volume boosting
121
- parameters based on the audio characteristics.
122
-
123
- Returns:
124
- --------
125
- AudioProcessor
126
- An AudioProcessor object containing the processed audio file.
127
- """
128
- processed_audio = AudioProcessor(audio_file_path)
129
- if processed_audio.format != ".wav":
130
- processed_audio.convert_to_wav()
131
-
132
- if processed_audio.sample_rate != 16000:
133
- processed_audio.resample_wav()
134
-
135
- if enhanced:
136
- parameters = processed_audio.optimize_enhancement_parameters()
137
- processed_audio.enhance_audio(noise_reduce_strength=parameters[0],
138
- voice_enhance_strength=parameters[1],
139
- volume_boost=parameters[2])
140
-
141
- processed_audio.display_changes()
142
- return processed_audio
143
-
144
- def perform_diarization(self, audio_file_path: str):
145
- """Perform speaker diarization on the audio file."""
146
- with torch.no_grad():
147
- return self.pipeline(audio_file_path)
148
-
149
- def transcribe_segments(self, audio, sr, duration, segments):
150
- """Transcribe audio segments based on diarization."""
151
- transcriptions = []
152
-
153
- for turn, _, speaker in tqdm(segments, desc="Transcribing segments", unit="segment", ncols=100, colour="green"):
154
- start = turn.start
155
- end = min(turn.end, duration)
156
- segment = audio[int(start * sr):int(end * sr)]
157
- if self.model_size == "large-v3-turbo":
158
- result = self.model(segment)
159
- else:
160
- result = self.model.transcribe(segment, fp16=self.device == "cuda")
161
- transcriptions.append((speaker, result['text'].strip()))
162
-
163
  return transcriptions
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import whisper
4
+ from pyannote.audio import Pipeline
5
+ import torch
6
+ from tqdm import tqdm
7
+ from time import time
8
+ from transformers import pipeline
9
+ from .transcription import Transcription
10
+ from .audio_processing import AudioProcessor
11
+
12
+ load_dotenv()
13
+
14
+ class Transcriptor:
15
+ """
16
+ A class for transcribing and diarizing audio files.
17
+
18
+ This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification.
19
+
20
+ Attributes
21
+ ----------
22
+ model_size : str
23
+ The size of the Whisper model to use for transcription. Available options are:
24
+ - 'tiny': Fastest, lowest accuracy
25
+ - 'base': Fast, good accuracy for many use cases
26
+ - 'small': Balanced speed and accuracy
27
+ - 'medium': High accuracy, slower than smaller models
28
+ - 'large': High accuracy, slower and more resource-intensive
29
+ - 'large-v1': Improved version of the large model
30
+ - 'large-v2': Further improved version of the large model
31
+ - 'large-v3': Latest and most accurate version of the large model
32
+ - 'large-v3-turbo': Optimized version of the large-v3 model for faster processing
33
+ model : whisper.model.Whisper
34
+ The Whisper model for transcription.
35
+ pipeline : pyannote.audio.pipelines.SpeakerDiarization
36
+ The PyAnnote speaker diarization pipeline.
37
+
38
+ Usage:
39
+ >>> transcript = Transcriptor(model_size="large-v3")
40
+ >>> transcription = transcript.transcribe_audio("/path/to/audio.wav")
41
+ >>> transcription.get_name_speakers()
42
+ >>> transcription.save("/path/to/transcripts")
43
+
44
+ Note:
45
+ Larger models, especially 'large-v3', provide higher accuracy but require more
46
+ computational resources and may be slower to process audio.
47
+ """
48
+
49
+ def __init__(self, model_size: str = "base"):
50
+ self.model_size = model_size
51
+ self.HF_TOKEN = os.environ.get("HF_TOKEN")
52
+ if not self.HF_TOKEN:
53
+ raise ValueError("HF_TOKEN not found. Please set it as a Gradio secret.")
54
+ self._setup()
55
+
56
+ def _setup(self):
57
+ """Initialize the Whisper model and diarization pipeline."""
58
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ print("Initializing Whisper model...")
60
+ if self.model_size == "large-v3-turbo":
61
+ self.model = pipeline(
62
+ task="automatic-speech-recognition",
63
+ model="ylacombe/whisper-large-v3-turbo",
64
+ chunk_length_s=30,
65
+ device=self.device,
66
+ )
67
+ else:
68
+ self.model = whisper.load_model(self.model_size, device=self.device)
69
+ print("Building diarization pipeline...")
70
+ self.pipeline = Pipeline.from_pretrained(
71
+ "pyannote/speaker-diarization-3.1",
72
+ use_auth_token=self.HF_TOKEN
73
+ ).to(torch.device(self.device))
74
+ print("Setup completed successfully!")
75
+
76
+ def transcribe_audio(self, audio_file_path: str, enhanced: bool = False) -> Transcription:
77
+ """
78
+ Transcribe an audio file.
79
+
80
+ Parameters:
81
+ -----------
82
+ audio_file_path : str
83
+ Path to the audio file to be transcribed.
84
+ enhanced : bool, optional
85
+ If True, applies audio enhancement techniques to improve transcription quality.
86
+ This includes noise reduction, voice enhancement, and volume boosting.
87
+
88
+ Returns:
89
+ --------
90
+ Transcription
91
+ A Transcription object containing the transcribed text and speaker segments.
92
+ """
93
+ try:
94
+ print("Processing audio file...")
95
+ processed_audio = self.process_audio(audio_file_path, enhanced)
96
+ audio_file_path = processed_audio.path
97
+ audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration
98
+
99
+ print("Diarization in progress...")
100
+ start_time = time()
101
+ diarization = self.perform_diarization(audio_file_path)
102
+ print(f"Diarization completed in {time() - start_time:.2f} seconds.")
103
+ segments = list(diarization.itertracks(yield_label=True))
104
+
105
+ transcriptions = self.transcribe_segments(audio, sr, duration, segments)
106
+ return Transcription(audio_file_path, transcriptions, segments)
107
+ except Exception as e:
108
+ raise RuntimeError(f"Failed to process the audio file: {e}")
109
+
110
+ def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor:
111
+ """
112
+ Process the audio file to ensure it meets the requirements for transcription.
113
+
114
+ Parameters:
115
+ -----------
116
+ audio_file_path : str
117
+ Path to the audio file to be processed.
118
+ enhanced : bool, optional
119
+ If True, applies audio enhancement techniques to improve audio quality.
120
+ This includes optimizing noise reduction, voice enhancement, and volume boosting
121
+ parameters based on the audio characteristics.
122
+
123
+ Returns:
124
+ --------
125
+ AudioProcessor
126
+ An AudioProcessor object containing the processed audio file.
127
+ """
128
+ processed_audio = AudioProcessor(audio_file_path)
129
+ if processed_audio.format != ".wav":
130
+ processed_audio.convert_to_wav()
131
+
132
+ if processed_audio.sample_rate != 16000:
133
+ processed_audio.resample_wav()
134
+
135
+ if enhanced:
136
+ parameters = processed_audio.optimize_enhancement_parameters()
137
+ processed_audio.enhance_audio(noise_reduce_strength=parameters[0],
138
+ voice_enhance_strength=parameters[1],
139
+ volume_boost=parameters[2])
140
+
141
+ processed_audio.display_changes()
142
+ return processed_audio
143
+
144
+ def perform_diarization(self, audio_file_path: str):
145
+ """Perform speaker diarization on the audio file."""
146
+ with torch.no_grad():
147
+ return self.pipeline(audio_file_path)
148
+
149
+ def transcribe_segments(self, audio, sr, duration, segments):
150
+ """Transcribe audio segments based on diarization."""
151
+ transcriptions = []
152
+
153
+ for turn, _, speaker in tqdm(segments, desc="Transcribing segments", unit="segment", ncols=100, colour="green"):
154
+ start = turn.start
155
+ end = min(turn.end, duration)
156
+ segment = audio[int(start * sr):int(end * sr)]
157
+ if self.model_size == "large-v3-turbo":
158
+ result = self.model(segment)
159
+ else:
160
+ result = self.model.transcribe(segment, fp16=self.device == "cuda")
161
+ transcriptions.append((speaker, result['text'].strip()))
162
+
163
  return transcriptions