avans06 commited on
Commit
cb7fd1a
·
1 Parent(s): 4abae29

Diarization now supports version selection, with the default set to speaker-diarization-3.1.

Browse files
app.py CHANGED
@@ -19,7 +19,6 @@ from src.diarization.diarization import Diarization
19
  from src.diarization.diarizationContainer import DiarizationContainer
20
  from src.hooks.progressListener import ProgressListener
21
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
22
- from src.hooks.whisperProgressHook import create_progress_listener_handle
23
  from src.modelCache import ModelCache
24
  from src.prompts.jsonPromptStrategy import JsonPromptStrategy
25
  from src.prompts.prependPromptStrategy import PrependPromptStrategy
@@ -32,7 +31,7 @@ import ffmpeg
32
  # UI
33
  import gradio as gr
34
 
35
- from src.download import ExceededMaximumDuration, download_url
36
  from src.utils import optional_int, slugify, str2bool, write_srt, write_srt_original, write_vtt
37
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
38
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
@@ -100,11 +99,15 @@ class WhisperTranscriber:
100
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
101
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
102
 
103
- def set_diarization(self, auth_token: str, enable_daemon_process: bool = True, **kwargs):
 
 
104
  if self.diarization is None:
105
  self.diarization = DiarizationContainer(auth_token=auth_token, enable_daemon_process=enable_daemon_process,
106
  auto_cleanup_timeout_seconds=self.app_config.diarization_process_timeout,
107
- cache=self.model_cache)
 
 
108
  # Set parameters
109
  self.diarization_kwargs = kwargs
110
 
@@ -257,6 +260,7 @@ class WhisperTranscriber:
257
  diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
258
  diarization_min_speakers: int = decodeOptions.pop("diarization_min_speakers", 1)
259
  diarization_max_speakers: int = decodeOptions.pop("diarization_max_speakers", 8)
 
260
  highlight_words: bool = decodeOptions.pop("highlight_words", False)
261
 
262
  temperature: float = decodeOptions.pop("temperature", None)
@@ -290,9 +294,9 @@ class WhisperTranscriber:
290
 
291
  if diarization:
292
  if diarization_speakers is not None and diarization_speakers < 1:
293
- self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
294
  else:
295
- self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
296
  else:
297
  self.unset_diarization()
298
 
@@ -1137,7 +1141,8 @@ def create_ui(app_config: ApplicationConfig):
1137
  gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs, elem_id="diarization", info="Whether to perform speaker diarization"),
1138
  gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs, elem_id="diarization_speakers", info="The number of speakers to detect"),
1139
  gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs, elem_id="diarization_min_speakers", info="The minimum number of speakers to detect"),
1140
- gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs, elem_id="diarization_max_speakers", info="The maximum number of speakers to detect")
 
1141
  }
1142
 
1143
  common_output = lambda : [
@@ -1439,6 +1444,7 @@ if __name__ == '__main__':
1439
  parser.add_argument("--diarization_max_speakers", type=int, default=default_app_config.diarization_max_speakers, help="Maximum number of speakers")
1440
  parser.add_argument("--diarization_process_timeout", type=int, default=default_app_config.diarization_process_timeout, \
1441
  help="Number of seconds before inactivate diarization processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
 
1442
 
1443
  args = parser.parse_args().__dict__
1444
 
 
19
  from src.diarization.diarizationContainer import DiarizationContainer
20
  from src.hooks.progressListener import ProgressListener
21
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
 
22
  from src.modelCache import ModelCache
23
  from src.prompts.jsonPromptStrategy import JsonPromptStrategy
24
  from src.prompts.prependPromptStrategy import PrependPromptStrategy
 
31
  # UI
32
  import gradio as gr
33
 
34
+ from src.download import ExceededMaximumDuration
35
  from src.utils import optional_int, slugify, str2bool, write_srt, write_srt_original, write_vtt
36
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
37
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
 
99
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
100
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
101
 
102
+ def set_diarization(self, auth_token: str, enable_daemon_process: bool = True, diarization_version: str = None, **kwargs):
103
+ if diarization_version == None:
104
+ diarization_version = self.app_config.diarization_version
105
  if self.diarization is None:
106
  self.diarization = DiarizationContainer(auth_token=auth_token, enable_daemon_process=enable_daemon_process,
107
  auto_cleanup_timeout_seconds=self.app_config.diarization_process_timeout,
108
+ cache=self.model_cache, diarization_version=diarization_version)
109
+ else:
110
+ self.diarization.diarization_version=diarization_version
111
  # Set parameters
112
  self.diarization_kwargs = kwargs
113
 
 
260
  diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
261
  diarization_min_speakers: int = decodeOptions.pop("diarization_min_speakers", 1)
262
  diarization_max_speakers: int = decodeOptions.pop("diarization_max_speakers", 8)
263
+ diarization_version: str = decodeOptions.pop("diarization_version", "speaker-diarization-3.1")
264
  highlight_words: bool = decodeOptions.pop("highlight_words", False)
265
 
266
  temperature: float = decodeOptions.pop("temperature", None)
 
294
 
295
  if diarization:
296
  if diarization_speakers is not None and diarization_speakers < 1:
297
+ self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers, diarization_version=diarization_version)
298
  else:
299
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers, diarization_version=diarization_version)
300
  else:
