|
from fastapi import APIRouter, Request |
|
from fastapi.responses import StreamingResponse |
|
from llama_index.core.workflow import StopEvent |
|
from llama_index.core.llms import ChatMessage |
|
from workflow.sql_workflow import sql_workflow, SQLWorkflow |
|
from workflow.roleplay_workflow import RolePlayWorkflow |
|
from workflow.modules import load_chat_store, MySQLChatStore, ToyStatusStore |
|
from workflow.events import TokenEvent, StatusEvent |
|
import asyncio, os, json |
|
from concurrent.futures import ThreadPoolExecutor |
|
from contextlib import asynccontextmanager |
|
from typing import Dict, Any |
|
from workflow.modules import load_chat_store |
|
from llama_index.llms.ollama import Ollama |
|
from prompts.default_prompts import FINAL_RESPONSE_PROMPT |
|
from workflow.vllm_model import MyVllm |
|
from datetime import datetime |
|
|
|
|
|
CHAT_STORE_PATH = os.getenv("CHAT_STORE_PATH") |
|
MYSQL_HOST = os.getenv("MYSQL_HOST") |
|
MYSQL_USER = os.getenv("MYSQL_USER") |
|
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD") |
|
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE") |
|
MYSQL_PORT = int(os.getenv("MYSQL_PORT")) |
|
|
|
sql_chat_router = APIRouter( |
|
prefix="/sql_chat", |
|
tags=["sql", "chat"] |
|
) |
|
|
|
chat_store_manager = APIRouter( |
|
prefix="/modify_chat_store", |
|
tags=["chat_store"] |
|
) |
|
|
|
roleplay_router = APIRouter( |
|
prefix="/roleplay", |
|
tags=["roleplay"] |
|
) |
|
|
|
""" |
|
接收数据格式 |
|
{ |
|
"query": "Top 5 porn videos with url.", # 用户提问 |
|
"clientVersion": "ios_3.1.4", # 客户端版本, 不做处理 |
|
"traceId": "1290325592508092416", # 请求唯一标识, 不做处理 |
|
"messageId": 1290325592508092400, # message表示, 不做处理 |
|
"llmSessionId": "xxxxxxxx", # sessionId用于调取用户对话上下文 |
|
"llmUserId": "xxxxxx" |
|
} |
|
""" |
|
|
|
|
|
|
|
polling_variable = 0 |
|
|
|
class ResourceManager: |
|
def __init__(self, max_concurrent=8): |
|
self.semaphore = asyncio.Semaphore(max_concurrent) |
|
self.thread_pool = ThreadPoolExecutor(max_workers=max_concurrent) |
|
self.chat_stores: Dict[str, Any] = {} |
|
|
|
async def get_chat_store(self, store_name: str): |
|
if store_name not in self.chat_stores: |
|
self.chat_stores[store_name] = load_chat_store(store_name) |
|
return self.chat_stores[store_name] |
|
|
|
|
|
resource_manager = ResourceManager() |
|
|
|
@asynccontextmanager |
|
async def acquire_model_resource(): |
|
try: |
|
await resource_manager.semaphore.acquire() |
|
yield |
|
finally: |
|
resource_manager.semaphore.release() |
|
|
|
@sql_chat_router.post("/") |
|
async def chat(request: Request): |
|
requestJson = await request.json() |
|
query = str(requestJson.get("query")) |
|
sessionId = str(requestJson.get("llmSessionId")) |
|
try: |
|
isLLMContext = int(requestJson.get("isLlmContext")) |
|
adultMode = int(requestJson.get("adultMode")) |
|
|
|
print(f"adultMode: {adultMode} - isLLMContext: {isLLMContext}") |
|
if adultMode == 1: |
|
adultMode = True |
|
else: |
|
adultMode = False |
|
except Exception as e: |
|
isLLMContext = 0 |
|
adultMode = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE) |
|
print(f"query: {query} - sessionId: {sessionId}") |
|
|
|
llm = MyVllm( |
|
model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2", |
|
api_key="token-abc123", |
|
base_url="http://localhost:17777/v1", |
|
) |
|
|
|
response_synthesis_prompt = FINAL_RESPONSE_PROMPT |
|
|
|
wf = SQLWorkflow( |
|
response_llm=llm, |
|
response_synthesis_prompt=response_synthesis_prompt, |
|
chat_store=chat_store, |
|
sessionId=sessionId, |
|
context_flag=isLLMContext, |
|
adultMode=adultMode, |
|
verbose=True, |
|
timeout=60 |
|
) |
|
r = wf.run(query=query) |
|
async def generator(): |
|
async for event in r.stream_events(): |
|
if await request.is_disconnected(): |
|
break |
|
if isinstance(event, TokenEvent): |
|
token = event.token |
|
yield token |
|
if isinstance(event, StatusEvent): |
|
yield event.status |
|
if isinstance(event, str): |
|
yield event |
|
await asyncio.sleep(0) |
|
|
|
return StreamingResponse(content=generator(), media_type="text/event-stream") |
|
|
|
@roleplay_router.post("/") |
|
async def roleplay(request: Request): |
|
requestJson = await request.json() |
|
query = str(requestJson.get("query")) |
|
sessionId = "rp_" + str(requestJson.get("llmSessionId")) |
|
|
|
toy_info = requestJson.get("toyIds") |
|
if toy_info: |
|
toy_names = [item["name"] for item in toy_info] |
|
else: |
|
toy_names = None |
|
chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE) |
|
toy_status_store = ToyStatusStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE) |
|
print(f"query: {query} - sessionId: {sessionId} - toyName: {toy_names}") |
|
|
|
|
|
llm = MyVllm( |
|
model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2", |
|
api_key="token-abc123", |
|
base_url="http://localhost:17777/v1", |
|
) |
|
|
|
|
|
wf = RolePlayWorkflow( |
|
response_llm=llm, |
|
chat_store=chat_store, |
|
toy_status_store=toy_status_store, |
|
sessionId=sessionId, |
|
gender="female", |
|
toy_names=toy_names, |
|
verbose=True, |
|
timeout=60, |
|
) |
|
r = wf.run(query=query) |
|
async def generator(): |
|
async for event in r.stream_events(): |
|
if await request.is_disconnected(): |
|
break |
|
if isinstance(event, TokenEvent): |
|
token = event.token |
|
yield token |
|
if isinstance(event, StatusEvent): |
|
yield event.status |
|
if isinstance(event, str): |
|
yield event |
|
await asyncio.sleep(0) |
|
return StreamingResponse(content=generator(), media_type="text/event-stream") |
|
|
|
|
|
@chat_store_manager.post("/") |
|
async def modify_chat_store(request: Request): |
|
try: |
|
request = await request.json() |
|
sessionId = request.get("sessionId") |
|
upMessageBody = request.get("upMessageBody") |
|
downMessageBody = request.get("downMessageBody") |
|
|
|
chat_store = MySQLChatStore(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE) |
|
chat_store.del_message(sessionId, upMessageBody) |
|
chat_store.del_message(sessionId, downMessageBody) |
|
return {"status": "SUCCESS"} |
|
except Exception as e: |
|
return {"status": "FAILED", "error": str(e)} |