TangoFlux-ONNX-RKNN2 / inference.py
happyme531's picture
Upload 13 files
f3a1217 verified
import numpy as np
# import onnxruntime as ort
import ztu_somemodelruntime_rknnlite2 as ort
import sentencepiece as spm
import soundfile as sf
ort.set_default_logger_verbosity(0)
def load_onnx_model(model_path):
"""加载ONNX模型"""
return ort.InferenceSession(
model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
class SimpleT5Tokenizer:
def __init__(self, model_path, max_length=128):
"""初始化tokenizer
Args:
model_path: sentencepiece模型路径
max_length: 序列最大长度,默认128
"""
self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_path)
# T5特殊token的ID
self.pad_token_id = 0
self.eos_token_id = 1
self.max_length = max_length
def __call__(self, texts, padding=True, truncation=True, max_length=None, return_tensors="np"):
"""处理文本序列
Args:
texts: 文本或文本列表
padding: 是否padding
truncation: 是否截断
max_length: 可选,覆盖默认max_length
return_tensors: 返回类型(只支持"np")
Returns:
dict: 包含input_ids和attention_mask
"""
if isinstance(texts, str):
texts = [texts]
max_len = max_length if max_length is not None else self.max_length
# 分词并转换为ID
input_ids = []
attention_mask = []
for text in texts:
ids = self.sp.EncodeAsIds(text)
# 截断处理(预留EOS token位置)
if truncation and len(ids) > max_len - 1:
ids = ids[:max_len-1]
ids.append(self.eos_token_id)
# 创建attention mask
mask = [1] * len(ids)
# Padding处理
if padding:
pad_length = max_len - len(ids)
ids.extend([self.pad_token_id] * pad_length)
mask.extend([0] * pad_length)
input_ids.append(ids)
attention_mask.append(mask)
# 转换为numpy array
input_ids = np.array(input_ids, dtype=np.int64)
attention_mask = np.array(attention_mask, dtype=np.int64)
return {
"input_ids": input_ids,
"attention_mask": attention_mask
}
def encode_text(prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=None):
"""编码文本,同时处理条件和无条件文本
Args:
prompt: 文本提示
tokenizer: T5 tokenizer
text_encoder_onnx: T5 ONNX模型
guidance_scale: 引导系数
"""
if not isinstance(prompt, list):
prompt = [prompt]
if guidance_scale is not None and guidance_scale > 1.0:
# 同时处理条件和无条件文本
all_prompts = [negative_prompt] + prompt
batch = tokenizer(
all_prompts,
padding=True,
truncation=True,
return_tensors="np"
)
# ONNX推理
all_hidden_states = text_encoder_onnx.run(
['last_hidden_state'],
{
'input_ids': batch['input_ids'].astype(np.int64),
'attention_mask': batch['attention_mask'].astype(np.int64)
}
)[0]
# 分离无条件和条件结果
uncond_hidden_states = all_hidden_states[0:1]
cond_hidden_states = all_hidden_states[1:]
uncond_mask = batch['attention_mask'][0:1]
cond_mask = batch['attention_mask'][1:]
return (uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask)
else:
# 只处理条件文本
batch = tokenizer(
prompt,
padding=True,
truncation=True,
return_tensors="np"
)
# ONNX推理
hidden_states = text_encoder_onnx.run(
['last_hidden_state'],
{
'input_ids': batch['input_ids'].astype(np.int64),
'attention_mask': batch['attention_mask'].astype(np.int64)
}
)[0]
return hidden_states, batch['attention_mask']
def retrieve_timesteps(scheduler, num_inference_steps, device, timesteps=None, sigmas=None):
"""获取timesteps"""
if sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# 添加一个简单的FlowMatchScheduler类
class SimpleFlowMatchScheduler:
def __init__(self, num_train_timesteps=1000, shift=1.0):
"""初始化scheduler
Args:
num_train_timesteps: 训练步数
shift: 时间步偏移量
"""
# 生成线性timesteps
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
# 计算sigmas
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
# 添加终止sigma
self.sigmas = np.append(sigmas, 0.0)
self.timesteps = sigmas * num_train_timesteps
self.step_index = None
def set_timesteps(self, num_inference_steps):
"""设置推理时的timesteps
Args:
num_inference_steps: 推理步数
"""
timesteps = np.linspace(1, len(self.timesteps), num_inference_steps, dtype=np.float32)[::-1].copy()
sigmas = timesteps / len(self.timesteps)
self.sigmas = np.append(sigmas, 0.0)
self.timesteps = sigmas * len(self.timesteps)
self.step_index = 0
def step(self, model_output, timestep, sample):
"""执行一步euler更新
Args:
model_output: 模型输出
timestep: 当前时间步
sample: 当前样本
Returns:
prev_sample: 更新后的样本
"""
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
# euler更新
prev_sample = sample + (sigma_next - sigma) * model_output
self.step_index += 1
return prev_sample
def generate_audio_onnx(
prompt="",
negative_prompt="",
duration=10,
steps=50,
guidance_scale=4.5,
onnx_dir="./onnx_models",
output_path="output_onnx.wav",
seed=None
):
if seed is not None:
np.random.seed(seed)
# 加载tokenizer和ONNX模型,设置固定长度
tokenizer = SimpleT5Tokenizer(f"{onnx_dir}/spiece.model", max_length=63)
text_encoder_onnx = load_onnx_model(f"{onnx_dir}/text_encoder_nf4.onnx")
# 加载其他ONNX模型
vae_decoder = load_onnx_model(f"{onnx_dir}/vae_decoder.onnx")
duration_embedder = load_onnx_model(f"{onnx_dir}/duration_embedder.onnx")
transformer = load_onnx_model(f"{onnx_dir}/transformer.onnx")
proj_layer = load_onnx_model(f"{onnx_dir}/proj.onnx")
# 1. duration embedding
duration_input = np.array([[duration]], dtype=np.float32)
print(f"[Shape] duration输入: {duration_input.shape}")
duration_hidden_states = duration_embedder.run(
['embedding'],
{'duration': duration_input}
)[0]
print(f"[Shape] duration embedding: {duration_hidden_states.shape}")
if guidance_scale > 1.0:
duration_hidden_states = np.concatenate([duration_hidden_states] * 2, axis=0)
print(f"[Shape] 复制后的duration embedding: {duration_hidden_states.shape}")
# 2. text encoder
if guidance_scale > 1.0:
(uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask) = encode_text(
prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=guidance_scale
)
print(cond_hidden_states)
encoder_hidden_states = np.concatenate([uncond_hidden_states, cond_hidden_states])
attention_mask = np.concatenate([uncond_mask, cond_mask])
else:
encoder_hidden_states, attention_mask = encode_text(
prompt, tokenizer, text_encoder_onnx
)
# 3. pooled_text
boolean_encoder_mask = (attention_mask == 1)
mask_expanded = boolean_encoder_mask[..., None].repeat(encoder_hidden_states.shape[-1], axis=-1)
masked_data = np.where(mask_expanded, encoder_hidden_states, np.nan)
pooled = np.nanmean(masked_data, axis=1)
# 使用projection层处理
pooled_text = proj_layer.run(
['projected'],
{'text_embedding': pooled.astype(np.float32)}
)[0]
# 4. 合并duration和text特征
encoder_hidden_states = np.concatenate(
[encoder_hidden_states, duration_hidden_states],
axis=1
)
# 5. 创建其他输入
txt_ids = np.zeros((1, encoder_hidden_states.shape[1], 3), dtype=np.int64)
img_ids = np.tile(
np.arange(645, dtype=np.int64)[None, :, None],
(1, 1, 3)
)
# 6. scheduler
scheduler = SimpleFlowMatchScheduler(num_train_timesteps=1000)
scheduler.set_timesteps(steps)
# 初始化latents
latents = np.random.randn(1, 645, 64).astype(np.float32)
# 7. 生成循环
for i in range(steps):
# Transformer前向传播
noise_pred = transformer.run(
['output'],
{
'hidden_states': latents,
'timestep': np.array([scheduler.timesteps[i]/1000], dtype=np.float32),
'pooled_text': pooled_text,
'encoder_hidden_states': encoder_hidden_states,
'txt_ids': txt_ids,
'img_ids': img_ids
}
)[0]
if i == 0: # 只在第一步打印
print(f"[Shape] noise预测输出: {noise_pred.shape}")
# 应用classifier free guidance
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred[0:1], noise_pred[1:2]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# 使用scheduler更新latents
latents = scheduler.step(noise_pred, scheduler.timesteps[i], latents)
if i % 10 == 0:
print(f"生成进度: {i}/{steps}")
# 8. VAE解码前的处理
latents = latents / scheduler.sigmas[0]
latents = np.transpose(latents, (0, 2, 1))
# 9. VAE解码
wave = vae_decoder.run(['audio'], {'latent': latents})[0]
# 10. 裁剪
sample_rate = 44100
waveform_end = int(duration * sample_rate)
wave = wave[:, :, :waveform_end]
print(f"[Shape] 裁剪后的最终波形: {wave.shape}")
# 11. 保存音频
wave = wave[0] # 移除batch维度
sf.write(output_path, wave.T, sample_rate) # soundfile需要(samples, channels)格式
return wave
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="测试ONNX模型推理")
parser.add_argument("--prompt", type=str, default="What does the fox say?", help="文本提示")
parser.add_argument("--negative_prompt", type=str, default="", help="负文本提示")
parser.add_argument("--onnx_dir", type=str, default=".", help="ONNX模型目录")
parser.add_argument("--duration", type=float, default=10.0, help="生成音频时长(秒)")
parser.add_argument("--steps", type=int, default=30, help="推理步数")
parser.add_argument("--guidance_scale", type=float, default=4.5, help="引导系数")
parser.add_argument("--output", type=str, default="output_onnx.wav", help="输出音频路径")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
args = parser.parse_args()
# 生成音频
wave = generate_audio_onnx(
# prompt="What does the fox say?",
# prompt="Never gonna give you up, never gonna let you down",
# prompt="Electonic music, future house style",
prompt=args.prompt,
negative_prompt=args.negative_prompt,
duration=args.duration,
steps=args.steps,
guidance_scale=args.guidance_scale,
onnx_dir=args.onnx_dir,
output_path=args.output,
seed=args.seed
)
print(f"生成的音频shape为: {wave.shape}")
print(f"音频已保存到: {args.output}")