301
  self.unset_diarization()
302
 
 
1141
  gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs, elem_id="diarization", info="Whether to perform speaker diarization"),
1142
  gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs, elem_id="diarization_speakers", info="The number of speakers to detect"),
1143
  gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs, elem_id="diarization_min_speakers", info="The minimum number of speakers to detect"),
1144
+ gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs, elem_id="diarization_max_speakers", info="The maximum number of speakers to detect"),
1145
+ gr.Dropdown(label="Diarization Version", choices=["speaker-diarization-3.1", "speaker-diarization-3.0", "[email protected]"], value=app_config.diarization_version, elem_id="diarization_version", info="pyannote.audio speaker diarization pipeline v3.1 is expected to be much better (and faster) than v2.x. [Benchmark](https://github.com/pyannote/pyannote-audio?tab=readme-ov-file#benchmark)"),
1146
  }
1147
 
1148
  common_output = lambda : [
 
1444
  parser.add_argument("--diarization_max_speakers", type=int, default=default_app_config.diarization_max_speakers, help="Maximum number of speakers")
1445
  parser.add_argument("--diarization_process_timeout", type=int, default=default_app_config.diarization_process_timeout, \
1446
  help="Number of seconds before inactivate diarization processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
1447
+ parser.add_argument('--diarization_version', type=str, default=default_app_config.diarization_version, help='Specify the diarization version, defaulting to speaker-diarization-3.1')
1448
 
1449
  args = parser.parse_args().__dict__
1450
 
requirements-fasterWhisper.txt CHANGED
@@ -20,7 +20,7 @@ sentencepiece
20
  # Needed by diarization
21
  intervaltree
22
  srt
23
- https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
24
 
25
  # Needed by ALMA-GPTQ
26
  accelerate
 
20
  # Needed by diarization
21
  intervaltree
22
  srt
23
+ pyannote.audio
24
 
25
  # Needed by ALMA-GPTQ
26
  accelerate
requirements-whisper.txt CHANGED
@@ -20,7 +20,7 @@ sentencepiece
20
  # Needed by diarization
21
  intervaltree
22
  srt
23
- https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
24
 
25
  # Needed by ALMA-GPTQ
26
  accelerate
 
20
  # Needed by diarization
21
  intervaltree
22
  srt
23
+ pyannote.audio
24
 
25
  # Needed by ALMA-GPTQ
26
  accelerate
requirements.txt CHANGED
@@ -20,7 +20,7 @@ sentencepiece
20
  # Needed by diarization
21
  intervaltree
22
  srt
23
- https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
24
 
25
  # Needed by ALMA-GPTQ
26
  accelerate
 
20
  # Needed by diarization
21
  intervaltree
22
  srt
23
+ pyannote.audio
24
 
25
  # Needed by ALMA-GPTQ
26
  accelerate
src/config.py CHANGED
@@ -78,6 +78,7 @@ class ApplicationConfig:
78
  auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
79
  diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
80
  diarization_process_timeout: int = 60,
 
81
  # Translation
82
  translation_batch_size: int = 2,
83
  translation_no_repeat_ngram_size: int = 4,
@@ -148,6 +149,8 @@ class ApplicationConfig:
148
  self.diarization_min_speakers = diarization_min_speakers
149
  self.diarization_max_speakers = diarization_max_speakers
150
  self.diarization_process_timeout = diarization_process_timeout
 
 
151
  # Translation
152
  self.translation_batch_size = translation_batch_size
153
  self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
 
78
  auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
79
  diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
80
  diarization_process_timeout: int = 60,
81
+ diarization_version: str = "speaker-diarization-3.1",
82
  # Translation
83
  translation_batch_size: int = 2,
84
  translation_no_repeat_ngram_size: int = 4,
 
149
  self.diarization_min_speakers = diarization_min_speakers
150
  self.diarization_max_speakers = diarization_max_speakers
151
  self.diarization_process_timeout = diarization_process_timeout
152
+ self.diarization_version = diarization_version
153
+
154
  # Translation
155
  self.translation_batch_size = translation_batch_size
156
  self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
src/diarization/diarization.py CHANGED
@@ -26,15 +26,16 @@ class DiarizationEntry:
26
  }
27
 
28
  class Diarization:
29
- def __init__(self, auth_token=None):
30
  if auth_token is None:
31
  auth_token = os.environ.get("HF_ACCESS_TOKEN")
32
  if auth_token is None:
33
  raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HF_ACCESS_TOKEN environment variable")
34
 
35
- self.auth_token = auth_token
36
- self.initialized = False
37
- self.pipeline = None
 
38
 
39
  @staticmethod
40
  def has_libraries():
@@ -47,17 +48,17 @@ class Diarization:
47
 
48
  def initialize(self):
49
  """
50
- 1.Install pyannote.audio 3.0 with pip install pyannote.audio
51
  2.Accept pyannote/segmentation-3.0 user conditions
52
- 3.Accept pyannote/speaker-diarization-3.0 user conditions
53
  4.Create access token at hf.co/settings/tokens.
54
- https://huggingface.co/pyannote/speaker-diarization-3.0
55
  """
56
  if self.initialized:
57
  return
58
  from pyannote.audio import Pipeline
59
-
60
- self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0", use_auth_token=self.auth_token)
61
  self.initialized = True
62
 
63
  # Load GPU mode if available
@@ -174,7 +175,7 @@ def main():
174
  # Read whisper JSON or SRT file
175
  whisper_result = load_transcript(args.whisper_file)
176
 
177
- diarization = Diarization(auth_token=args.auth_token)
178
  diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
179
 
180
  # Print result
 
26
  }
