|
import torch |
|
import torch.nn as nn |
|
from diffusers import AutoencoderOobleck |
|
from diffusers import FluxTransformer2DModel |
|
from tangoflux import TangoFluxInference |
|
from tangoflux.model import DurationEmbedder, TangoFlux |
|
|
|
def export_vae_encoder(vae, save_path, batch_size=1, audio_length=441000): |
|
"""导出VAE编码器到ONNX格式 |
|
|
|
Args: |
|
vae: AutoencoderOobleck实例 |
|
save_path: 保存路径 |
|
batch_size: batch大小 |
|
audio_length: 音频长度(默认10秒,44100Hz采样率) |
|
""" |
|
vae.eval() |
|
|
|
|
|
dummy_input = torch.randn(batch_size, 2, audio_length) |
|
|
|
|
|
class VAEEncoderWrapper(nn.Module): |
|
def __init__(self, vae): |
|
super().__init__() |
|
self.vae = vae |
|
|
|
def forward(self, audio): |
|
return self.vae.encode(audio).latent_dist.sample() |
|
|
|
wrapper = VAEEncoderWrapper(vae) |
|
|
|
|
|
torch.onnx.export( |
|
wrapper, |
|
dummy_input, |
|
save_path, |
|
input_names=['audio'], |
|
output_names=['latent'], |
|
dynamic_axes={ |
|
'audio': {0: 'batch_size', 2: 'audio_length'}, |
|
'latent': {0: 'batch_size', 2: 'latent_length'} |
|
}, |
|
opset_version=17 |
|
) |
|
|
|
def export_vae_decoder(vae, save_path, batch_size=1, latent_length=645): |
|
"""导出VAE解码器到ONNX格式 |
|
|
|
Args: |
|
vae: AutoencoderOobleck实例 |
|
save_path: 保存路径 |
|
batch_size: batch大小 |
|
latent_length: 潜在向量长度 |
|
""" |
|
vae.eval() |
|
|
|
|
|
dummy_input = torch.randn(batch_size, 64, latent_length) |
|
|
|
|
|
class VAEDecoderWrapper(nn.Module): |
|
def __init__(self, vae): |
|
super().__init__() |
|
self.vae = vae |
|
|
|
def forward(self, latent): |
|
return self.vae.decode(latent).sample |
|
|
|
wrapper = VAEDecoderWrapper(vae) |
|
|
|
|
|
torch.onnx.export( |
|
wrapper, |
|
dummy_input, |
|
save_path, |
|
input_names=['latent'], |
|
output_names=['audio'], |
|
dynamic_axes={ |
|
'latent': {0: 'batch_size', 2: 'latent_length'}, |
|
'audio': {0: 'batch_size', 2: 'audio_length'} |
|
}, |
|
opset_version=17 |
|
) |
|
|
|
def export_duration_embedder(duration_embedder, save_path, batch_size=1): |
|
"""导出Duration Embedder到ONNX格式 |
|
|
|
Args: |
|
duration_embedder: DurationEmbedder实例 |
|
save_path: 保存路径 |
|
batch_size: batch大小 |
|
""" |
|
duration_embedder.eval() |
|
|
|
|
|
dummy_input = torch.tensor([[10.0]], dtype=torch.float32) |
|
|
|
|
|
torch.onnx.export( |
|
duration_embedder, |
|
dummy_input, |
|
save_path, |
|
input_names=['duration'], |
|
output_names=['embedding'], |
|
dynamic_axes={ |
|
'duration': {0: 'batch_size'}, |
|
'embedding': {0: 'batch_size'} |
|
}, |
|
opset_version=17 |
|
) |
|
|
|
def export_flux_transformer(transformer, save_path, batch_size=1, seq_length=645): |
|
"""导出FluxTransformer2D到ONNX格式 |
|
|
|
Args: |
|
transformer: FluxTransformer2DModel实例 |
|
save_path: 保存路径 |
|
batch_size: batch大小 |
|
seq_length: 序列长度 |
|
""" |
|
transformer.eval() |
|
|
|
|
|
hidden_states = torch.randn(batch_size, seq_length, 64) |
|
timestep = torch.tensor([0.5]) |
|
pooled_text = torch.randn(batch_size, 1024) |
|
encoder_hidden_states = torch.randn(batch_size, 64, 1024) |
|
txt_ids = torch.zeros(batch_size, 64, 3).to(torch.int64) |
|
img_ids = torch.arange(seq_length).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 3).to(torch.int64) |
|
|
|
|
|
class TransformerWrapper(nn.Module): |
|
def __init__(self, transformer): |
|
super().__init__() |
|
self.transformer = transformer |
|
|
|
def forward(self, hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids): |
|
return self.transformer( |
|
hidden_states=hidden_states, |
|
timestep=timestep, |
|
guidance=None, |
|
pooled_projections=pooled_text, |
|
encoder_hidden_states=encoder_hidden_states, |
|
txt_ids=txt_ids, |
|
img_ids=img_ids, |
|
return_dict=False |
|
)[0] |
|
|
|
wrapper = TransformerWrapper(transformer) |
|
|
|
|
|
torch.onnx.export( |
|
wrapper, |
|
(hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids), |
|
save_path, |
|
input_names=['hidden_states', 'timestep', 'pooled_text', 'encoder_hidden_states', 'txt_ids', 'img_ids'], |
|
output_names=['output'], |
|
dynamic_axes={ |
|
'hidden_states': {0: 'batch_size', 1: 'sequence_length'}, |
|
'pooled_text': {0: 'batch_size'}, |
|
'encoder_hidden_states': {0: 'batch_size', 1: 'text_length'}, |
|
'txt_ids': {0: 'batch_size', 1: 'text_length'}, |
|
'img_ids': {0: 'batch_size', 1: 'sequence_length'} |
|
}, |
|
opset_version=17 |
|
) |
|
|
|
def export_proj_layer(proj_layer, save_path, batch_size=1): |
|
"""导出projection层到ONNX格式 |
|
|
|
Args: |
|
proj_layer: 投影层(fc层)实例 |
|
save_path: 保存路径 |
|
batch_size: batch大小 |
|
""" |
|
proj_layer.eval() |
|
|
|
|
|
dummy_input = torch.randn(batch_size, 1024) |
|
|
|
|
|
torch.onnx.export( |
|
proj_layer, |
|
dummy_input, |
|
save_path, |
|
input_names=['text_embedding'], |
|
output_names=['projected'], |
|
dynamic_axes={ |
|
'text_embedding': {0: 'batch_size'}, |
|
'projected': {0: 'batch_size'} |
|
}, |
|
opset_version=17 |
|
) |
|
|
|
def export_all(model_path, output_dir): |
|
"""导出所有组件到ONNX格式 |
|
|
|
Args: |
|
model_path: TangoFlux模型路径 |
|
output_dir: 输出目录 |
|
""" |
|
import os |
|
|
|
|
|
model = TangoFluxInference(name=model_path, device="cpu") |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
export_vae_encoder(model.vae, f"{output_dir}/vae_encoder.onnx") |
|
export_vae_decoder(model.vae, f"{output_dir}/vae_decoder.onnx") |
|
|
|
|
|
export_duration_embedder(model.model.duration_emebdder, f"{output_dir}/duration_embedder.onnx") |
|
|
|
|
|
export_flux_transformer(model.model.transformer, f"{output_dir}/transformer.onnx") |
|
|
|
|
|
export_proj_layer(model.model.fc, f"{output_dir}/proj.onnx") |
|
|
|
print(f"所有模型已导出到: {output_dir}") |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="导出TangoFlux模型到ONNX格式") |
|
parser.add_argument("--model_path", type=str, required=True, help="TangoFlux模型路径") |
|
parser.add_argument("--output_dir", type=str, required=True, help="输出目录") |
|
|
|
args = parser.parse_args() |
|
export_all(args.model_path, args.output_dir) |