SexBot / routers /sql_chat.py
Pew404's picture
Upload folder using huggingface_hub
318db6e verified
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from llama_index.core.workflow import StopEvent
from llama_index.core.llms import ChatMessage
from workflow.sql_workflow import sql_workflow, SQLWorkflow
from workflow.roleplay_workflow import RolePlayWorkflow
from workflow.modules import load_chat_store, MySQLChatStore, ToyStatusStore
from workflow.events import TokenEvent, StatusEvent
import asyncio, os, json
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from typing import Dict, Any
from workflow.modules import load_chat_store
from llama_index.llms.ollama import Ollama
from prompts.default_prompts import FINAL_RESPONSE_PROMPT
from workflow.vllm_model import MyVllm
from datetime import datetime
CHAT_STORE_PATH = os.getenv("CHAT_STORE_PATH")
MYSQL_HOST = os.getenv("MYSQL_HOST")
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE")
MYSQL_PORT = int(os.getenv("MYSQL_PORT"))
sql_chat_router = APIRouter(
prefix="/sql_chat",
tags=["sql", "chat"]
)
chat_store_manager = APIRouter(
prefix="/modify_chat_store",
tags=["chat_store"]
)
roleplay_router = APIRouter(
prefix="/roleplay",
tags=["roleplay"]
)
"""
接收数据格式
{
"query": "Top 5 porn videos with url.", # 用户提问
"clientVersion": "ios_3.1.4", # 客户端版本, 不做处理
"traceId": "1290325592508092416", # 请求唯一标识, 不做处理
"messageId": 1290325592508092400, # message表示, 不做处理
"llmSessionId": "xxxxxxxx", # sessionId用于调取用户对话上下文
"llmUserId": "xxxxxx"
}
"""
# 全局资源管理
polling_variable = 0
class ResourceManager:
def __init__(self, max_concurrent=8): # 基于24G显存可以支持2个9G模型
self.semaphore = asyncio.Semaphore(max_concurrent)
self.thread_pool = ThreadPoolExecutor(max_workers=max_concurrent)
self.chat_stores: Dict[str, Any] = {} # 缓存chat store实例
async def get_chat_store(self, store_name: str):
if store_name not in self.chat_stores:
self.chat_stores[store_name] = load_chat_store(store_name)
return self.chat_stores[store_name]
# 全局资源管理器
resource_manager = ResourceManager()
@asynccontextmanager
async def acquire_model_resource():
try:
await resource_manager.semaphore.acquire()
yield
finally:
resource_manager.semaphore.release()
@sql_chat_router.post("/")
async def chat(request: Request):
requestJson = await request.json()
query = str(requestJson.get("query"))
sessionId = str(requestJson.get("llmSessionId"))
try:
isLLMContext = int(requestJson.get("isLlmContext"))
adultMode = int(requestJson.get("adultMode"))
# isCache = int(request.get("isCache"))
print(f"adultMode: {adultMode} - isLLMContext: {isLLMContext}")
if adultMode == 1:
adultMode = True
else:
adultMode = False
except Exception as e:
isLLMContext = 0
adultMode = False
# 调取用户对话上下文 elasticsearch
# chat_history = parse_chat_history(sessionId)
# try:
# chat_store = load_chat_store("testing_chat_store")
# except:
# # 获取当前日期
# current_date = datetime.now().strftime('%Y%m%d')
# chat_store = load_chat_store(current_date)
chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE)
print(f"query: {query} - sessionId: {sessionId}")
llm = MyVllm(
model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2",
api_key="token-abc123",
base_url="http://localhost:17777/v1",
)
response_synthesis_prompt = FINAL_RESPONSE_PROMPT
# Ollama("solonglin/qwen2.5-q6_k-abliterated", temperature=0.8, request_timeout=120, context_window=5000, keep_alive=-1),
wf = SQLWorkflow(
response_llm=llm,
response_synthesis_prompt=response_synthesis_prompt,
chat_store=chat_store,
sessionId=sessionId,
context_flag=isLLMContext,
adultMode=adultMode,
verbose=True,
timeout=60
)
r = wf.run(query=query)
async def generator():
async for event in r.stream_events():
if await request.is_disconnected():
break
if isinstance(event, TokenEvent):
token = event.token
yield token
if isinstance(event, StatusEvent):
yield event.status
if isinstance(event, str):
yield event
await asyncio.sleep(0)
# result = await sql_workflow(query=query, chat_store=chat_store, sessionId=sessionId, llm=llm, context_flag=isLLMContext, adultMode=adultMode)
return StreamingResponse(content=generator(), media_type="text/event-stream")
@roleplay_router.post("/")
async def roleplay(request: Request):
requestJson = await request.json()
query = str(requestJson.get("query"))
sessionId = "rp_" + str(requestJson.get("llmSessionId"))
# messageId = str(requestJson.get("messageId"))
toy_info = requestJson.get("toyIds")
if toy_info:
toy_names = [item["name"] for item in toy_info]
else:
toy_names = None
chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE)
toy_status_store = ToyStatusStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE)
print(f"query: {query} - sessionId: {sessionId} - toyName: {toy_names}")
# setup llm
llm = MyVllm(
model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2",
api_key="token-abc123",
base_url="http://localhost:17777/v1",
)
# init workflow
wf = RolePlayWorkflow(
response_llm=llm,
chat_store=chat_store,
toy_status_store=toy_status_store,
sessionId=sessionId,
gender="female",
toy_names=toy_names,
verbose=True,
timeout=60,
)
r = wf.run(query=query)
async def generator():
async for event in r.stream_events():
if await request.is_disconnected():
break
if isinstance(event, TokenEvent):
token = event.token
yield token
if isinstance(event, StatusEvent):
yield event.status
if isinstance(event, str):
yield event
await asyncio.sleep(0)
return StreamingResponse(content=generator(), media_type="text/event-stream")
@chat_store_manager.post("/")
async def modify_chat_store(request: Request):
try:
request = await request.json()
sessionId = request.get("sessionId")
upMessageBody = request.get("upMessageBody")
downMessageBody = request.get("downMessageBody")
chat_store = MySQLChatStore(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE)
chat_store.del_message(sessionId, upMessageBody)
chat_store.del_message(sessionId, downMessageBody)
return {"status": "SUCCESS"}
except Exception as e:
return {"status": "FAILED", "error": str(e)}