|
import logging |
|
import os |
|
import sys |
|
|
|
from modules.ffmpeg_env import setup_ffmpeg_path |
|
|
|
try: |
|
setup_ffmpeg_path() |
|
|
|
logging.basicConfig( |
|
level=os.getenv("LOG_LEVEL", "INFO"), |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
) |
|
except BaseException: |
|
pass |
|
|
|
import argparse |
|
|
|
from modules import config |
|
from modules.api.api_setup import process_api_args, setup_api_args |
|
from modules.api.app_config import app_description, app_title, app_version |
|
from modules.gradio_dcls_fix import dcls_patch |
|
from modules.models_setup import process_model_args, setup_model_args |
|
from modules.utils.env import get_and_update_env |
|
from modules.utils.ignore_warn import ignore_useless_warnings |
|
from modules.utils.torch_opt import configure_torch_optimizations |
|
from modules.webui import webui_config |
|
from modules.webui.app import create_interface, webui_init |
|
|
|
import subprocess |
|
|
|
subprocess.run( |
|
"pip install flash-attn --no-build-isolation", |
|
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, |
|
shell=True, |
|
) |
|
|
|
dcls_patch() |
|
ignore_useless_warnings() |
|
|
|
|
|
def setup_webui_args(parser: argparse.ArgumentParser): |
|
parser.add_argument("--server_name", type=str, help="server name") |
|
parser.add_argument("--server_port", type=int, help="server port") |
|
parser.add_argument( |
|
"--share", action="store_true", help="share the gradio interface" |
|
) |
|
parser.add_argument("--debug", action="store_true", help="enable debug mode") |
|
parser.add_argument("--auth", type=str, help="username:password for authentication") |
|
parser.add_argument( |
|
"--tts_max_len", |
|
type=int, |
|
help="Max length of text for TTS", |
|
) |
|
parser.add_argument( |
|
"--ssml_max_len", |
|
type=int, |
|
help="Max length of text for SSML", |
|
) |
|
parser.add_argument( |
|
"--max_batch_size", |
|
type=int, |
|
help="Max batch size for TTS", |
|
) |
|
|
|
parser.add_argument( |
|
"--webui_experimental", |
|
action="store_true", |
|
help="Enable webui_experimental features", |
|
) |
|
parser.add_argument( |
|
"--language", |
|
type=str, |
|
help="Set the default language for the webui", |
|
) |
|
parser.add_argument( |
|
"--api", |
|
action="store_true", |
|
help="use api=True to launch the API together with the webui (run launch.py for only API server)", |
|
) |
|
|
|
|
|
def process_webui_args(args): |
|
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str) |
|
server_port = get_and_update_env(args, "server_port", 7860, int) |
|
share = get_and_update_env(args, "share", False, bool) |
|
debug = get_and_update_env(args, "debug", False, bool) |
|
auth = get_and_update_env(args, "auth", None, str) |
|
language = get_and_update_env(args, "language", "zh-CN", str) |
|
api = get_and_update_env(args, "api", False, bool) |
|
|
|
webui_config.experimental = get_and_update_env( |
|
args, "webui_experimental", False, bool |
|
) |
|
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) |
|
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) |
|
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) |
|
|
|
webui_config.experimental = get_and_update_env( |
|
args, "webui_experimental", False, bool |
|
) |
|
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) |
|
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) |
|
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) |
|
|
|
configure_torch_optimizations() |
|
webui_init() |
|
demo = create_interface() |
|
|
|
if auth: |
|
auth = tuple(auth.split(":")) |
|
|
|
app, local_url, share_url = demo.queue().launch( |
|
server_name=server_name, |
|
server_port=server_port, |
|
share=share, |
|
debug=debug, |
|
auth=auth, |
|
show_api=False, |
|
prevent_thread_lock=True, |
|
inbrowser=sys.platform == "win32", |
|
app_kwargs={ |
|
"title": app_title, |
|
"description": app_description, |
|
"version": app_version, |
|
"redoc_url": ( |
|
None |
|
if api is False |
|
else None if config.runtime_env_vars.no_docs else "/redoc" |
|
), |
|
"docs_url": ( |
|
None |
|
if api is False |
|
else None if config.runtime_env_vars.no_docs else "/docs" |
|
), |
|
}, |
|
) |
|
|
|
|
|
|
|
|
|
app.user_middleware = [ |
|
x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware" |
|
] |
|
|
|
if api: |
|
process_api_args(args, app) |
|
|
|
demo.block_thread() |
|
|
|
|
|
if __name__ == "__main__": |
|
import dotenv |
|
|
|
dotenv.load_dotenv( |
|
dotenv_path=os.getenv("ENV_FILE", ".env.webui"), |
|
) |
|
|
|
parser = argparse.ArgumentParser(description="Gradio App") |
|
|
|
setup_webui_args(parser) |
|
setup_model_args(parser) |
|
setup_api_args(parser) |
|
|
|
args = parser.parse_args() |
|
|
|
process_model_args(args) |
|
process_webui_args(args) |
|
|