27
 
28
  class Diarization:
29
+ def __init__(self, auth_token=None, diarization_version=None):
30
  if auth_token is None:
31
  auth_token = os.environ.get("HF_ACCESS_TOKEN")
32
  if auth_token is None:
33
  raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HF_ACCESS_TOKEN environment variable")
34
 
35
+ self.auth_token = auth_token
36
+ self.initialized = False
37
+ self.pipeline = None
38
+ self.diarization_version = diarization_version
39
 
40
  @staticmethod
41
  def has_libraries():
 
48
 
49
  def initialize(self):
50
  """
51
+ 1.Install pyannote.audio 3.1 with pip install pyannote.audio
52
  2.Accept pyannote/segmentation-3.0 user conditions
53
+ 3.Accept pyannote/speaker-diarization-3.1 user conditions
54
  4.Create access token at hf.co/settings/tokens.
55
+ https://huggingface.co/pyannote/speaker-diarization-3.1
56
  """
57
  if self.initialized:
58
  return
59
  from pyannote.audio import Pipeline
60
+
61
+ self.pipeline = Pipeline.from_pretrained(f"pyannote/{self.diarization_version}", use_auth_token=self.auth_token)
62
  self.initialized = True
63
 
64
  # Load GPU mode if available
 
175
  # Read whisper JSON or SRT file
176
  whisper_result = load_transcript(args.whisper_file)
177
 
178
+ diarization = Diarization(auth_token=args.auth_token, diarization_version=args.diarization_version)
179
  diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
180
 
181
  # Print result
src/diarization/diarizationContainer.py CHANGED
@@ -4,13 +4,14 @@ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
4
  from src.vadParallel import ParallelContext
5
 
6
  class DiarizationContainer:
7
- def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None):
8
  self.auth_token = auth_token
9
  self.enable_daemon_process = enable_daemon_process
10
  self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
11
  self.diarization_context: ParallelContext = None
12
  self.cache = cache
13
  self.model = None
 
14
 
15
  def run(self, audio_file, **kwargs):
16
  # Create parallel context if needed
@@ -37,18 +38,18 @@ class DiarizationContainer:
37
  return self.model.mark_speakers(diarization_result, whisper_result)
38
 
39
  # Create a new diarization model (calling mark_speakers will not initialize pyannote.audio)
40
- model = Diarization(self.auth_token)
41
  return model.mark_speakers(diarization_result, whisper_result)
42
 
43
  def get_model(self):
44
  # Lazy load the model
45
  if (self.model is None):
46
  if self.cache:
47
- print("Loading diarization model from cache")
48
- self.model = self.cache.get("diarization", lambda : Diarization(self.auth_token))
49
  else:
50
- print("Loading diarization model")
51
- self.model = Diarization(self.auth_token)
52
  return self.model
53
 
54
  def execute(self, audio_file, **kwargs):
@@ -66,7 +67,8 @@ class DiarizationContainer:
66
  return {
67
  "auth_token": self.auth_token,
68
  "enable_daemon_process": self.enable_daemon_process,
69
- "auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds
 
70
  }
