Spaces:
Runtime error
Runtime error
File size: 6,953 Bytes
8031785 44d964a 1acaa19 44d964a 74b7d77 8031785 74b7d77 8031785 74b7d77 ce65f49 8031785 ce65f49 8031785 44d964a d20404a 295de00 1acaa19 8031785 1acaa19 33ee1bb 1acaa19 f55c594 a1da02d 3cd7d59 1acaa19 44d964a 1acaa19 44d964a d20404a 1acaa19 295de00 44d964a 1acaa19 8031785 1acaa19 33ee1bb 1acaa19 f55c594 a1da02d 3cd7d59 a1da02d 44d964a 1acaa19 44d964a 4a46425 44d964a 1acaa19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from enum import Enum
import urllib
import os
from typing import List
from urllib.parse import urlparse
import json5
import torch
from tqdm import tqdm
class ModelConfig:
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
"""
Initialize a model configuration.
name: Name of the model
url: URL to download the model from
path: Path to the model file. If not set, the model will be downloaded from the URL.
type: Type of model. Can be whisper or huggingface.
"""
self.name = name
self.url = url
self.path = path
self.type = type
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
class VadInitialPromptMode(Enum):
PREPEND_ALL_SEGMENTS = 1
PREPREND_FIRST_SEGMENT = 2
JSON_PROMPT_MODE = 3
@staticmethod
def from_string(s: str):
normalized = s.lower() if s is not None else None
if normalized == "prepend_all_segments":
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
elif normalized == "prepend_first_segment":
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
elif normalized == "json_prompt_mode":
return VadInitialPromptMode.JSON_PROMPT_MODE
elif normalized is not None and normalized != "":
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
else:
return None
class ApplicationConfig:
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
share: bool = False, server_name: str = None, server_port: int = 7860,
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
whisper_implementation: str = "whisper",
default_model_name: str = "medium", default_vad: str = "silero-vad",
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
auto_parallel: bool = False, output_dir: str = None,
model_dir: str = None, device: str = None,
verbose: bool = True, task: str = "transcribe", language: str = None,
vad_initial_prompt_mode: str = "prepend_first_segment ",
vad_merge_window: float = 5, vad_max_merge_size: float = 30,
vad_padding: float = 1, vad_prompt_window: float = 3,
temperature: float = 0, best_of: int = 5, beam_size: int = 5,
patience: float = None, length_penalty: float = None,
suppress_tokens: str = "-1", initial_prompt: str = None,
condition_on_previous_text: bool = True, fp16: bool = True,
compute_type: str = "float16",
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
# Word timestamp settings
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
highlight_words: bool = False,
# Diarization
auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
diarization_process_timeout: int = 60):
self.models = models
# WebUI settings
self.input_audio_max_duration = input_audio_max_duration
self.share = share
self.server_name = server_name
self.server_port = server_port
self.queue_concurrency_count = queue_concurrency_count
self.delete_uploaded_files = delete_uploaded_files
self.whisper_implementation = whisper_implementation
self.default_model_name = default_model_name
self.default_vad = default_vad
self.vad_parallel_devices = vad_parallel_devices
self.vad_cpu_cores = vad_cpu_cores
self.vad_process_timeout = vad_process_timeout
self.auto_parallel = auto_parallel
self.output_dir = output_dir
self.model_dir = model_dir
self.device = device
self.verbose = verbose
self.task = task
self.language = language
self.vad_initial_prompt_mode = vad_initial_prompt_mode
self.vad_merge_window = vad_merge_window
self.vad_max_merge_size = vad_max_merge_size
self.vad_padding = vad_padding
self.vad_prompt_window = vad_prompt_window
self.temperature = temperature
self.best_of = best_of
self.beam_size = beam_size
self.patience = patience
self.length_penalty = length_penalty
self.suppress_tokens = suppress_tokens
self.initial_prompt = initial_prompt
self.condition_on_previous_text = condition_on_previous_text
self.fp16 = fp16
self.compute_type = compute_type
self.temperature_increment_on_fallback = temperature_increment_on_fallback
self.compression_ratio_threshold = compression_ratio_threshold
self.logprob_threshold = logprob_threshold
self.no_speech_threshold = no_speech_threshold
# Word timestamp settings
self.word_timestamps = word_timestamps
self.prepend_punctuations = prepend_punctuations
self.append_punctuations = append_punctuations
self.highlight_words = highlight_words
# Diarization settings
self.auth_token = auth_token
self.diarization = diarization
self.diarization_speakers = diarization_speakers
self.diarization_min_speakers = diarization_min_speakers
self.diarization_max_speakers = diarization_max_speakers
self.diarization_process_timeout = diarization_process_timeout
def get_model_names(self):
return [ x.name for x in self.models ]
def update(self, **new_values):
result = ApplicationConfig(**self.__dict__)
for key, value in new_values.items():
setattr(result, key, value)
return result
@staticmethod
def create_default(**kwargs):
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
# Update with kwargs
if len(kwargs) > 0:
app_config = app_config.update(**kwargs)
return app_config
@staticmethod
def parse_file(config_path: str):
import json5
with open(config_path, "r", encoding="utf-8") as f:
# Load using json5
data = json5.load(f)
data_models = data.pop("models", [])
models = [ ModelConfig(**x) for x in data_models ]
return ApplicationConfig(models, **data)
|