jhj0517 commited on
Commit
b7ba0cb
·
unverified ·
2 Parent(s): 6960e0b 0a34091

Merge pull request #219 from jhj0517/refactor/args

Browse files
app.py CHANGED
@@ -18,9 +18,10 @@ class App:
18
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
19
  self.whisper_inf = WhisperFactory.create_whisper_inference(
20
  whisper_type=self.args.whisper_type,
21
- model_dir=self.args.faster_whisper_model_dir,
 
 
22
  output_dir=self.args.output_dir,
23
- args=self.args
24
  )
25
  print(f"Use \"{self.args.whisper_type}\" implementation")
26
  print(f"Device \"{self.whisper_inf.device}\" is detected")
 
18
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
19
  self.whisper_inf = WhisperFactory.create_whisper_inference(
20
  whisper_type=self.args.whisper_type,
21
+ whisper_model_dir=self.args.whisper_model_dir,
22
+ faster_whisper_model_dir=self.args.faster_whisper_model_dir,
23
+ insanely_fast_whisper_model_dir=self.args.insanely_fast_whisper_model_dir,
24
  output_dir=self.args.output_dir,
 
25
  )
26
  print(f"Use \"{self.args.whisper_type}\" implementation")
27
  print(f"Device \"{self.whisper_inf.device}\" is detected")
modules/whisper/faster_whisper_inference.py CHANGED
@@ -17,15 +17,20 @@ from modules.whisper.whisper_base import WhisperBase
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str,
21
- output_dir: str,
22
- args: Namespace
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
- output_dir=output_dir,
27
- args=args
28
  )
 
 
 
 
 
29
  self.model_paths = self.get_model_paths()
30
  self.device = self.get_device()
31
  self.available_models = self.model_paths.keys()
 
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
+ model_dir: Optional[str] = None,
21
+ diarization_model_dir: Optional[str] = None,
22
+ output_dir: Optional[str] = None,
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
+ diarization_model_dir=diarization_model_dir,
27
+ output_dir=output_dir
28
  )
29
+ if model_dir is None:
30
+ model_dir = os.path.join("models", "Whisper", "faster-whisper")
31
+ self.model_dir = model_dir
32
+ os.makedirs(self.model_dir, exist_ok=True)
33
+
34
  self.model_paths = self.get_model_paths()
35
  self.device = self.get_device()
36
  self.available_models = self.model_paths.keys()
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -17,15 +17,20 @@ from modules.whisper.whisper_base import WhisperBase
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str,
21
- output_dir: str,
22
- args: Namespace
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  output_dir=output_dir,
27
- args=args
28
  )
 
 
 
 
 
29
  openai_models = whisper.available_models()
30
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
31
  self.available_models = openai_models + distil_models
 
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
+ model_dir: Optional[str] = None,
21
+ diarization_model_dir: Optional[str] = None,
22
+ output_dir: Optional[str] = None,
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  output_dir=output_dir,
27
+ diarization_model_dir=diarization_model_dir
28
  )
29
+ if model_dir is None:
30
+ model_dir = os.path.join("models", "Whisper", "insanely-fast-whisper")
31
+ self.model_dir = model_dir
32
+ os.makedirs(self.model_dir, exist_ok=True)
33
+
34
  openai_models = whisper.available_models()
35
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
36
  self.available_models = openai_models + distil_models
modules/whisper/whisper_Inference.py CHANGED
@@ -12,14 +12,14 @@ from modules.whisper.whisper_parameter import *
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
- model_dir: str,
16
- output_dir: str,
17
- args: Namespace
18
  ):
19
  super().__init__(
20
  model_dir=model_dir,
21
  output_dir=output_dir,
22
- args=args
23
  )
24
 
25
  def transcribe(self,
 
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
+ model_dir: Optional[str] = None,
16
+ diarization_model_dir: Optional[str] = None,
17
+ output_dir: Optional[str] = None,
18
  ):