71
 
72
  def __setstate__(self, state):
@@ -74,5 +76,6 @@ class DiarizationContainer:
74
  self.enable_daemon_process = state["enable_daemon_process"]
75
  self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"]
76
  self.diarization_context = None
 
77
  self.cache = GLOBAL_MODEL_CACHE
78
  self.model = None
 
4
  from src.vadParallel import ParallelContext
5
 
6
  class DiarizationContainer:
7
+ def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None, diarization_version=None):
8
  self.auth_token = auth_token
9
  self.enable_daemon_process = enable_daemon_process
10
  self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
11
  self.diarization_context: ParallelContext = None
12
  self.cache = cache
13
  self.model = None
14
+ self.diarization_version = diarization_version
15
 
16
  def run(self, audio_file, **kwargs):
17
  # Create parallel context if needed
 
38
  return self.model.mark_speakers(diarization_result, whisper_result)
39
 
40
  # Create a new diarization model (calling mark_speakers will not initialize pyannote.audio)
41
+ model = Diarization(self.auth_token, self.diarization_version)
42
  return model.mark_speakers(diarization_result, whisper_result)
43
 
44
  def get_model(self):
45
  # Lazy load the model
46
  if (self.model is None):
47
  if self.cache:
48
+ print(f"Loading {self.diarization_version} model from cache")
49
+ self.model = self.cache.get(self.diarization_version, lambda : Diarization(self.auth_token, self.diarization_version))
50
  else:
51
+ print(f"Loading {self.diarization_version} model")
52
+ self.model = Diarization(self.auth_token, self.diarization_version)
53
  return self.model
54
 
55
  def execute(self, audio_file, **kwargs):
 
67
  return {
68
  "auth_token": self.auth_token,
69
  "enable_daemon_process": self.enable_daemon_process,
70
+ "auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds,
71
+ "diarization_version": self.diarization_version
72
  }
73
 
74
  def __setstate__(self, state):
 
76
  self.enable_daemon_process = state["enable_daemon_process"]
77
  self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"]
78
  self.diarization_context = None
79
+ self.diarization_version = state["diarization_version"]
80
  self.cache = GLOBAL_MODEL_CACHE
81
  self.model = None
src/utils.py CHANGED
@@ -150,7 +150,7 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
150
  yield segment
151
 
152
  if segment_longest_speaker is not None:
153
- segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
154
 
155
  subtitle_start = segment['start']
156
  subtitle_end = segment['end']
@@ -160,7 +160,9 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
160
  if len(words) == 0:
161
  # Prepend the longest speaker ID if available
162
  if segment_longest_speaker is not None:
163
- text = f"({segment_longest_speaker}) {text}"
 
 
164
 
165
  result = {
166
  'start': subtitle_start,
@@ -175,12 +177,16 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
175
  continue
176
 
177
  if segment_longest_speaker is not None:
178
- # Add the beginning
179
- words.insert(0, {
180
- 'start': subtitle_start,
181
- 'end' : subtitle_start,
182
- 'word' : f"({segment_longest_speaker})"
183
- })
 
 
 
 
184
 
185
  text_words = [text] if not highlight_words and text_original is not None and len(text_original) > 0 else [ this_word["word"] for this_word in words ]
186
 
 
150
  yield segment
151
 
152
  if segment_longest_speaker is not None:
153
+ segment_longest_speaker = "(" + segment_longest_speaker.replace("SPEAKER", "S") + ")"
154
 
155
  subtitle_start = segment['start']
156
  subtitle_end = segment['end']
 
160
  if len(words) == 0:
161
  # Prepend the longest speaker ID if available
162
  if segment_longest_speaker is not None:
163
+ text = f"{segment_longest_speaker} {text}"
164
+ if text_original is not None and len(text_original) > 0:
165
+ text_original = f"{segment_longest_speaker} {text_original}"
166
 
167
  result = {
168
  'start': subtitle_start,
 
177
  continue
178
 
179
  if segment_longest_speaker is not None:
180
+ if words[0].get('word') != segment_longest_speaker:
181
+ # Add the beginning
182
+ words.insert(0, {
183
+ 'start': subtitle_start,
184
+ 'end' : subtitle_start,
185
+ 'word' : segment_longest_speaker
186
+ })
187
+ text = f"{segment_longest_speaker} {text}"
188
+ if text_original is not None and len(text_original) > 0:
189
+ text_original = f"{segment_longest_speaker} {text_original}"
190
 
191
  text_words = [text] if not highlight_words and text_original is not None and len(text_original) > 0 else [ this_word["word"] for this_word in words ]
192