import os from argparse import ArgumentParser from pathlib import Path import pyrootutils import torch from loguru import logger pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from tools.inference_engine import TTSInferenceEngine from tools.llama.generate import launch_thread_safe_queue from tools.schema import ServeTTSRequest from tools.vqgan.inference import load_model as load_decoder_model from tools.webui import build_app from tools.webui.inference import get_inference_wrapper # Make einx happy os.environ["EINX_FILTER_TRACEBACK"] = "false" os.system("huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5") def parse_args(): parser = ArgumentParser() parser.add_argument( "--llama-checkpoint-path", type=Path, default="checkpoints/fish-speech-1.5", ) parser.add_argument( "--decoder-checkpoint-path", type=Path, default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", ) parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--half", action="store_true") parser.add_argument("--compile", action="store_true") parser.add_argument("--max-gradio-length", type=int, default=0) parser.add_argument("--theme", type=str, default="light") return parser.parse_args() if __name__ == "__main__": args = parse_args() args.precision = torch.half if args.half else torch.bfloat16 # Check if MPS or CUDA is available if torch.backends.mps.is_available(): args.device = "mps" logger.info("mps is available, running on mps.") elif not torch.cuda.is_available(): logger.info("CUDA is not available, running on CPU.") args.device = "cpu" logger.info("Loading Llama model...") llama_queue = launch_thread_safe_queue( checkpoint_path=args.llama_checkpoint_path, device=args.device, precision=args.precision, compile=args.compile, ) logger.info("Loading VQ-GAN model...") decoder_model = load_decoder_model( config_name=args.decoder_config_name, checkpoint_path=args.decoder_checkpoint_path, device=args.device, ) logger.info("Decoder model loaded, warming up...") # Create the inference engine inference_engine = TTSInferenceEngine( llama_queue=llama_queue, decoder_model=decoder_model, compile=args.compile, precision=args.precision, ) # Dry run to check if the model is loaded correctly and avoid the first-time latency list( inference_engine.inference( ServeTTSRequest( text="Hello world.", references=[], reference_id=None, max_new_tokens=1024, chunk_length=200, top_p=0.7, repetition_penalty=1.5, temperature=0.7, format="wav", ) ) ) logger.info("Warming up done, launching the web UI...") # Get the inference function with the immutable arguments inference_fct = get_inference_wrapper(inference_engine) app = build_app(inference_fct, args.theme) app.launch(show_api=True, share=True)