#!/usr/bin/python3 # -*- coding: utf-8 -*- from __future__ import annotations import asyncio import contextlib import pathlib import shutil import traceback import uuid from collections import deque from functools import partial from json import JSONDecodeError from typing import Dict import fire import openai import tenacity import uvicorn from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse, JSONResponse from fastapi.staticfiles import StaticFiles from loguru import logger from metagpt.config import CONFIG from metagpt.logs import set_llm_stream_logfunc from metagpt.schema import Message from openai import OpenAI from data_model import ( NewMsg, MessageJsonModel, Sentences, Sentence, SentenceType, SentenceValue, ThinkActPrompt, LLMAPIkeyTest, ThinkActStep, ) from message_enum import QueryAnswerType, MessageStatus from software_company import RoleRun, SoftwareCompany class Service: @classmethod async def create_message(cls, req_model: NewMsg, request: Request): """ Session message stream """ tc_id = 0 task = None try: exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} cls._set_context(config) msg_queue = deque() CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None role = SoftwareCompany() role.recv(message=Message(content=req_model.query)) answer = MessageJsonModel( steps=[ Sentences( contents=[ Sentence( type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True ) ], status=MessageStatus.COMPLETE.value, ) ], qa_type=QueryAnswerType.Answer.value, ) async def stop_if_disconnect(): while not await request.is_disconnected(): await asyncio.sleep(1) if task is None: return if not task.done(): task.cancel() logger.info(f"cancel task {task}") asyncio.create_task(stop_if_disconnect()) while True: tc_id += 1 if await request.is_disconnected(): return think_result: RoleRun = await role.think() if not think_result: # End of conversion break think_act_prompt = ThinkActPrompt(role=think_result.role.profile) think_act_prompt.update_think(tc_id, think_result) yield think_act_prompt.prompt + "\n\n" task = asyncio.create_task(role.act()) while not await request.is_disconnected(): if msg_queue: think_act_prompt.update_act(msg_queue.pop(), False) yield think_act_prompt.prompt + "\n\n" continue if task.done(): break await asyncio.sleep(0.5) else: task.cancel() return act_result = await task think_act_prompt.update_act(act_result) yield think_act_prompt.prompt + "\n\n" answer.add_think_act(think_act_prompt) yield answer.prompt + "\n\n" # Notify the front-end that the message is complete. except asyncio.CancelledError: task.cancel() except tenacity.RetryError as retry_error: yield cls.handle_retry_error(tc_id, retry_error) except Exception as ex: description = str(ex) answer = traceback.format_exc() think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) yield think_act_prompt.prompt + "\n\n" finally: CONFIG.WORKSPACE_PATH: pathlib.Path if CONFIG.WORKSPACE_PATH.exists(): shutil.rmtree(CONFIG.WORKSPACE_PATH) @staticmethod def create_error_think_act_prompt(tc_id: int, title, description: str, answer: str) -> ThinkActPrompt: step = ThinkActStep( id=tc_id, status="failed", title=title, description=description, content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), ) return ThinkActPrompt(step=step) @classmethod def handle_retry_error(cls, tc_id: int, error: tenacity.RetryError): # Known exception handling logic try: # Try to get the original exception original_exception = error.last_attempt.exception() while isinstance(original_exception, tenacity.RetryError): original_exception = original_exception.last_attempt.exception() if isinstance(original_exception, openai.AuthenticationError): answer = original_exception.message title = "OpenAI AuthenticationError" think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) return think_act_prompt.prompt + "\n\n" elif isinstance(original_exception, openai.APITimeoutError): answer = original_exception.message title = "OpenAI APITimeoutError" think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) return think_act_prompt.prompt + "\n\n" elif isinstance(original_exception, JSONDecodeError): answer = str(original_exception) title = "MetaGPT Error" description = "LLM return result parsing error" think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, description, answer) return think_act_prompt.prompt + "\n\n" else: return cls.handle_unexpected_error(tc_id, error) except Exception: return cls.handle_unexpected_error(tc_id, error) @classmethod def handle_unexpected_error(cls, tc_id, error): description = str(error) answer = traceback.format_exc() think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) return think_act_prompt.prompt + "\n\n" @staticmethod def _set_context(context: Dict) -> Dict: uid = uuid.uuid4().hex context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): if old in context and new not in context: context[new] = context[old] CONFIG.set_context(context) return context default_llm_stream_log = partial(print, end="") def llm_stream_log(msg): with contextlib.suppress(): CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) class ChatHandler: @staticmethod async def create_message(req_model: NewMsg, request: Request): """Message stream, using SSE.""" event = Service.create_message(req_model, request) headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} return StreamingResponse(event, headers=headers, media_type="text/event-stream") class LLMAPIHandler: @staticmethod async def check_openai_key(req_model: LLMAPIkeyTest): try: # Listing all available models. client = OpenAI(api_key=req_model.api_key) response = client.models.list() model_set = {model.id for model in response.data} if req_model.llm_type in model_set: logger.info("API Key is valid.") return JSONResponse({"valid": True}) else: logger.info("API Key is invalid.") return JSONResponse({"valid": False, "message": "Model not found"}) except Exception as e: # If the request fails, return False logger.info(f"Error: {e}") return JSONResponse({"valid": False, "message": str(e)}) app = FastAPI() app.mount( "/storage", StaticFiles(directory="./storage/"), name="storage", ) app.add_api_route( "/api/messages", endpoint=ChatHandler.create_message, methods=["post"], summary="Session message sending (streaming response)", ) app.add_api_route( "/api/test-api-key", endpoint=LLMAPIHandler.check_openai_key, methods=["post"], summary="LLM APIkey detection", ) app.mount( "/", StaticFiles(directory="./static/", html=True, follow_symlink=True), name="static", ) set_llm_stream_logfunc(llm_stream_log) def main(): server_config = CONFIG.get("SERVER_UVICORN", {}) uvicorn.run(app="__main__:app", **server_config) if __name__ == "__main__": fire.Fire(main)