19
  super().__init__(
20
  model_dir=model_dir,
21
  output_dir=output_dir,
22
+ diarization_model_dir=diarization_model_dir
23
  )
24
 
25
  def transcribe(self,
modules/whisper/whisper_base.py CHANGED
@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
6
  from typing import BinaryIO, Union, Tuple, List
7
  import numpy as np
8
  from datetime import datetime
9
- from argparse import Namespace
10
  from faster_whisper.vad import VadOptions
11
  from dataclasses import astuple
12
 
@@ -20,26 +19,34 @@ from modules.vad.silero_vad import SileroVAD
20
 
21
  class WhisperBase(ABC):
22
  def __init__(self,
23
- model_dir: str,
24
- output_dir: str,
25
- args: Namespace
26
  ):
27
- self.model = None
28
- self.current_model_size = None
 
 
 
 
 
29
  self.model_dir = model_dir
30
  self.output_dir = output_dir
31
  os.makedirs(self.output_dir, exist_ok=True)
32
  os.makedirs(self.model_dir, exist_ok=True)
 
 
 
 
 
 
 
33
  self.available_models = whisper.available_models()
34
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
35
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
36
  self.device = self.get_device()
37
  self.available_compute_types = ["float16", "float32"]
38
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
39
- self.diarizer = Diarizer(
40
- model_dir=args.diarization_model_dir
41
- )
42
- self.vad = SileroVAD()
43
 
44
  @abstractmethod
45
  def transcribe(self,
@@ -47,6 +54,7 @@ class WhisperBase(ABC):
47
  progress: gr.Progress,
48
  *whisper_params,
49
  ):
 
50
  pass
51
 
52
  @abstractmethod
@@ -55,6 +63,7 @@ class WhisperBase(ABC):
55
  compute_type: str,
56
  progress: gr.Progress
57
  ):
 
58
  pass
59
 
60
  def run(self,
 
6
  from typing import BinaryIO, Union, Tuple, List
7
  import numpy as np
8
  from datetime import datetime
 
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
 
19
 
20
  class WhisperBase(ABC):
21
  def __init__(self,
22
+ model_dir: Optional[str] = None,
23
+ diarization_model_dir: Optional[str] = None,
24
+ output_dir: Optional[str] = None,
25
  ):
26
+ if model_dir is None:
27
+ model_dir = os.path.join("models", "Whisper")
28
+ if diarization_model_dir is None:
29
+ diarization_model_dir = os.path.join("models", "Diarization")
30
+ if output_dir is None:
31
+ output_dir = os.path.join("outputs")
32
+
33
  self.model_dir = model_dir
34
  self.output_dir = output_dir
35
  os.makedirs(self.output_dir, exist_ok=True)
36
  os.makedirs(self.model_dir, exist_ok=True)
37
+ self.diarizer = Diarizer(
38
+ model_dir=diarization_model_dir
39
+ )
40
+ self.vad = SileroVAD()
41
+
42
+ self.model = None
43
+ self.current_model_size = None
44
  self.available_models = whisper.available_models()
45
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
46
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
47
  self.device = self.get_device()
48
  self.available_compute_types = ["float16", "float32"]
49
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
 
 
 
 
50
 
51
  @abstractmethod
52
  def transcribe(self,
 
54
  progress: gr.Progress,
55
  *whisper_params,
56
  ):
57
+ """Inference whisper model to transcribe"""
58
  pass
59
 
60
  @abstractmethod
 
63
  compute_type: str,
64
  progress: gr.Progress
65
  ):
66
+ """Initialize whisper model"""
67
  pass
68
 
