Spaces:
Runtime error
Runtime error
Fix threads argument
Browse files
app.py
CHANGED
@@ -29,7 +29,7 @@ import ffmpeg
|
|
29 |
import gradio as gr
|
30 |
|
31 |
from src.download import ExceededMaximumDuration, download_url
|
32 |
-
from src.utils import slugify, write_srt, write_vtt
|
33 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
34 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
35 |
from src.whisper.whisperFactory import create_whisper_container
|
@@ -596,9 +596,14 @@ if __name__ == '__main__':
|
|
596 |
help="the Whisper implementation to use")
|
597 |
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
598 |
help="the compute type to use for inference")
|
|
|
|
|
599 |
|
600 |
args = parser.parse_args().__dict__
|
601 |
|
602 |
updated_config = default_app_config.update(**args)
|
603 |
|
|
|
|
|
|
|
604 |
create_ui(app_config=updated_config)
|
|
|
29 |
import gradio as gr
|
30 |
|
31 |
from src.download import ExceededMaximumDuration, download_url
|
32 |
+
from src.utils import optional_int, slugify, write_srt, write_vtt
|
33 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
34 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
35 |
from src.whisper.whisperFactory import create_whisper_container
|
|
|
596 |
help="the Whisper implementation to use")
|
597 |
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
598 |
help="the compute type to use for inference")
|
599 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
600 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
601 |
|
602 |
args = parser.parse_args().__dict__
|
603 |
|
604 |
updated_config = default_app_config.update(**args)
|
605 |
|
606 |
+
if (threads := args.pop("threads")) > 0:
|
607 |
+
torch.set_num_threads(threads)
|
608 |
+
|
609 |
create_ui(app_config=updated_config)
|
cli.py
CHANGED
@@ -113,6 +113,9 @@ def cli():
|
|
113 |
device: str = args.pop("device")
|
114 |
os.makedirs(output_dir, exist_ok=True)
|
115 |
|
|
|
|
|
|
|
116 |
whisper_implementation = args.pop("whisper_implementation")
|
117 |
print(f"Using {whisper_implementation} for Whisper")
|
118 |
|
|
|
113 |
device: str = args.pop("device")
|
114 |
os.makedirs(output_dir, exist_ok=True)
|
115 |
|
116 |
+
if (threads := args.pop("threads")) > 0:
|
117 |
+
torch.set_num_threads(threads)
|
118 |
+
|
119 |
whisper_implementation = args.pop("whisper_implementation")
|
120 |
print(f"Using {whisper_implementation} for Whisper")
|
121 |
|