Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from __future__ import annotations | |
import asyncio | |
import contextlib | |
import pathlib | |
import shutil | |
import traceback | |
import uuid | |
from collections import deque | |
from datetime import datetime | |
from enum import Enum | |
from functools import partial | |
from typing import Any, Optional | |
import fire | |
import uvicorn | |
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from fastapi.staticfiles import StaticFiles | |
from loguru import logger | |
from metagpt.actions.action import Action | |
from metagpt.actions.action_output import ActionOutput | |
from metagpt.config import CONFIG | |
from metagpt.logs import set_llm_stream_logfunc | |
from metagpt.schema import Message | |
from pydantic import BaseModel, Field | |
from software_company import RoleRun, SoftwareCompany | |
class QueryAnswerType(Enum): | |
Query = "Q" | |
Answer = "A" | |
class SentenceType(Enum): | |
TEXT = "text" | |
HIHT = "hint" | |
ACTION = "action" | |
ERROR = "error" | |
class MessageStatus(Enum): | |
COMPLETE = "complete" | |
class SentenceValue(BaseModel): | |
answer: str | |
class Sentence(BaseModel): | |
type: str | |
id: Optional[str] = None | |
value: SentenceValue | |
is_finished: Optional[bool] = None | |
class Sentences(BaseModel): | |
id: Optional[str] = None | |
action: Optional[str] = None | |
role: Optional[str] = None | |
skill: Optional[str] = None | |
description: Optional[str] = None | |
timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) | |
status: str | |
contents: list[dict] | |
class NewMsg(BaseModel): | |
"""Chat with MetaGPT""" | |
query: str = Field(description="Problem description") | |
config: dict[str, Any] = Field(description="Configuration information") | |
class ErrorInfo(BaseModel): | |
error: str = None | |
traceback: str = None | |
class ThinkActStep(BaseModel): | |
id: str | |
status: str | |
title: str | |
timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) | |
description: str | |
content: Sentence = None | |
class ThinkActPrompt(BaseModel): | |
message_id: int = None | |
timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) | |
step: ThinkActStep = None | |
skill: Optional[str] = None | |
role: Optional[str] = None | |
def update_think(self, tc_id, action: Action): | |
self.step = ThinkActStep( | |
id=str(tc_id), | |
status="running", | |
title=action.desc, | |
description=action.desc, | |
) | |
def update_act(self, message: ActionOutput | str, is_finished: bool = True): | |
if is_finished: | |
self.step.status = "finish" | |
self.step.content = Sentence( | |
type=SentenceType.TEXT.value, | |
id=str(1), | |
value=SentenceValue(answer=message.content if is_finished else message), | |
is_finished=is_finished, | |
) | |
def guid32(): | |
return str(uuid.uuid4()).replace("-", "")[0:32] | |
def prompt(self): | |
return self.json(exclude_unset=True) | |
class MessageJsonModel(BaseModel): | |
steps: list[Sentences] | |
qa_type: str | |
created_at: datetime = Field(default_factory=datetime.now) | |
query_time: datetime = Field(default_factory=datetime.now) | |
answer_time: datetime = Field(default_factory=datetime.now) | |
score: Optional[int] = None | |
feedback: Optional[str] = None | |
def add_think_act(self, think_act_prompt: ThinkActPrompt): | |
s = Sentences( | |
action=think_act_prompt.step.title, | |
skill=think_act_prompt.skill, | |
description=think_act_prompt.step.description, | |
timestamp=think_act_prompt.timestamp, | |
status=think_act_prompt.step.status, | |
contents=[think_act_prompt.step.content.dict()], | |
) | |
self.steps.append(s) | |
def prompt(self): | |
return self.json(exclude_unset=True) | |
async def create_message(req_model: NewMsg, request: Request): | |
""" | |
Session message stream | |
""" | |
tc_id = 0 | |
try: | |
exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) | |
config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} | |
set_context(config, uuid.uuid4().hex) | |
msg_queue = deque() | |
CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None | |
role = SoftwareCompany() | |
role.recv(message=Message(content=req_model.query)) | |
answer = MessageJsonModel( | |
steps=[ | |
Sentences( | |
contents=[ | |
Sentence( | |
type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True | |
) | |
], | |
status=MessageStatus.COMPLETE.value, | |
) | |
], | |
qa_type=QueryAnswerType.Answer.value, | |
) | |
task = None | |
async def stop_if_disconnect(): | |
while not await request.is_disconnected(): | |
await asyncio.sleep(1) | |
if task is None: | |
return | |
if not task.done(): | |
task.cancel() | |
logger.info(f"cancel task {task}") | |
asyncio.create_task(stop_if_disconnect()) | |
while True: | |
tc_id += 1 | |
if await request.is_disconnected(): | |
return | |
think_result: RoleRun = await role.think() | |
if not think_result: # End of conversion | |
break | |
think_act_prompt = ThinkActPrompt(role=think_result.role.profile) | |
think_act_prompt.update_think(tc_id, think_result) | |
yield think_act_prompt.prompt + "\n\n" | |
task = asyncio.create_task(role.act()) | |
while not await request.is_disconnected(): | |
if msg_queue: | |
think_act_prompt.update_act(msg_queue.pop(), False) | |
yield think_act_prompt.prompt + "\n\n" | |
continue | |
if task.done(): | |
break | |
await asyncio.sleep(0.5) | |
else: | |
task.cancel() | |
return | |
act_result = await task | |
think_act_prompt.update_act(act_result) | |
yield think_act_prompt.prompt + "\n\n" | |
answer.add_think_act(think_act_prompt) | |
yield answer.prompt + "\n\n" # Notify the front-end that the message is complete. | |
except asyncio.CancelledError: | |
task.cancel() | |
except Exception as ex: | |
description = str(ex) | |
answer = traceback.format_exc() | |
step = ThinkActStep( | |
id=tc_id, | |
status="failed", | |
title=description, | |
description=description, | |
content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), | |
) | |
think_act_prompt = ThinkActPrompt(step=step) | |
yield think_act_prompt.prompt + "\n\n" | |
finally: | |
CONFIG.WORKSPACE_PATH: pathlib.Path | |
if CONFIG.WORKSPACE_PATH.exists(): | |
shutil.rmtree(CONFIG.WORKSPACE_PATH) | |
default_llm_stream_log = partial(print, end="") | |
def llm_stream_log(msg): | |
with contextlib.suppress(): | |
CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) | |
def set_context(context, uid): | |
context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) | |
for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): | |
if old in context and new not in context: | |
context[new] = context[old] | |
CONFIG.set_context(context) | |
return context | |
class ChatHandler: | |
async def create_message(req_model: NewMsg, request: Request): | |
"""Message stream, using SSE.""" | |
event = create_message(req_model, request) | |
headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} | |
return StreamingResponse(event, headers=headers, media_type="text/event-stream") | |
app = FastAPI() | |
app.mount( | |
"/storage", | |
StaticFiles(directory="./storage/"), | |
name="storage", | |
) | |
app.add_api_route( | |
"/api/messages", | |
endpoint=ChatHandler.create_message, | |
methods=["post"], | |
summary="Session message sending (streaming response)", | |
) | |
app.mount( | |
"/", | |
StaticFiles(directory="./static/", html=True, follow_symlink=True), | |
name="static", | |
) | |
set_llm_stream_logfunc(llm_stream_log) | |
def main(): | |
server_config = CONFIG.get("SERVER_UVICORN", {}) | |
uvicorn.run(app="__main__:app", **server_config) | |
if __name__ == "__main__": | |
fire.Fire(main) | |