|
import os
|
|
import torch
|
|
from model import AutoModel, Config
|
|
|
|
def load_model(model_path, config_path):
|
|
"""
|
|
加载模型权重和配置
|
|
"""
|
|
|
|
if not os.path.exists(config_path):
|
|
raise FileNotFoundError(f"配置文件未找到: {config_path}")
|
|
print(f"加载配置文件: {config_path}")
|
|
config = Config()
|
|
|
|
|
|
model = AutoModel(config)
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
|
print(f"加载模型权重: {model_path}")
|
|
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
print("模型加载成功并设置为评估模式。")
|
|
|
|
return model, config
|
|
|
|
|
|
def run_inference(model, config):
|
|
"""
|
|
使用模型运行推理
|
|
"""
|
|
|
|
image = torch.randn(1, 3, 224, 224)
|
|
text = torch.randn(1, config.max_position_embeddings, config.hidden_size)
|
|
audio = torch.randn(1, config.audio_sample_rate)
|
|
|
|
|
|
outputs = model(image, text, audio)
|
|
vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output = outputs
|
|
|
|
|
|
print("\n推理结果:")
|
|
print(f"VQA output shape: {vqa_output.shape}")
|
|
print(f"Caption output shape: {caption_output.shape}")
|
|
print(f"Retrieval output shape: {retrieval_output.shape}")
|
|
print(f"ASR output shape: {asr_output.shape}")
|
|
print(f"Realtime ASR output shape: {realtime_asr_output.shape}")
|
|
|
|
if __name__ == "__main__":
|
|
|
|
model_path = "AutoModel.pth"
|
|
config_path = "config.json"
|
|
|
|
|
|
try:
|
|
model, config = load_model(model_path, config_path)
|
|
|
|
|
|
run_inference(model, config)
|
|
except Exception as e:
|
|
print(f"运行失败: {e}")
|
|
|