jhj0517 commited on
Commit
7716c2c
·
unverified ·
2 Parent(s): f9b7286 84055fe

Merge pull request #154 from jhj0517/feature/add-local-model-path

Browse files
app.py CHANGED
@@ -17,8 +17,10 @@ class App:
17
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
18
  self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
19
  if isinstance(self.whisper_inf, FasterWhisperInference):
 
20
  print("Use Faster Whisper implementation")
21
  else:
 
22
  print("Use Open AI Whisper implementation")
23
  print(f"Device \"{self.whisper_inf.device}\" is detected")
24
  self.nllb_inf = NLLBInference()
@@ -296,6 +298,8 @@ parser.add_argument('--password', type=str, default=None, help='Gradio authentic
296
  parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
297
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
298
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
 
 
299
  _args = parser.parse_args()
300
 
301
  if __name__ == "__main__":
 
17
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
18
  self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
19
  if isinstance(self.whisper_inf, FasterWhisperInference):
20
+ self.whisper_inf.model_dir = args.faster_whisper_model_dir
21
  print("Use Faster Whisper implementation")
22
  else:
23
+ self.whisper_inf.model_dir = args.whisper_model_dir
24
  print("Use Open AI Whisper implementation")
25
  print(f"Device \"{self.whisper_inf.device}\" is detected")
26
  self.nllb_inf = NLLBInference()
 
298
  parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
299
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
300
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
301
+ parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
302
+ parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
303
  _args = parser.parse_args()
304
 
305
  if __name__ == "__main__":
modules/faster_whisper_inference.py CHANGED
@@ -32,7 +32,7 @@ class FasterWhisperInference(BaseInterface):
32
  self.available_compute_types = ctranslate2.get_supported_compute_types(
33
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
- self.default_beam_size = 1
36
 
37
  def transcribe_file(self,
38
  files: list,
@@ -311,7 +311,7 @@ class FasterWhisperInference(BaseInterface):
311
  self.model = faster_whisper.WhisperModel(
312
  device=self.device,
313
  model_size_or_path=model_size,
314
- download_root=os.path.join("models", "Whisper", "faster-whisper"),
315
  compute_type=self.current_compute_type
316
  )
317
 
 
32
  self.available_compute_types = ctranslate2.get_supported_compute_types(
33
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
+ self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
36
 
37
  def transcribe_file(self,
38
  files: list,
 
311
  self.model = faster_whisper.WhisperModel(
312
  device=self.device,
313
  model_size_or_path=model_size,
314
+ download_root=self.model_dir,
315
  compute_type=self.current_compute_type
316
  )
317
 
modules/whisper_Inference.py CHANGED
@@ -26,7 +26,7 @@ class WhisperInference(BaseInterface):
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.available_compute_types = ["float16", "float32"]
28
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
- self.default_beam_size = 1
30
 
31
  def transcribe_file(self,
32
  files: list,
@@ -288,7 +288,7 @@ class WhisperInference(BaseInterface):
288
  self.model = whisper.load_model(
289
  name=model_size,
290
  device=self.device,
291
- download_root=os.path.join("models", "Whisper")
292
  )
293
 
294
  @staticmethod
 
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.available_compute_types = ["float16", "float32"]
28
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
+ self.model_dir = os.path.join("models", "Whisper")
30
 
31
  def transcribe_file(self,
32
  files: list,
 
288
  self.model = whisper.load_model(
289
  name=model_size,
290
  device=self.device,
291
+ download_root=self.model_dir
292
  )
293
 
294
  @staticmethod
user-start-webui.bat CHANGED
@@ -10,9 +10,10 @@ set SHARE=
10
  set THEME=
11
  set DISABLE_FASTER_WHISPER=
12
  set API_OPEN=
 
 
13
 
14
 
15
- :: Set args accordingly
16
  if not "%SERVER_NAME%"=="" (
17
  set SERVER_NAME_ARG=--server_name %SERVER_NAME%
18
  )
@@ -37,7 +38,13 @@ if /I "%DISABLE_FASTER_WHISPER%"=="true" (
37
  if /I "%API_OPEN%"=="true" (
38
  set API_OPEN=--api_open
39
  )
 
 
 
 
 
 
40
 
41
  :: Call the original .bat script with optional arguments
42
- start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %DISABLE_FASTER_WHISPER_ARG% %API_OPEN%
43
  pause
 
10
  set THEME=
11
  set DISABLE_FASTER_WHISPER=
12
  set API_OPEN=
13
+ set WHISPER_MODEL_DIR=
14
+ set FASTER_WHISPER_MODEL_DIR=
15
 
16
 
 
17
  if not "%SERVER_NAME%"=="" (
18
  set SERVER_NAME_ARG=--server_name %SERVER_NAME%
19
  )
 
38
  if /I "%API_OPEN%"=="true" (
39
  set API_OPEN=--api_open
40
  )
41
+ if not "%WHISPER_MODEL_DIR%"=="" (
42
+ set WHISPER_MODEL_DIR_ARG=--whisper_model_dir "%WHISPER_MODEL_DIR%"
43
+ )
44
+ if not "%FASTER_WHISPER_MODEL_DIR%"=="" (
45
+ set FASTER_WHISPER_MODEL_DIR_ARG=--faster_whisper_model_dir "%FASTER_WHISPER_MODEL_DIR%"
46
+ )
47
 
48
  :: Call the original .bat script with optional arguments
49
+ start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %DISABLE_FASTER_WHISPER_ARG% %API_OPEN% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG%
50
  pause