Spaces:
Runtime error
Runtime error
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from __future__ import annotations | |
| import asyncio | |
| import contextlib | |
| import os | |
| import pathlib | |
| import re | |
| import shutil | |
| import time | |
| import traceback | |
| import uuid | |
| from collections import deque | |
| from contextlib import asynccontextmanager | |
| from functools import partial | |
| from typing import Dict | |
| import fire | |
| import tenacity | |
| import uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from loguru import logger | |
| from metagpt.config import CONFIG | |
| from metagpt.logs import set_llm_stream_logfunc | |
| from metagpt.schema import Message | |
| from metagpt.utils.common import any_to_name, any_to_str | |
| from openai import OpenAI | |
| from data_model import ( | |
| LLMAPIkeyTest, | |
| MessageJsonModel, | |
| NewMsg, | |
| Sentence, | |
| Sentences, | |
| SentenceType, | |
| SentenceValue, | |
| ThinkActPrompt, | |
| ThinkActStep, | |
| ) | |
| from message_enum import MessageStatus, QueryAnswerType | |
| from software_company import RoleRun, SoftwareCompany | |
| class Service: | |
| async def create_message(cls, req_model: NewMsg, request: Request): | |
| """ | |
| Session message stream | |
| """ | |
| tc_id = 0 | |
| task = None | |
| try: | |
| exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) | |
| config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} | |
| cls._set_context(config) | |
| msg_queue = deque() | |
| CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None | |
| role = SoftwareCompany() | |
| role.recv(message=Message(content=req_model.query)) | |
| answer = MessageJsonModel( | |
| steps=[ | |
| Sentences( | |
| contents=[ | |
| Sentence( | |
| type=SentenceType.TEXT.value, | |
| value=SentenceValue(answer=req_model.query), | |
| is_finished=True, | |
| ).model_dump() | |
| ], | |
| status=MessageStatus.COMPLETE.value, | |
| ) | |
| ], | |
| qa_type=QueryAnswerType.Answer.value, | |
| ) | |
| async def stop_if_disconnect(): | |
| while not await request.is_disconnected(): | |
| await asyncio.sleep(1) | |
| if task is None: | |
| return | |
| if not task.done(): | |
| task.cancel() | |
| logger.info(f"cancel task {task}") | |
| asyncio.create_task(stop_if_disconnect()) | |
| while True: | |
| tc_id += 1 | |
| if await request.is_disconnected(): | |
| return | |
| think_result: RoleRun = await role.think() | |
| if not think_result: # End of conversion | |
| break | |
| think_act_prompt = ThinkActPrompt(role=think_result.role.profile) | |
| think_act_prompt.update_think(tc_id, think_result) | |
| yield think_act_prompt.prompt + "\n\n" | |
| task = asyncio.create_task(role.act()) | |
| while not await request.is_disconnected(): | |
| if msg_queue: | |
| think_act_prompt.update_act(msg_queue.pop(), False) | |
| yield think_act_prompt.prompt + "\n\n" | |
| continue | |
| if task.done(): | |
| break | |
| await asyncio.sleep(0.5) | |
| else: | |
| task.cancel() | |
| return | |
| act_result = await task | |
| think_act_prompt.update_act(act_result) | |
| yield think_act_prompt.prompt + "\n\n" | |
| answer.add_think_act(think_act_prompt) | |
| yield answer.prompt + "\n\n" # Notify the front-end that the message is complete. | |
| except asyncio.CancelledError: | |
| task.cancel() | |
| except tenacity.RetryError as retry_error: | |
| yield cls.handle_retry_error(tc_id, retry_error) | |
| except Exception as ex: | |
| description = str(ex) | |
| answer = traceback.format_exc() | |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) | |
| yield think_act_prompt.prompt + "\n\n" | |
| finally: | |
| CONFIG.WORKSPACE_PATH: pathlib.Path | |
| if CONFIG.WORKSPACE_PATH.exists(): | |
| shutil.rmtree(CONFIG.WORKSPACE_PATH) | |
| def create_error_think_act_prompt(tc_id: int, title, description: str, answer: str) -> ThinkActPrompt: | |
| step = ThinkActStep( | |
| id=tc_id, | |
| status="failed", | |
| title=title, | |
| description=description, | |
| content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), | |
| ) | |
| return ThinkActPrompt(step=step) | |
| def handle_retry_error(cls, tc_id: int, error: tenacity.RetryError): | |
| # Known exception handling logic | |
| try: | |
| # Try to get the original exception | |
| original_exception = error.last_attempt.exception() | |
| while isinstance(original_exception, tenacity.RetryError): | |
| original_exception = original_exception.last_attempt.exception() | |
| name = any_to_str(original_exception) | |
| if re.match(r"^openai\.", name): | |
| return cls._handle_openai_error(tc_id, original_exception) | |
| if re.match(r"^httpx\.", name): | |
| return cls._handle_httpx_error(tc_id, original_exception) | |
| if re.match(r"^json\.", name): | |
| return cls._handle_json_error(tc_id, original_exception) | |
| return cls.handle_unexpected_error(tc_id, error) | |
| except Exception: | |
| return cls.handle_unexpected_error(tc_id, error) | |
| def _handle_openai_error(cls, tc_id, original_exception): | |
| answer = original_exception.message | |
| title = f"OpenAI {any_to_name(original_exception)}" | |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) | |
| return think_act_prompt.prompt + "\n\n" | |
| def _handle_httpx_error(cls, tc_id, original_exception): | |
| answer = f"{original_exception}. {original_exception.request}" | |
| title = f"httpx {any_to_name(original_exception)}" | |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) | |
| return think_act_prompt.prompt + "\n\n" | |
| def _handle_json_error(cls, tc_id, original_exception): | |
| answer = str(original_exception) | |
| title = "MetaGPT Action Node Error" | |
| description = f"LLM response parse error. {any_to_str(original_exception)}: {original_exception}" | |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, description, answer) | |
| return think_act_prompt.prompt + "\n\n" | |
| def handle_unexpected_error(cls, tc_id, error): | |
| description = str(error) | |
| answer = traceback.format_exc() | |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) | |
| return think_act_prompt.prompt + "\n\n" | |
| def _set_context(context: Dict) -> Dict: | |
| uid = uuid.uuid4().hex | |
| context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) | |
| for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): | |
| if old in context and new not in context: | |
| context[new] = context[old] | |
| CONFIG.set_context(context) | |
| return context | |
| default_llm_stream_log = partial(print, end="") | |
| def llm_stream_log(msg): | |
| with contextlib.suppress(): | |
| CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) | |
| class ChatHandler: | |
| async def create_message(req_model: NewMsg, request: Request): | |
| """Message stream, using SSE.""" | |
| event = Service.create_message(req_model, request) | |
| headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} | |
| return StreamingResponse(event, headers=headers, media_type="text/event-stream") | |
| class LLMAPIHandler: | |
| async def check_openai_key(req_model: LLMAPIkeyTest): | |
| try: | |
| # Listing all available models. | |
| client = OpenAI(api_key=req_model.api_key) | |
| response = client.models.list() | |
| model_set = {model.id for model in response.data} | |
| if req_model.llm_type in model_set: | |
| logger.info("API Key is valid.") | |
| return JSONResponse({"valid": True}) | |
| else: | |
| logger.info("API Key is invalid.") | |
| return JSONResponse({"valid": False, "message": "Model not found"}) | |
| except Exception as e: | |
| # If the request fails, return False | |
| logger.info(f"Error: {e}") | |
| return JSONResponse({"valid": False, "message": str(e)}) | |
| async def lifespan(app: FastAPI): | |
| loop = asyncio.get_running_loop() | |
| loop.create_task(clear_storage()) | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| app.mount( | |
| "/storage", | |
| StaticFiles(directory="./storage/"), | |
| name="storage", | |
| ) | |
| app.add_api_route( | |
| "/api/messages", | |
| endpoint=ChatHandler.create_message, | |
| methods=["post"], | |
| summary="Session message sending (streaming response)", | |
| ) | |
| app.add_api_route( | |
| "/api/test-api-key", | |
| endpoint=LLMAPIHandler.check_openai_key, | |
| methods=["post"], | |
| summary="LLM APIkey detection", | |
| ) | |
| app.mount( | |
| "/", | |
| StaticFiles(directory="./static/", html=True, follow_symlink=True), | |
| name="static", | |
| ) | |
| set_llm_stream_logfunc(llm_stream_log) | |
| def gen_file_modified_time(folder_path): | |
| yield os.path.getmtime(folder_path) | |
| for root, _, files in os.walk(folder_path): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| yield os.path.getmtime(file_path) | |
| async def clear_storage(ttl: float = 1800): | |
| storage = pathlib.Path(CONFIG.get("LOCAL_ROOT", "storage")) | |
| logger.info("task `clear_storage` start running") | |
| while True: | |
| current_time = time.time() | |
| for i in os.listdir(storage): | |
| i = storage / i | |
| try: | |
| last_time = max(gen_file_modified_time(i)) | |
| if current_time - last_time > ttl: | |
| shutil.rmtree(i) | |
| await asyncio.sleep(0) | |
| logger.info(f"Deleted directory: {i}") | |
| except Exception: | |
| logger.exception(f"check {i} error") | |
| await asyncio.sleep(60) | |
| def main(): | |
| server_config = CONFIG.get("SERVER_UVICORN", {}) | |
| uvicorn.run(app="__main__:app", **server_config) | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |