File size: 7,374 Bytes
318db6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from llama_index.core.workflow import StopEvent
from llama_index.core.llms import ChatMessage
from workflow.sql_workflow import sql_workflow, SQLWorkflow
from workflow.roleplay_workflow import RolePlayWorkflow
from workflow.modules import load_chat_store, MySQLChatStore, ToyStatusStore
from workflow.events import TokenEvent, StatusEvent
import asyncio, os, json
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from typing import Dict, Any
from workflow.modules import load_chat_store
from llama_index.llms.ollama import Ollama
from prompts.default_prompts import FINAL_RESPONSE_PROMPT
from workflow.vllm_model import MyVllm
from datetime import datetime
CHAT_STORE_PATH = os.getenv("CHAT_STORE_PATH")
MYSQL_HOST = os.getenv("MYSQL_HOST")
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE")
MYSQL_PORT = int(os.getenv("MYSQL_PORT"))
sql_chat_router = APIRouter(
prefix="/sql_chat",
tags=["sql", "chat"]
)
chat_store_manager = APIRouter(
prefix="/modify_chat_store",
tags=["chat_store"]
)
roleplay_router = APIRouter(
prefix="/roleplay",
tags=["roleplay"]
)
"""
接收数据格式
{
"query": "Top 5 porn videos with url.", # 用户提问
"clientVersion": "ios_3.1.4", # 客户端版本, 不做处理
"traceId": "1290325592508092416", # 请求唯一标识, 不做处理
"messageId": 1290325592508092400, # message表示, 不做处理
"llmSessionId": "xxxxxxxx", # sessionId用于调取用户对话上下文
"llmUserId": "xxxxxx"
}
"""
# 全局资源管理
polling_variable = 0
class ResourceManager:
def __init__(self, max_concurrent=8): # 基于24G显存可以支持2个9G模型
self.semaphore = asyncio.Semaphore(max_concurrent)
self.thread_pool = ThreadPoolExecutor(max_workers=max_concurrent)
self.chat_stores: Dict[str, Any] = {} # 缓存chat store实例
async def get_chat_store(self, store_name: str):
if store_name not in self.chat_stores:
self.chat_stores[store_name] = load_chat_store(store_name)
return self.chat_stores[store_name]
# 全局资源管理器
resource_manager = ResourceManager()
@asynccontextmanager
async def acquire_model_resource():
try:
await resource_manager.semaphore.acquire()
yield
finally:
resource_manager.semaphore.release()
@sql_chat_router.post("/")
async def chat(request: Request):
requestJson = await request.json()
query = str(requestJson.get("query"))
sessionId = str(requestJson.get("llmSessionId"))
try:
isLLMContext = int(requestJson.get("isLlmContext"))
adultMode = int(requestJson.get("adultMode"))
# isCache = int(request.get("isCache"))
print(f"adultMode: {adultMode} - isLLMContext: {isLLMContext}")
if adultMode == 1:
adultMode = True
else:
adultMode = False
except Exception as e:
isLLMContext = 0
adultMode = False
# 调取用户对话上下文 elasticsearch
# chat_history = parse_chat_history(sessionId)
# try:
# chat_store = load_chat_store("testing_chat_store")
# except:
# # 获取当前日期
# current_date = datetime.now().strftime('%Y%m%d')
# chat_store = load_chat_store(current_date)
chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE)
print(f"query: {query} - sessionId: {sessionId}")
llm = MyVllm(
model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2",
api_key="token-abc123",
base_url="http://localhost:17777/v1",
)
response_synthesis_prompt = FINAL_RESPONSE_PROMPT
# Ollama("solonglin/qwen2.5-q6_k-abliterated", temperature=0.8, request_timeout=120, context_window=5000, keep_alive=-1),
wf = SQLWorkflow(
response_llm=llm,
response_synthesis_prompt=response_synthesis_prompt,
chat_store=chat_store,
sessionId=sessionId,
context_flag=isLLMContext,
adultMode=adultMode,
verbose=True,
timeout=60
)
r = wf.run(query=query)
async def generator():
async for event in r.stream_events():
if await request.is_disconnected():
break
if isinstance(event, TokenEvent):
token = event.token
yield token
if isinstance(event, StatusEvent):
yield event.status
if isinstance(event, str):
yield event
await asyncio.sleep(0)
# result = await sql_workflow(query=query, chat_store=chat_store, sessionId=sessionId, llm=llm, context_flag=isLLMContext, adultMode=adultMode)
return StreamingResponse(content=generator(), media_type="text/event-stream")
@roleplay_router.post("/")
async def roleplay(request: Request):
requestJson = await request.json()
query = str(requestJson.get("query"))
sessionId = "rp_" + str(requestJson.get("llmSessionId"))
# messageId = str(requestJson.get("messageId"))
toy_info = requestJson.get("toyIds")
if toy_info:
toy_names = [item["name"] for item in toy_info]
else:
toy_names = None
chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE)
toy_status_store = ToyStatusStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE)
print(f"query: {query} - sessionId: {sessionId} - toyName: {toy_names}")
# setup llm
llm = MyVllm(
model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2",
api_key="token-abc123",
base_url="http://localhost:17777/v1",
)
# init workflow
wf = RolePlayWorkflow(
response_llm=llm,
chat_store=chat_store,
toy_status_store=toy_status_store,
sessionId=sessionId,
gender="female",
toy_names=toy_names,
verbose=True,
timeout=60,
)
r = wf.run(query=query)
async def generator():
async for event in r.stream_events():
if await request.is_disconnected():
break
if isinstance(event, TokenEvent):
token = event.token
yield token
if isinstance(event, StatusEvent):
yield event.status
if isinstance(event, str):
yield event
await asyncio.sleep(0)
return StreamingResponse(content=generator(), media_type="text/event-stream")
@chat_store_manager.post("/")
async def modify_chat_store(request: Request):
try:
request = await request.json()
sessionId = request.get("sessionId")
upMessageBody = request.get("upMessageBody")
downMessageBody = request.get("downMessageBody")
chat_store = MySQLChatStore(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE)
chat_store.del_message(sessionId, upMessageBody)
chat_store.del_message(sessionId, downMessageBody)
return {"status": "SUCCESS"}
except Exception as e:
return {"status": "FAILED", "error": str(e)} |