import torch import torch.nn as nn from diffusers import AutoencoderOobleck from diffusers import FluxTransformer2DModel from tangoflux import TangoFluxInference from tangoflux.model import DurationEmbedder, TangoFlux def export_vae_encoder(vae, save_path, batch_size=1, audio_length=441000): """导出VAE编码器到ONNX格式 Args: vae: AutoencoderOobleck实例 save_path: 保存路径 batch_size: batch大小 audio_length: 音频长度(默认10秒,44100Hz采样率) """ vae.eval() # 创建dummy input - 注意这里是双声道音频 dummy_input = torch.randn(batch_size, 2, audio_length) # 创建一个包装类来处理forward调用 class VAEEncoderWrapper(nn.Module): def __init__(self, vae): super().__init__() self.vae = vae def forward(self, audio): return self.vae.encode(audio).latent_dist.sample() wrapper = VAEEncoderWrapper(vae) # 导出encoder部分 torch.onnx.export( wrapper, dummy_input, save_path, input_names=['audio'], output_names=['latent'], dynamic_axes={ 'audio': {0: 'batch_size', 2: 'audio_length'}, 'latent': {0: 'batch_size', 2: 'latent_length'} }, opset_version=17 ) def export_vae_decoder(vae, save_path, batch_size=1, latent_length=645): """导出VAE解码器到ONNX格式 Args: vae: AutoencoderOobleck实例 save_path: 保存路径 batch_size: batch大小 latent_length: 潜在向量长度 """ vae.eval() # 创建dummy input dummy_input = torch.randn(batch_size, 64, latent_length) # 创建一个包装类来处理forward调用 class VAEDecoderWrapper(nn.Module): def __init__(self, vae): super().__init__() self.vae = vae def forward(self, latent): return self.vae.decode(latent).sample wrapper = VAEDecoderWrapper(vae) # 导出decoder部分 torch.onnx.export( wrapper, dummy_input, save_path, input_names=['latent'], output_names=['audio'], dynamic_axes={ 'latent': {0: 'batch_size', 2: 'latent_length'}, 'audio': {0: 'batch_size', 2: 'audio_length'} }, opset_version=17 ) def export_duration_embedder(duration_embedder, save_path, batch_size=1): """导出Duration Embedder到ONNX格式 Args: duration_embedder: DurationEmbedder实例 save_path: 保存路径 batch_size: batch大小 """ duration_embedder.eval() # 创建dummy input - 注意这里是标量值 dummy_input = torch.tensor([[10.0]], dtype=torch.float32) # 10秒 # 导出 torch.onnx.export( duration_embedder, dummy_input, save_path, input_names=['duration'], output_names=['embedding'], dynamic_axes={ 'duration': {0: 'batch_size'}, 'embedding': {0: 'batch_size'} }, opset_version=17 ) def export_flux_transformer(transformer, save_path, batch_size=1, seq_length=645): """导出FluxTransformer2D到ONNX格式 Args: transformer: FluxTransformer2DModel实例 save_path: 保存路径 batch_size: batch大小 seq_length: 序列长度 """ transformer.eval() # 创建dummy inputs - 注意所有输入的形状 hidden_states = torch.randn(batch_size, seq_length, 64) # [B, S, C] timestep = torch.tensor([0.5]) # [1] pooled_text = torch.randn(batch_size, 1024) # [B, D] encoder_hidden_states = torch.randn(batch_size, 64, 1024) # [B, L, D] txt_ids = torch.zeros(batch_size, 64, 3).to(torch.int64) # [B, L, 3] img_ids = torch.arange(seq_length).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 3).to(torch.int64) # [B, S, 3] # 创建一个包装类来处理forward调用 class TransformerWrapper(nn.Module): def __init__(self, transformer): super().__init__() self.transformer = transformer def forward(self, hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids): return self.transformer( hidden_states=hidden_states, timestep=timestep, guidance=None, pooled_projections=pooled_text, encoder_hidden_states=encoder_hidden_states, txt_ids=txt_ids, img_ids=img_ids, return_dict=False )[0] wrapper = TransformerWrapper(transformer) # 导出 torch.onnx.export( wrapper, (hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids), save_path, input_names=['hidden_states', 'timestep', 'pooled_text', 'encoder_hidden_states', 'txt_ids', 'img_ids'], output_names=['output'], dynamic_axes={ 'hidden_states': {0: 'batch_size', 1: 'sequence_length'}, 'pooled_text': {0: 'batch_size'}, 'encoder_hidden_states': {0: 'batch_size', 1: 'text_length'}, 'txt_ids': {0: 'batch_size', 1: 'text_length'}, 'img_ids': {0: 'batch_size', 1: 'sequence_length'} }, opset_version=17 ) def export_proj_layer(proj_layer, save_path, batch_size=1): """导出projection层到ONNX格式 Args: proj_layer: 投影层(fc层)实例 save_path: 保存路径 batch_size: batch大小 """ proj_layer.eval() # 创建dummy input - 使用T5的hidden size dummy_input = torch.randn(batch_size, 1024) # T5-large hidden size # 导出 torch.onnx.export( proj_layer, dummy_input, save_path, input_names=['text_embedding'], output_names=['projected'], dynamic_axes={ 'text_embedding': {0: 'batch_size'}, 'projected': {0: 'batch_size'} }, opset_version=17 ) def export_all(model_path, output_dir): """导出所有组件到ONNX格式 Args: model_path: TangoFlux模型路径 output_dir: 输出目录 """ import os # 加载模型 model = TangoFluxInference(name=model_path, device="cpu") # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 导出VAE export_vae_encoder(model.vae, f"{output_dir}/vae_encoder.onnx") export_vae_decoder(model.vae, f"{output_dir}/vae_decoder.onnx") # 导出Duration Embedder export_duration_embedder(model.model.duration_emebdder, f"{output_dir}/duration_embedder.onnx") # 导出Transformer export_flux_transformer(model.model.transformer, f"{output_dir}/transformer.onnx") # 导出Projection层 export_proj_layer(model.model.fc, f"{output_dir}/proj.onnx") print(f"所有模型已导出到: {output_dir}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="导出TangoFlux模型到ONNX格式") parser.add_argument("--model_path", type=str, required=True, help="TangoFlux模型路径") parser.add_argument("--output_dir", type=str, required=True, help="输出目录") args = parser.parse_args() export_all(args.model_path, args.output_dir)