|
import numpy as np |
|
|
|
import ztu_somemodelruntime_rknnlite2 as ort |
|
import sentencepiece as spm |
|
import soundfile as sf |
|
|
|
ort.set_default_logger_verbosity(0) |
|
|
|
def load_onnx_model(model_path): |
|
"""加载ONNX模型""" |
|
return ort.InferenceSession( |
|
model_path, |
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
) |
|
|
|
class SimpleT5Tokenizer: |
|
def __init__(self, model_path, max_length=128): |
|
"""初始化tokenizer |
|
|
|
Args: |
|
model_path: sentencepiece模型路径 |
|
max_length: 序列最大长度,默认128 |
|
""" |
|
self.sp = spm.SentencePieceProcessor() |
|
self.sp.Load(model_path) |
|
|
|
|
|
self.pad_token_id = 0 |
|
self.eos_token_id = 1 |
|
self.max_length = max_length |
|
|
|
def __call__(self, texts, padding=True, truncation=True, max_length=None, return_tensors="np"): |
|
"""处理文本序列 |
|
|
|
Args: |
|
texts: 文本或文本列表 |
|
padding: 是否padding |
|
truncation: 是否截断 |
|
max_length: 可选,覆盖默认max_length |
|
return_tensors: 返回类型(只支持"np") |
|
|
|
Returns: |
|
dict: 包含input_ids和attention_mask |
|
""" |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
max_len = max_length if max_length is not None else self.max_length |
|
|
|
|
|
input_ids = [] |
|
attention_mask = [] |
|
for text in texts: |
|
ids = self.sp.EncodeAsIds(text) |
|
|
|
|
|
if truncation and len(ids) > max_len - 1: |
|
ids = ids[:max_len-1] |
|
ids.append(self.eos_token_id) |
|
|
|
|
|
mask = [1] * len(ids) |
|
|
|
|
|
if padding: |
|
pad_length = max_len - len(ids) |
|
ids.extend([self.pad_token_id] * pad_length) |
|
mask.extend([0] * pad_length) |
|
|
|
input_ids.append(ids) |
|
attention_mask.append(mask) |
|
|
|
|
|
input_ids = np.array(input_ids, dtype=np.int64) |
|
attention_mask = np.array(attention_mask, dtype=np.int64) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask |
|
} |
|
|
|
def encode_text(prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=None): |
|
"""编码文本,同时处理条件和无条件文本 |
|
|
|
Args: |
|
prompt: 文本提示 |
|
tokenizer: T5 tokenizer |
|
text_encoder_onnx: T5 ONNX模型 |
|
guidance_scale: 引导系数 |
|
""" |
|
if not isinstance(prompt, list): |
|
prompt = [prompt] |
|
|
|
if guidance_scale is not None and guidance_scale > 1.0: |
|
|
|
all_prompts = [negative_prompt] + prompt |
|
batch = tokenizer( |
|
all_prompts, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="np" |
|
) |
|
|
|
|
|
all_hidden_states = text_encoder_onnx.run( |
|
['last_hidden_state'], |
|
{ |
|
'input_ids': batch['input_ids'].astype(np.int64), |
|
'attention_mask': batch['attention_mask'].astype(np.int64) |
|
} |
|
)[0] |
|
|
|
|
|
uncond_hidden_states = all_hidden_states[0:1] |
|
cond_hidden_states = all_hidden_states[1:] |
|
uncond_mask = batch['attention_mask'][0:1] |
|
cond_mask = batch['attention_mask'][1:] |
|
|
|
return (uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask) |
|
else: |
|
|
|
batch = tokenizer( |
|
prompt, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="np" |
|
) |
|
|
|
|
|
hidden_states = text_encoder_onnx.run( |
|
['last_hidden_state'], |
|
{ |
|
'input_ids': batch['input_ids'].astype(np.int64), |
|
'attention_mask': batch['attention_mask'].astype(np.int64) |
|
} |
|
)[0] |
|
|
|
return hidden_states, batch['attention_mask'] |
|
|
|
def retrieve_timesteps(scheduler, num_inference_steps, device, timesteps=None, sigmas=None): |
|
"""获取timesteps""" |
|
if sigmas is not None: |
|
scheduler.set_timesteps(sigmas=sigmas) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
class SimpleFlowMatchScheduler: |
|
def __init__(self, num_train_timesteps=1000, shift=1.0): |
|
"""初始化scheduler |
|
|
|
Args: |
|
num_train_timesteps: 训练步数 |
|
shift: 时间步偏移量 |
|
""" |
|
|
|
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() |
|
|
|
|
|
sigmas = timesteps / num_train_timesteps |
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) |
|
|
|
|
|
self.sigmas = np.append(sigmas, 0.0) |
|
self.timesteps = sigmas * num_train_timesteps |
|
self.step_index = None |
|
|
|
def set_timesteps(self, num_inference_steps): |
|
"""设置推理时的timesteps |
|
|
|
Args: |
|
num_inference_steps: 推理步数 |
|
""" |
|
timesteps = np.linspace(1, len(self.timesteps), num_inference_steps, dtype=np.float32)[::-1].copy() |
|
sigmas = timesteps / len(self.timesteps) |
|
self.sigmas = np.append(sigmas, 0.0) |
|
self.timesteps = sigmas * len(self.timesteps) |
|
self.step_index = 0 |
|
|
|
def step(self, model_output, timestep, sample): |
|
"""执行一步euler更新 |
|
|
|
Args: |
|
model_output: 模型输出 |
|
timestep: 当前时间步 |
|
sample: 当前样本 |
|
|
|
Returns: |
|
prev_sample: 更新后的样本 |
|
""" |
|
sigma = self.sigmas[self.step_index] |
|
sigma_next = self.sigmas[self.step_index + 1] |
|
|
|
|
|
prev_sample = sample + (sigma_next - sigma) * model_output |
|
|
|
self.step_index += 1 |
|
return prev_sample |
|
|
|
def generate_audio_onnx( |
|
prompt="", |
|
negative_prompt="", |
|
duration=10, |
|
steps=50, |
|
guidance_scale=4.5, |
|
onnx_dir="./onnx_models", |
|
output_path="output_onnx.wav", |
|
seed=None |
|
): |
|
if seed is not None: |
|
np.random.seed(seed) |
|
|
|
|
|
tokenizer = SimpleT5Tokenizer(f"{onnx_dir}/spiece.model", max_length=63) |
|
text_encoder_onnx = load_onnx_model(f"{onnx_dir}/text_encoder_nf4.onnx") |
|
|
|
|
|
vae_decoder = load_onnx_model(f"{onnx_dir}/vae_decoder.onnx") |
|
duration_embedder = load_onnx_model(f"{onnx_dir}/duration_embedder.onnx") |
|
transformer = load_onnx_model(f"{onnx_dir}/transformer.onnx") |
|
proj_layer = load_onnx_model(f"{onnx_dir}/proj.onnx") |
|
|
|
|
|
duration_input = np.array([[duration]], dtype=np.float32) |
|
print(f"[Shape] duration输入: {duration_input.shape}") |
|
|
|
duration_hidden_states = duration_embedder.run( |
|
['embedding'], |
|
{'duration': duration_input} |
|
)[0] |
|
print(f"[Shape] duration embedding: {duration_hidden_states.shape}") |
|
|
|
if guidance_scale > 1.0: |
|
duration_hidden_states = np.concatenate([duration_hidden_states] * 2, axis=0) |
|
print(f"[Shape] 复制后的duration embedding: {duration_hidden_states.shape}") |
|
|
|
|
|
if guidance_scale > 1.0: |
|
(uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask) = encode_text( |
|
prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=guidance_scale |
|
) |
|
print(cond_hidden_states) |
|
encoder_hidden_states = np.concatenate([uncond_hidden_states, cond_hidden_states]) |
|
attention_mask = np.concatenate([uncond_mask, cond_mask]) |
|
else: |
|
encoder_hidden_states, attention_mask = encode_text( |
|
prompt, tokenizer, text_encoder_onnx |
|
) |
|
|
|
|
|
boolean_encoder_mask = (attention_mask == 1) |
|
mask_expanded = boolean_encoder_mask[..., None].repeat(encoder_hidden_states.shape[-1], axis=-1) |
|
masked_data = np.where(mask_expanded, encoder_hidden_states, np.nan) |
|
pooled = np.nanmean(masked_data, axis=1) |
|
|
|
|
|
pooled_text = proj_layer.run( |
|
['projected'], |
|
{'text_embedding': pooled.astype(np.float32)} |
|
)[0] |
|
|
|
|
|
encoder_hidden_states = np.concatenate( |
|
[encoder_hidden_states, duration_hidden_states], |
|
axis=1 |
|
) |
|
|
|
|
|
txt_ids = np.zeros((1, encoder_hidden_states.shape[1], 3), dtype=np.int64) |
|
img_ids = np.tile( |
|
np.arange(645, dtype=np.int64)[None, :, None], |
|
(1, 1, 3) |
|
) |
|
|
|
|
|
scheduler = SimpleFlowMatchScheduler(num_train_timesteps=1000) |
|
scheduler.set_timesteps(steps) |
|
|
|
|
|
latents = np.random.randn(1, 645, 64).astype(np.float32) |
|
|
|
|
|
for i in range(steps): |
|
|
|
noise_pred = transformer.run( |
|
['output'], |
|
{ |
|
'hidden_states': latents, |
|
'timestep': np.array([scheduler.timesteps[i]/1000], dtype=np.float32), |
|
'pooled_text': pooled_text, |
|
'encoder_hidden_states': encoder_hidden_states, |
|
'txt_ids': txt_ids, |
|
'img_ids': img_ids |
|
} |
|
)[0] |
|
|
|
if i == 0: |
|
print(f"[Shape] noise预测输出: {noise_pred.shape}") |
|
|
|
|
|
if guidance_scale > 1.0: |
|
noise_pred_uncond, noise_pred_text = noise_pred[0:1], noise_pred[1:2] |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = scheduler.step(noise_pred, scheduler.timesteps[i], latents) |
|
|
|
if i % 10 == 0: |
|
print(f"生成进度: {i}/{steps}") |
|
|
|
|
|
latents = latents / scheduler.sigmas[0] |
|
latents = np.transpose(latents, (0, 2, 1)) |
|
|
|
|
|
wave = vae_decoder.run(['audio'], {'latent': latents})[0] |
|
|
|
|
|
sample_rate = 44100 |
|
waveform_end = int(duration * sample_rate) |
|
wave = wave[:, :, :waveform_end] |
|
print(f"[Shape] 裁剪后的最终波形: {wave.shape}") |
|
|
|
|
|
wave = wave[0] |
|
sf.write(output_path, wave.T, sample_rate) |
|
|
|
return wave |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="测试ONNX模型推理") |
|
parser.add_argument("--prompt", type=str, default="What does the fox say?", help="文本提示") |
|
parser.add_argument("--negative_prompt", type=str, default="", help="负文本提示") |
|
parser.add_argument("--onnx_dir", type=str, default=".", help="ONNX模型目录") |
|
parser.add_argument("--duration", type=float, default=10.0, help="生成音频时长(秒)") |
|
parser.add_argument("--steps", type=int, default=30, help="推理步数") |
|
parser.add_argument("--guidance_scale", type=float, default=4.5, help="引导系数") |
|
parser.add_argument("--output", type=str, default="output_onnx.wav", help="输出音频路径") |
|
parser.add_argument("--seed", type=int, default=42, help="随机种子") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
wave = generate_audio_onnx( |
|
|
|
|
|
|
|
prompt=args.prompt, |
|
negative_prompt=args.negative_prompt, |
|
duration=args.duration, |
|
steps=args.steps, |
|
guidance_scale=args.guidance_scale, |
|
onnx_dir=args.onnx_dir, |
|
output_path=args.output, |
|
seed=args.seed |
|
) |
|
|
|
print(f"生成的音频shape为: {wave.shape}") |
|
print(f"音频已保存到: {args.output}") |