|
import numpy as np |
|
import torch |
|
|
|
from modules.utils.SeedContext import SeedContext |
|
|
|
from modules import models, config |
|
|
|
|
|
@torch.inference_mode() |
|
def refine_text( |
|
text: str, |
|
prompt="[oral_2][laugh_0][break_6]", |
|
seed=-1, |
|
top_P=0.7, |
|
top_K=20, |
|
temperature=0.7, |
|
repetition_penalty=1.0, |
|
max_new_token=384, |
|
) -> str: |
|
chat_tts = models.load_chat_tts() |
|
|
|
with SeedContext(seed): |
|
refined_text = chat_tts.refiner_prompt( |
|
text, |
|
{ |
|
"prompt": prompt, |
|
"top_K": top_K, |
|
"top_P": top_P, |
|
"temperature": temperature, |
|
"repetition_penalty": repetition_penalty, |
|
"max_new_token": max_new_token, |
|
"disable_tqdm": config.disable_tqdm, |
|
}, |
|
do_text_normalization=False, |
|
) |
|
return refined_text |
|
|