|
import numpy as np |
|
import torch |
|
|
|
from modules.speaker import Speaker |
|
from modules.utils.SeedContext import SeedContext |
|
|
|
from modules import models, config |
|
|
|
import logging |
|
import gc |
|
|
|
from modules.devices import devices |
|
from typing import Union |
|
|
|
from modules.utils.cache import conditional_cache |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def generate_audio( |
|
text: str, |
|
temperature: float = 0.3, |
|
top_P: float = 0.7, |
|
top_K: float = 20, |
|
spk: Union[int, Speaker] = -1, |
|
infer_seed: int = -1, |
|
use_decoder: bool = True, |
|
prompt1: str = "", |
|
prompt2: str = "", |
|
prefix: str = "", |
|
): |
|
(sample_rate, wav) = generate_audio_batch( |
|
[text], |
|
temperature=temperature, |
|
top_P=top_P, |
|
top_K=top_K, |
|
spk=spk, |
|
infer_seed=infer_seed, |
|
use_decoder=use_decoder, |
|
prompt1=prompt1, |
|
prompt2=prompt2, |
|
prefix=prefix, |
|
)[0] |
|
|
|
return (sample_rate, wav) |
|
|
|
|
|
@torch.inference_mode() |
|
def generate_audio_batch( |
|
texts: list[str], |
|
temperature: float = 0.3, |
|
top_P: float = 0.7, |
|
top_K: float = 20, |
|
spk: Union[int, Speaker] = -1, |
|
infer_seed: int = -1, |
|
use_decoder: bool = True, |
|
prompt1: str = "", |
|
prompt2: str = "", |
|
prefix: str = "", |
|
): |
|
chat_tts = models.load_chat_tts() |
|
params_infer_code = { |
|
"spk_emb": None, |
|
"temperature": temperature, |
|
"top_P": top_P, |
|
"top_K": top_K, |
|
"prompt1": prompt1 or "", |
|
"prompt2": prompt2 or "", |
|
"prefix": prefix or "", |
|
"repetition_penalty": 1.0, |
|
"disable_tqdm": config.runtime_env_vars.off_tqdm, |
|
} |
|
|
|
if isinstance(spk, int): |
|
with SeedContext(spk, True): |
|
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker() |
|
logger.debug(("spk", spk)) |
|
elif isinstance(spk, Speaker): |
|
if not isinstance(spk.emb, torch.Tensor): |
|
raise ValueError("spk.pt is broken, please retrain the model.") |
|
params_infer_code["spk_emb"] = spk.emb |
|
logger.debug(("spk", spk.name)) |
|
else: |
|
logger.warn( |
|
f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice" |
|
) |
|
with SeedContext(2, True): |
|
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker() |
|
|
|
logger.debug( |
|
{ |
|
"text": texts, |
|
"infer_seed": infer_seed, |
|
"temperature": temperature, |
|
"top_P": top_P, |
|
"top_K": top_K, |
|
"prompt1": prompt1 or "", |
|
"prompt2": prompt2 or "", |
|
"prefix": prefix or "", |
|
} |
|
) |
|
|
|
with SeedContext(infer_seed, True): |
|
wavs = chat_tts.generate_audio( |
|
texts, params_infer_code, use_decoder=use_decoder |
|
) |
|
|
|
sample_rate = 24000 |
|
|
|
if config.auto_gc: |
|
devices.torch_gc() |
|
gc.collect() |
|
|
|
return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs] |
|
|
|
|
|
lru_cache_enabled = False |
|
|
|
|
|
def setup_lru_cache(): |
|
global generate_audio_batch |
|
global lru_cache_enabled |
|
|
|
if lru_cache_enabled: |
|
return |
|
lru_cache_enabled = True |
|
|
|
def should_cache(*args, **kwargs): |
|
spk_seed = kwargs.get("spk", -1) |
|
infer_seed = kwargs.get("infer_seed", -1) |
|
return spk_seed != -1 and infer_seed != -1 |
|
|
|
lru_size = config.runtime_env_vars.lru_size |
|
if isinstance(lru_size, int): |
|
generate_audio_batch = conditional_cache(lru_size, should_cache)( |
|
generate_audio_batch |
|
) |
|
logger.info(f"LRU cache enabled with size {lru_size}") |
|
else: |
|
logger.debug(f"LRU cache failed to enable, invalid size {lru_size}") |
|
|
|
|
|
if __name__ == "__main__": |
|
import soundfile as sf |
|
|
|
|
|
inputs = ["你好[lbreak]", "再见[lbreak]", "长度不同的文本片段[lbreak]"] |
|
outputs = generate_audio_batch(inputs, spk=5, infer_seed=42) |
|
|
|
for i, (sample_rate, wav) in enumerate(outputs): |
|
print(i, sample_rate, wav.shape) |
|
|
|
sf.write(f"batch_{i}.wav", wav, sample_rate, format="wav") |
|
|
|
|
|
for i, text in enumerate(inputs): |
|
sample_rate, wav = generate_audio(text, spk=5, infer_seed=42) |
|
print(i, sample_rate, wav.shape) |
|
|
|
sf.write(f"one_{i}.wav", wav, sample_rate, format="wav") |
|
|