Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from __future__ import annotations | |
import asyncio | |
import contextlib | |
import os | |
import pathlib | |
import re | |
import shutil | |
import time | |
import traceback | |
import uuid | |
from collections import deque | |
from contextlib import asynccontextmanager | |
from functools import partial | |
from typing import Dict | |
import fire | |
import tenacity | |
import uvicorn | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from fastapi.staticfiles import StaticFiles | |
from loguru import logger | |
from metagpt.config import CONFIG | |
from metagpt.logs import set_llm_stream_logfunc | |
from metagpt.schema import Message | |
from metagpt.utils.common import any_to_name, any_to_str | |
from openai import OpenAI | |
from data_model import ( | |
LLMAPIkeyTest, | |
MessageJsonModel, | |
NewMsg, | |
Sentence, | |
Sentences, | |
SentenceType, | |
SentenceValue, | |
ThinkActPrompt, | |
ThinkActStep, | |
) | |
from message_enum import MessageStatus, QueryAnswerType | |
from software_company import RoleRun, SoftwareCompany | |
class Service: | |
async def create_message(cls, req_model: NewMsg, request: Request): | |
""" | |
Session message stream | |
""" | |
tc_id = 0 | |
task = None | |
try: | |
exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) | |
config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} | |
cls._set_context(config) | |
msg_queue = deque() | |
CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None | |
role = SoftwareCompany() | |
role.recv(message=Message(content=req_model.query)) | |
answer = MessageJsonModel( | |
steps=[ | |
Sentences( | |
contents=[ | |
Sentence( | |
type=SentenceType.TEXT.value, | |
value=SentenceValue(answer=req_model.query), | |
is_finished=True, | |
).model_dump() | |
], | |
status=MessageStatus.COMPLETE.value, | |
) | |
], | |
qa_type=QueryAnswerType.Answer.value, | |
) | |
async def stop_if_disconnect(): | |
while not await request.is_disconnected(): | |
await asyncio.sleep(1) | |
if task is None: | |
return | |
if not task.done(): | |
task.cancel() | |
logger.info(f"cancel task {task}") | |
asyncio.create_task(stop_if_disconnect()) | |
while True: | |
tc_id += 1 | |
if await request.is_disconnected(): | |
return | |
think_result: RoleRun = await role.think() | |
if not think_result: # End of conversion | |
break | |
think_act_prompt = ThinkActPrompt(role=think_result.role.profile) | |
think_act_prompt.update_think(tc_id, think_result) | |
yield think_act_prompt.prompt + "\n\n" | |
task = asyncio.create_task(role.act()) | |
while not await request.is_disconnected(): | |
if msg_queue: | |
think_act_prompt.update_act(msg_queue.pop(), False) | |
yield think_act_prompt.prompt + "\n\n" | |
continue | |
if task.done(): | |
break | |
await asyncio.sleep(0.5) | |
else: | |
task.cancel() | |
return | |
act_result = await task | |
think_act_prompt.update_act(act_result) | |
yield think_act_prompt.prompt + "\n\n" | |
answer.add_think_act(think_act_prompt) | |
yield answer.prompt + "\n\n" # Notify the front-end that the message is complete. | |
except asyncio.CancelledError: | |
task.cancel() | |
except tenacity.RetryError as retry_error: | |
yield cls.handle_retry_error(tc_id, retry_error) | |
except Exception as ex: | |
description = str(ex) | |
answer = traceback.format_exc() | |
think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) | |
yield think_act_prompt.prompt + "\n\n" | |
finally: | |
CONFIG.WORKSPACE_PATH: pathlib.Path | |
if CONFIG.WORKSPACE_PATH.exists(): | |
shutil.rmtree(CONFIG.WORKSPACE_PATH) | |
def create_error_think_act_prompt(tc_id: int, title, description: str, answer: str) -> ThinkActPrompt: | |
step = ThinkActStep( | |
id=tc_id, | |
status="failed", | |
title=title, | |
description=description, | |
content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), | |
) | |
return ThinkActPrompt(step=step) | |
def handle_retry_error(cls, tc_id: int, error: tenacity.RetryError): | |
# Known exception handling logic | |
try: | |
# Try to get the original exception | |
original_exception = error.last_attempt.exception() | |
while isinstance(original_exception, tenacity.RetryError): | |
original_exception = original_exception.last_attempt.exception() | |
name = any_to_str(original_exception) | |
if re.match(r"^openai\.", name): | |
return cls._handle_openai_error(tc_id, original_exception) | |
if re.match(r"^httpx\.", name): | |
return cls._handle_httpx_error(tc_id, original_exception) | |
if re.match(r"^json\.", name): | |
return cls._handle_json_error(tc_id, original_exception) | |
return cls.handle_unexpected_error(tc_id, error) | |
except Exception: | |
return cls.handle_unexpected_error(tc_id, error) | |
def _handle_openai_error(cls, tc_id, original_exception): | |
answer = original_exception.message | |
title = f"OpenAI {any_to_name(original_exception)}" | |
think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) | |
return think_act_prompt.prompt + "\n\n" | |
def _handle_httpx_error(cls, tc_id, original_exception): | |
answer = f"{original_exception}. {original_exception.request}" | |
title = f"httpx {any_to_name(original_exception)}" | |
think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) | |
return think_act_prompt.prompt + "\n\n" | |
def _handle_json_error(cls, tc_id, original_exception): | |
answer = str(original_exception) | |
title = "MetaGPT Action Node Error" | |
description = f"LLM response parse error. {any_to_str(original_exception)}: {original_exception}" | |
think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, description, answer) | |
return think_act_prompt.prompt + "\n\n" | |
def handle_unexpected_error(cls, tc_id, error): | |
description = str(error) | |
answer = traceback.format_exc() | |
think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) | |
return think_act_prompt.prompt + "\n\n" | |
def _set_context(context: Dict) -> Dict: | |
uid = uuid.uuid4().hex | |
context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) | |
for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): | |
if old in context and new not in context: | |
context[new] = context[old] | |
CONFIG.set_context(context) | |
return context | |
default_llm_stream_log = partial(print, end="") | |
def llm_stream_log(msg): | |
with contextlib.suppress(): | |
CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) | |
class ChatHandler: | |
async def create_message(req_model: NewMsg, request: Request): | |
"""Message stream, using SSE.""" | |
event = Service.create_message(req_model, request) | |
headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} | |
return StreamingResponse(event, headers=headers, media_type="text/event-stream") | |
class LLMAPIHandler: | |
async def check_openai_key(req_model: LLMAPIkeyTest): | |
try: | |
# Listing all available models. | |
client = OpenAI(api_key=req_model.api_key) | |
response = client.models.list() | |
model_set = {model.id for model in response.data} | |
if req_model.llm_type in model_set: | |
logger.info("API Key is valid.") | |
return JSONResponse({"valid": True}) | |
else: | |
logger.info("API Key is invalid.") | |
return JSONResponse({"valid": False, "message": "Model not found"}) | |
except Exception as e: | |
# If the request fails, return False | |
logger.info(f"Error: {e}") | |
return JSONResponse({"valid": False, "message": str(e)}) | |
async def lifespan(app: FastAPI): | |
loop = asyncio.get_running_loop() | |
loop.create_task(clear_storage()) | |
yield | |
app = FastAPI(lifespan=lifespan) | |
app.mount( | |
"/storage", | |
StaticFiles(directory="./storage/"), | |
name="storage", | |
) | |
app.add_api_route( | |
"/api/messages", | |
endpoint=ChatHandler.create_message, | |
methods=["post"], | |
summary="Session message sending (streaming response)", | |
) | |
app.add_api_route( | |
"/api/test-api-key", | |
endpoint=LLMAPIHandler.check_openai_key, | |
methods=["post"], | |
summary="LLM APIkey detection", | |
) | |
app.mount( | |
"/", | |
StaticFiles(directory="./static/", html=True, follow_symlink=True), | |
name="static", | |
) | |
set_llm_stream_logfunc(llm_stream_log) | |
def gen_file_modified_time(folder_path): | |
yield os.path.getmtime(folder_path) | |
for root, _, files in os.walk(folder_path): | |
for file in files: | |
file_path = os.path.join(root, file) | |
yield os.path.getmtime(file_path) | |
async def clear_storage(ttl: float = 1800): | |
storage = pathlib.Path(CONFIG.get("LOCAL_ROOT", "storage")) | |
logger.info("task `clear_storage` start running") | |
while True: | |
current_time = time.time() | |
for i in os.listdir(storage): | |
i = storage / i | |
try: | |
last_time = max(gen_file_modified_time(i)) | |
if current_time - last_time > ttl: | |
shutil.rmtree(i) | |
await asyncio.sleep(0) | |
logger.info(f"Deleted directory: {i}") | |
except Exception: | |
logger.exception(f"check {i} error") | |
await asyncio.sleep(60) | |
def main(): | |
server_config = CONFIG.get("SERVER_UVICORN", {}) | |
uvicorn.run(app="__main__:app", **server_config) | |
if __name__ == "__main__": | |
fire.Fire(main) | |