Spaces:
Running
Running
import os | |
from argparse import ArgumentParser | |
from pathlib import Path | |
import pyrootutils | |
import torch | |
from loguru import logger | |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
from tools.inference_engine import TTSInferenceEngine | |
from tools.llama.generate import launch_thread_safe_queue | |
from tools.schema import ServeTTSRequest | |
from tools.vqgan.inference import load_model as load_decoder_model | |
from tools.webui import build_app | |
from tools.webui.inference import get_inference_wrapper | |
# Make einx happy | |
os.environ["EINX_FILTER_TRACEBACK"] = "false" | |
os.system("huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5") | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--llama-checkpoint-path", | |
type=Path, | |
default="checkpoints/fish-speech-1.5", | |
) | |
parser.add_argument( | |
"--decoder-checkpoint-path", | |
type=Path, | |
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", | |
) | |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--half", action="store_true") | |
parser.add_argument("--compile", action="store_true") | |
parser.add_argument("--max-gradio-length", type=int, default=0) | |
parser.add_argument("--theme", type=str, default="light") | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = parse_args() | |
args.precision = torch.half if args.half else torch.bfloat16 | |
# Check if MPS or CUDA is available | |
if torch.backends.mps.is_available(): | |
args.device = "mps" | |
logger.info("mps is available, running on mps.") | |
elif not torch.cuda.is_available(): | |
logger.info("CUDA is not available, running on CPU.") | |
args.device = "cpu" | |
logger.info("Loading Llama model...") | |
llama_queue = launch_thread_safe_queue( | |
checkpoint_path=args.llama_checkpoint_path, | |
device=args.device, | |
precision=args.precision, | |
compile=args.compile, | |
) | |
logger.info("Loading VQ-GAN model...") | |
decoder_model = load_decoder_model( | |
config_name=args.decoder_config_name, | |
checkpoint_path=args.decoder_checkpoint_path, | |
device=args.device, | |
) | |
logger.info("Decoder model loaded, warming up...") | |
# Create the inference engine | |
inference_engine = TTSInferenceEngine( | |
llama_queue=llama_queue, | |
decoder_model=decoder_model, | |
compile=args.compile, | |
precision=args.precision, | |
) | |
# Dry run to check if the model is loaded correctly and avoid the first-time latency | |
list( | |
inference_engine.inference( | |
ServeTTSRequest( | |
text="Hello world.", | |
references=[], | |
reference_id=None, | |
max_new_tokens=1024, | |
chunk_length=200, | |
top_p=0.7, | |
repetition_penalty=1.5, | |
temperature=0.7, | |
format="wav", | |
) | |
) | |
) | |
logger.info("Warming up done, launching the web UI...") | |
# Get the inference function with the immutable arguments | |
inference_fct = get_inference_wrapper(inference_engine) | |
app = build_app(inference_fct, args.theme) | |
app.launch(show_api=True, share=True) | |