from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from llama_index.core.workflow import StopEvent from llama_index.core.llms import ChatMessage from workflow.sql_workflow import sql_workflow, SQLWorkflow from workflow.roleplay_workflow import RolePlayWorkflow from workflow.modules import load_chat_store, MySQLChatStore, ToyStatusStore from workflow.events import TokenEvent, StatusEvent import asyncio, os, json from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from typing import Dict, Any from workflow.modules import load_chat_store from llama_index.llms.ollama import Ollama from prompts.default_prompts import FINAL_RESPONSE_PROMPT from workflow.vllm_model import MyVllm from datetime import datetime CHAT_STORE_PATH = os.getenv("CHAT_STORE_PATH") MYSQL_HOST = os.getenv("MYSQL_HOST") MYSQL_USER = os.getenv("MYSQL_USER") MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD") MYSQL_DATABASE = os.getenv("MYSQL_DATABASE") MYSQL_PORT = int(os.getenv("MYSQL_PORT")) sql_chat_router = APIRouter( prefix="/sql_chat", tags=["sql", "chat"] ) chat_store_manager = APIRouter( prefix="/modify_chat_store", tags=["chat_store"] ) roleplay_router = APIRouter( prefix="/roleplay", tags=["roleplay"] ) """ 接收数据格式 { "query": "Top 5 porn videos with url.", # 用户提问 "clientVersion": "ios_3.1.4", # 客户端版本, 不做处理 "traceId": "1290325592508092416", # 请求唯一标识, 不做处理 "messageId": 1290325592508092400, # message表示, 不做处理 "llmSessionId": "xxxxxxxx", # sessionId用于调取用户对话上下文 "llmUserId": "xxxxxx" } """ # 全局资源管理 polling_variable = 0 class ResourceManager: def __init__(self, max_concurrent=8): # 基于24G显存可以支持2个9G模型 self.semaphore = asyncio.Semaphore(max_concurrent) self.thread_pool = ThreadPoolExecutor(max_workers=max_concurrent) self.chat_stores: Dict[str, Any] = {} # 缓存chat store实例 async def get_chat_store(self, store_name: str): if store_name not in self.chat_stores: self.chat_stores[store_name] = load_chat_store(store_name) return self.chat_stores[store_name] # 全局资源管理器 resource_manager = ResourceManager() @asynccontextmanager async def acquire_model_resource(): try: await resource_manager.semaphore.acquire() yield finally: resource_manager.semaphore.release() @sql_chat_router.post("/") async def chat(request: Request): requestJson = await request.json() query = str(requestJson.get("query")) sessionId = str(requestJson.get("llmSessionId")) try: isLLMContext = int(requestJson.get("isLlmContext")) adultMode = int(requestJson.get("adultMode")) # isCache = int(request.get("isCache")) print(f"adultMode: {adultMode} - isLLMContext: {isLLMContext}") if adultMode == 1: adultMode = True else: adultMode = False except Exception as e: isLLMContext = 0 adultMode = False # 调取用户对话上下文 elasticsearch # chat_history = parse_chat_history(sessionId) # try: # chat_store = load_chat_store("testing_chat_store") # except: # # 获取当前日期 # current_date = datetime.now().strftime('%Y%m%d') # chat_store = load_chat_store(current_date) chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE) print(f"query: {query} - sessionId: {sessionId}") llm = MyVllm( model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2", api_key="token-abc123", base_url="http://localhost:17777/v1", ) response_synthesis_prompt = FINAL_RESPONSE_PROMPT # Ollama("solonglin/qwen2.5-q6_k-abliterated", temperature=0.8, request_timeout=120, context_window=5000, keep_alive=-1), wf = SQLWorkflow( response_llm=llm, response_synthesis_prompt=response_synthesis_prompt, chat_store=chat_store, sessionId=sessionId, context_flag=isLLMContext, adultMode=adultMode, verbose=True, timeout=60 ) r = wf.run(query=query) async def generator(): async for event in r.stream_events(): if await request.is_disconnected(): break if isinstance(event, TokenEvent): token = event.token yield token if isinstance(event, StatusEvent): yield event.status if isinstance(event, str): yield event await asyncio.sleep(0) # result = await sql_workflow(query=query, chat_store=chat_store, sessionId=sessionId, llm=llm, context_flag=isLLMContext, adultMode=adultMode) return StreamingResponse(content=generator(), media_type="text/event-stream") @roleplay_router.post("/") async def roleplay(request: Request): requestJson = await request.json() query = str(requestJson.get("query")) sessionId = "rp_" + str(requestJson.get("llmSessionId")) # messageId = str(requestJson.get("messageId")) toy_info = requestJson.get("toyIds") if toy_info: toy_names = [item["name"] for item in toy_info] else: toy_names = None chat_store = MySQLChatStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE) toy_status_store = ToyStatusStore(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE) print(f"query: {query} - sessionId: {sessionId} - toyName: {toy_names}") # setup llm llm = MyVllm( model="huihui-ai/Qwen2.5-7B-Instruct-abliterated-v2", api_key="token-abc123", base_url="http://localhost:17777/v1", ) # init workflow wf = RolePlayWorkflow( response_llm=llm, chat_store=chat_store, toy_status_store=toy_status_store, sessionId=sessionId, gender="female", toy_names=toy_names, verbose=True, timeout=60, ) r = wf.run(query=query) async def generator(): async for event in r.stream_events(): if await request.is_disconnected(): break if isinstance(event, TokenEvent): token = event.token yield token if isinstance(event, StatusEvent): yield event.status if isinstance(event, str): yield event await asyncio.sleep(0) return StreamingResponse(content=generator(), media_type="text/event-stream") @chat_store_manager.post("/") async def modify_chat_store(request: Request): try: request = await request.json() sessionId = request.get("sessionId") upMessageBody = request.get("upMessageBody") downMessageBody = request.get("downMessageBody") chat_store = MySQLChatStore(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE) chat_store.del_message(sessionId, upMessageBody) chat_store.del_message(sessionId, downMessageBody) return {"status": "SUCCESS"} except Exception as e: return {"status": "FAILED", "error": str(e)}