import os import torch from model import AutoModel, Config def load_model(model_path, config_path): """ 加载模型权重和配置 """ # 加载配置 if not os.path.exists(config_path): raise FileNotFoundError(f"配置文件未找到: {config_path}") print(f"加载配置文件: {config_path}") config = Config() # 初始化模型 model = AutoModel(config) # 加载权重 if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件未找到: {model_path}") print(f"加载模型权重: {model_path}") state_dict = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(state_dict) model.eval() print("模型加载成功并设置为评估模式。") return model, config def run_inference(model, config): """ 使用模型运行推理 """ # 模拟示例输入 image = torch.randn(1, 3, 224, 224) # 图像输入 text = torch.randn(1, config.max_position_embeddings, config.hidden_size) # 文本输入 audio = torch.randn(1, config.audio_sample_rate) # 音频输入 # 模型推理 outputs = model(image, text, audio) vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output = outputs # 打印结果 print("\n推理结果:") print(f"VQA output shape: {vqa_output.shape}") print(f"Caption output shape: {caption_output.shape}") print(f"Retrieval output shape: {retrieval_output.shape}") print(f"ASR output shape: {asr_output.shape}") print(f"Realtime ASR output shape: {realtime_asr_output.shape}") if __name__ == "__main__": # 文件路径 model_path = "AutoModel.pth" config_path = "config.json" # 加载模型 try: model, config = load_model(model_path, config_path) # 运行推理 run_inference(model, config) except Exception as e: print(f"运行失败: {e}")