|
import contextlib |
|
import functools |
|
import hashlib |
|
import logging |
|
import os |
|
|
|
import requests |
|
import torch |
|
import tqdm |
|
|
|
from TTS.tts.layers.bark.model import GPT, GPTConfig |
|
from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig |
|
|
|
if ( |
|
torch.cuda.is_available() |
|
and hasattr(torch.cuda, "amp") |
|
and hasattr(torch.cuda.amp, "autocast") |
|
and torch.cuda.is_bf16_supported() |
|
): |
|
autocast = functools.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) |
|
else: |
|
|
|
@contextlib.contextmanager |
|
def autocast(): |
|
yield |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): |
|
logger.warning( |
|
"torch version does not support flash attention. You will get significantly faster" |
|
+ " inference speed by upgrade torch to newest version / nightly." |
|
) |
|
|
|
|
|
def _md5(fname): |
|
hash_md5 = hashlib.md5() |
|
with open(fname, "rb") as f: |
|
for chunk in iter(lambda: f.read(4096), b""): |
|
hash_md5.update(chunk) |
|
return hash_md5.hexdigest() |
|
|
|
|
|
def _download(from_s3_path, to_local_path, CACHE_DIR): |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
response = requests.get(from_s3_path, stream=True) |
|
total_size_in_bytes = int(response.headers.get("content-length", 0)) |
|
block_size = 1024 |
|
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) |
|
with open(to_local_path, "wb") as file: |
|
for data in response.iter_content(block_size): |
|
progress_bar.update(len(data)) |
|
file.write(data) |
|
progress_bar.close() |
|
if total_size_in_bytes not in [0, progress_bar.n]: |
|
raise ValueError("ERROR, something went wrong") |
|
|
|
|
|
class InferenceContext: |
|
def __init__(self, benchmark=False): |
|
|
|
self._chosen_cudnn_benchmark = benchmark |
|
self._cudnn_benchmark = None |
|
|
|
def __enter__(self): |
|
self._cudnn_benchmark = torch.backends.cudnn.benchmark |
|
torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark |
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback): |
|
torch.backends.cudnn.benchmark = self._cudnn_benchmark |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
@contextlib.contextmanager |
|
def inference_mode(): |
|
with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): |
|
yield |
|
|
|
|
|
def clear_cuda_cache(): |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.synchronize() |
|
|
|
|
|
def load_model(ckpt_path, device, config, model_type="text"): |
|
logger.info(f"loading {model_type} model from {ckpt_path}...") |
|
|
|
if device == "cpu": |
|
logger.warning("No GPU being used. Careful, Inference might be extremely slow!") |
|
if model_type == "text": |
|
ConfigClass = GPTConfig |
|
ModelClass = GPT |
|
elif model_type == "coarse": |
|
ConfigClass = GPTConfig |
|
ModelClass = GPT |
|
elif model_type == "fine": |
|
ConfigClass = FineGPTConfig |
|
ModelClass = FineGPT |
|
else: |
|
raise NotImplementedError() |
|
if ( |
|
not config.USE_SMALLER_MODELS |
|
and os.path.exists(ckpt_path) |
|
and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"] |
|
): |
|
logger.warning(f"found outdated {model_type} model, removing...") |
|
os.remove(ckpt_path) |
|
if not os.path.exists(ckpt_path): |
|
logger.info(f"{model_type} model not found, downloading...") |
|
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) |
|
|
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
|
|
model_args = checkpoint["model_args"] |
|
if "input_vocab_size" not in model_args: |
|
model_args["input_vocab_size"] = model_args["vocab_size"] |
|
model_args["output_vocab_size"] = model_args["vocab_size"] |
|
del model_args["vocab_size"] |
|
|
|
gptconf = ConfigClass(**checkpoint["model_args"]) |
|
if model_type == "text": |
|
config.semantic_config = gptconf |
|
elif model_type == "coarse": |
|
config.coarse_config = gptconf |
|
elif model_type == "fine": |
|
config.fine_config = gptconf |
|
|
|
model = ModelClass(gptconf) |
|
state_dict = checkpoint["model"] |
|
|
|
unwanted_prefix = "_orig_mod." |
|
for k, _ in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) |
|
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) |
|
extra_keys = set(k for k in extra_keys if not k.endswith(".attn.bias")) |
|
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) |
|
missing_keys = set(k for k in missing_keys if not k.endswith(".attn.bias")) |
|
if len(extra_keys) != 0: |
|
raise ValueError(f"extra keys found: {extra_keys}") |
|
if len(missing_keys) != 0: |
|
raise ValueError(f"missing keys: {missing_keys}") |
|
model.load_state_dict(state_dict, strict=False) |
|
n_params = model.get_num_params() |
|
val_loss = checkpoint["best_val_loss"].item() |
|
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") |
|
model.eval() |
|
model.to(device) |
|
del checkpoint, state_dict |
|
clear_cuda_cache() |
|
return model, config |
|
|