auto / run_local.py
zeroMN's picture
Upload 9 files
b744e9c verified
raw
history blame
2.02 kB
import os
import torch
from model import AutoModel, Config
def load_model(model_path, config_path):
"""
加载模型权重和配置
"""
# 加载配置
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件未找到: {config_path}")
print(f"加载配置文件: {config_path}")
config = Config()
# 初始化模型
model = AutoModel(config)
# 加载权重
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件未找到: {model_path}")
print(f"加载模型权重: {model_path}")
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()
print("模型加载成功并设置为评估模式。")
return model, config
def run_inference(model, config):
"""
使用模型运行推理
"""
# 模拟示例输入
image = torch.randn(1, 3, 224, 224) # 图像输入
text = torch.randn(1, config.max_position_embeddings, config.hidden_size) # 文本输入
audio = torch.randn(1, config.audio_sample_rate) # 音频输入
# 模型推理
outputs = model(image, text, audio)
vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output = outputs
# 打印结果
print("\n推理结果:")
print(f"VQA output shape: {vqa_output.shape}")
print(f"Caption output shape: {caption_output.shape}")
print(f"Retrieval output shape: {retrieval_output.shape}")
print(f"ASR output shape: {asr_output.shape}")
print(f"Realtime ASR output shape: {realtime_asr_output.shape}")
if __name__ == "__main__":
# 文件路径
model_path = "AutoModel.pth"
config_path = "config.json"
# 加载模型
try:
model, config = load_model(model_path, config_path)
# 运行推理
run_inference(model, config)
except Exception as e:
print(f"运行失败: {e}")