69
  def run(self,
modules/whisper/whisper_factory.py CHANGED
@@ -1,4 +1,4 @@
1
- from argparse import Namespace
2
  import os
3
 
4
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
@@ -11,27 +11,32 @@ class WhisperFactory:
11
  @staticmethod
12
  def create_whisper_inference(
13
  whisper_type: str,
14
- model_dir: str,
15
- output_dir: str,
16
- args: Namespace
 
 
17
  ) -> "WhisperBase":
18
  """
19
  Create a whisper inference class based on the provided whisper_type.
20
 
21
  Parameters
22
  ----------
23
- whisper_type: str
24
- The repository name of whisper inference to use. Supported values are:
25
- - "faster-whisper" from
26
- - "whisper"
27
- - insanely-fast-whisper", "insanely_fast_whisper", "insanelyfastwhisper",
28
- "insanely-faster-whisper", "insanely_faster_whisper", "insanelyfasterwhisper"
29
- model_dir: str
30
- The directory path where the whisper model is located.
31
- output_dir: str
32
- The directory path where the output files will be saved.
33
- args: Any
34
- Additional arguments to be passed to the whisper inference object.
 
 
 
35
 
36
  Returns
37
  -------
@@ -51,10 +56,26 @@ class WhisperFactory:
51
  ]
52
 
53
  if whisper_type in faster_whisper_typos:
54
- return FasterWhisperInference(model_dir, output_dir, args)
 
 
 
 
55
  elif whisper_type in whisper_typos:
56
- return WhisperInference(model_dir, output_dir, args)
 
 
 
 
57
  elif whisper_type in insanely_fast_whisper_typos:
58
- return InsanelyFastWhisperInference(model_dir, output_dir, args)
 
 
 
 
59
  else:
60
- return FasterWhisperInference(model_dir, output_dir, args)
 
 
 
 
 
1
+ from typing import Optional
2
  import os
3
 
4
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
 
11
  @staticmethod
12
  def create_whisper_inference(
13
  whisper_type: str,
14
+ whisper_model_dir: Optional[str] = None,
15
+ faster_whisper_model_dir: Optional[str] = None,
16
+ insanely_fast_whisper_model_dir: Optional[str] = None,
17
+ diarization_model_dir: Optional[str] = None,
18
+ output_dir: Optional[str] = None,
19
  ) -> "WhisperBase":
20
  """
21
  Create a whisper inference class based on the provided whisper_type.
22
 
23
  Parameters
24
  ----------
25
+ whisper_type : str
26
+ The type of Whisper implementation to use. Supported values (case-insensitive):
27
+ - "faster-whisper": https://github.com/openai/whisper
28
+ - "whisper": https://github.com/openai/whisper
29
+ - "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper
30
+ whisper_model_dir : str
31
+ Directory path for the Whisper model.
32
+ faster_whisper_model_dir : str
33
+ Directory path for the Faster Whisper model.
34
+ insanely_fast_whisper_model_dir : str
35
+ Directory path for the Insanely Fast Whisper model.
36
+ diarization_model_dir : str
37
+ Directory path for the diarization model.
38
+ output_dir : str
39
+ Directory path where output files will be saved.
40
 
41
  Returns
42
  -------
 
56
  ]
57
 
58
  if whisper_type in faster_whisper_typos:
59
+ return FasterWhisperInference(
60
+ model_dir=faster_whisper_model_dir,
61
+ output_dir=output_dir,
62
+ diarization_model_dir=diarization_model_dir
63
+ )
64
  elif whisper_type in whisper_typos:
65
+ return WhisperInference(
66
+ model_dir=whisper_model_dir,
67
+ output_dir=output_dir,
68
+ diarization_model_dir=diarization_model_dir
69
+ )
70
  elif whisper_type in insanely_fast_whisper_typos:
71
+ return InsanelyFastWhisperInference(
72
+ model_dir=insanely_fast_whisper_model_dir,
73
+ output_dir=output_dir,
74
+ diarization_model_dir=diarization_model_dir
75
+ )
76
  else:
77
+ return FasterWhisperInference(
78
+ model_dir=faster_whisper_model_dir,
79
+ output_dir=output_dir,
80
+ diarization_model_dir=diarization_model_dir
81
+ )