#!/usr/bin/python3 # -*- coding: utf-8 -*- from __future__ import annotations import asyncio import contextlib import pathlib import shutil import traceback import uuid from collections import deque from datetime import datetime from enum import Enum from functools import partial from typing import Any, Optional import fire import uvicorn from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from loguru import logger from metagpt.actions.action import Action from metagpt.actions.action_output import ActionOutput from metagpt.config import CONFIG from metagpt.logs import set_llm_stream_logfunc from metagpt.schema import Message from pydantic import BaseModel, Field from software_company import RoleRun, SoftwareCompany class QueryAnswerType(Enum): Query = "Q" Answer = "A" class SentenceType(Enum): TEXT = "text" HIHT = "hint" ACTION = "action" ERROR = "error" class MessageStatus(Enum): COMPLETE = "complete" class SentenceValue(BaseModel): answer: str class Sentence(BaseModel): type: str id: Optional[str] = None value: SentenceValue is_finished: Optional[bool] = None class Sentences(BaseModel): id: Optional[str] = None action: Optional[str] = None role: Optional[str] = None skill: Optional[str] = None description: Optional[str] = None timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) status: str contents: list[dict] class NewMsg(BaseModel): """Chat with MetaGPT""" query: str = Field(description="Problem description") config: dict[str, Any] = Field(description="Configuration information") class ErrorInfo(BaseModel): error: str = None traceback: str = None class ThinkActStep(BaseModel): id: str status: str title: str timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) description: str content: Sentence = None class ThinkActPrompt(BaseModel): message_id: int = None timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) step: ThinkActStep = None skill: Optional[str] = None role: Optional[str] = None def update_think(self, tc_id, action: Action): self.step = ThinkActStep( id=str(tc_id), status="running", title=action.desc, description=action.desc, ) def update_act(self, message: ActionOutput | str, is_finished: bool = True): if is_finished: self.step.status = "finish" self.step.content = Sentence( type=SentenceType.TEXT.value, id=str(1), value=SentenceValue(answer=message.content if is_finished else message), is_finished=is_finished, ) @staticmethod def guid32(): return str(uuid.uuid4()).replace("-", "")[0:32] @property def prompt(self): return self.json(exclude_unset=True) class MessageJsonModel(BaseModel): steps: list[Sentences] qa_type: str created_at: datetime = Field(default_factory=datetime.now) query_time: datetime = Field(default_factory=datetime.now) answer_time: datetime = Field(default_factory=datetime.now) score: Optional[int] = None feedback: Optional[str] = None def add_think_act(self, think_act_prompt: ThinkActPrompt): s = Sentences( action=think_act_prompt.step.title, skill=think_act_prompt.skill, description=think_act_prompt.step.description, timestamp=think_act_prompt.timestamp, status=think_act_prompt.step.status, contents=[think_act_prompt.step.content.dict()], ) self.steps.append(s) @property def prompt(self): return self.json(exclude_unset=True) async def create_message(req_model: NewMsg, request: Request): """ Session message stream """ tc_id = 0 try: exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} set_context(config, uuid.uuid4().hex) msg_queue = deque() CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None role = SoftwareCompany() role.recv(message=Message(content=req_model.query)) answer = MessageJsonModel( steps=[ Sentences( contents=[ Sentence( type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True ) ], status=MessageStatus.COMPLETE.value, ) ], qa_type=QueryAnswerType.Answer.value, ) task = None async def stop_if_disconnect(): while not await request.is_disconnected(): await asyncio.sleep(1) if task is None: return if not task.done(): task.cancel() logger.info(f"cancel task {task}") asyncio.create_task(stop_if_disconnect()) while True: tc_id += 1 if await request.is_disconnected(): return think_result: RoleRun = await role.think() if not think_result: # End of conversion break think_act_prompt = ThinkActPrompt(role=think_result.role.profile) think_act_prompt.update_think(tc_id, think_result) yield think_act_prompt.prompt + "\n\n" task = asyncio.create_task(role.act()) while not await request.is_disconnected(): if msg_queue: think_act_prompt.update_act(msg_queue.pop(), False) yield think_act_prompt.prompt + "\n\n" continue if task.done(): break await asyncio.sleep(0.5) else: task.cancel() return act_result = await task think_act_prompt.update_act(act_result) yield think_act_prompt.prompt + "\n\n" answer.add_think_act(think_act_prompt) yield answer.prompt + "\n\n" # Notify the front-end that the message is complete. except asyncio.CancelledError: task.cancel() except Exception as ex: description = str(ex) answer = traceback.format_exc() step = ThinkActStep( id=tc_id, status="failed", title=description, description=description, content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), ) think_act_prompt = ThinkActPrompt(step=step) yield think_act_prompt.prompt + "\n\n" finally: CONFIG.WORKSPACE_PATH: pathlib.Path if CONFIG.WORKSPACE_PATH.exists(): shutil.rmtree(CONFIG.WORKSPACE_PATH) default_llm_stream_log = partial(print, end="") def llm_stream_log(msg): with contextlib.suppress(): CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) def set_context(context, uid): context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): if old in context and new not in context: context[new] = context[old] CONFIG.set_context(context) return context class ChatHandler: @staticmethod async def create_message(req_model: NewMsg, request: Request): """Message stream, using SSE.""" event = create_message(req_model, request) headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} return StreamingResponse(event, headers=headers, media_type="text/event-stream") app = FastAPI() app.mount( "/storage", StaticFiles(directory="./storage/"), name="storage", ) app.add_api_route( "/api/messages", endpoint=ChatHandler.create_message, methods=["post"], summary="Session message sending (streaming response)", ) app.mount( "/", StaticFiles(directory="./static/", html=True, follow_symlink=True), name="static", ) set_llm_stream_logfunc(llm_stream_log) def main(): server_config = CONFIG.get("SERVER_UVICORN", {}) uvicorn.run(app="__main__:app", **server_config) if __name__ == "__main__": fire.Fire(main)