Spaces:
Build error
Build error
import os | |
import runpod | |
from main import initialize_agent | |
# 初始化agent(可以根据需要选择要使用的工具) | |
selected_tools = [ | |
"ImageVisualizerTool", | |
"DicomProcessorTool", | |
"ChestXRayClassifierTool", | |
"ChestXRaySegmentationTool", | |
"ChestXRayReportGeneratorTool", | |
"XRayVQATool", | |
] | |
agent, tools_dict = initialize_agent( | |
"medrax/docs/system_prompts.txt", | |
tools_to_use=selected_tools, | |
model_dir="/model-weights" | |
) | |
def handler(event): | |
""" | |
处理RunPod API请求的主函数 | |
""" | |
try: | |
# 获取请求参数 | |
job_input = event["input"] | |
# 验证必需的参数 | |
if "image" not in job_input: | |
return {"error": "Missing required parameter: image"} | |
if "task" not in job_input: | |
return {"error": "Missing required parameter: task"} | |
image_data = job_input["image"] # 这里假设是base64编码的图像 | |
task = job_input["task"] # 任务类型 | |
# 根据任务类型调用相应的工具 | |
if task == "classification": | |
result = tools_dict["ChestXRayClassifierTool"].run(image_data) | |
elif task == "segmentation": | |
result = tools_dict["ChestXRaySegmentationTool"].run(image_data) | |
elif task == "report": | |
result = tools_dict["ChestXRayReportGeneratorTool"].run(image_data) | |
else: | |
return {"error": f"Unsupported task type: {task}"} | |
return { | |
"status": "success", | |
"result": result | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
runpod.serverless.start({"handler": handler}) | |