import concurrent.futures import threading, math import asyncio, json, os from dotenv import load_dotenv from llama_index.core import PromptTemplate from llama_index.core.workflow import ( Context, Workflow, StartEvent, StopEvent, step, ) from workflow.events import ( SafeStartEvent, RefuseEvent, TokenEvent, ControlEvent ) from workflow.vllm_model import MyVllm from workflow.modules import MySQLChatStore, ToyStatusStore, prGreen, prRed, prYellow from prompts.default_prompts import( ALIGNMENT_PROMPT, REFUSE_PROMPT, FEMALE_ROLEPLAY_PROMPT, MALE_ROLEPLAY_PROMPT, TOY_CONTROL_PROMPT_TEST, TOY_CONTROL_PROMPT ) REFUSE_INTENTS = [ "medical advice", "Overdose medication", "child pornography", "self-harm", "political bias", "racial hate speech", "illegal drugs", "not harmful", "violent tendencies", "weaponry", "religious hate", "Theft", "Robbery", "Body Disposal", "Forgery", "Smuggling", "Money laundering", "Extortion", "Terrorism", "Explosion", "Cyberattack & Hacking", "illegal stalking", "Arms trafficking" ] OPERATIONS = ["vibrate"] class RolePlayWorkflow(Workflow): def __init__( self, response_llm: MyVllm, chat_store: MySQLChatStore, toy_status_store: ToyStatusStore, sessionId: str, gender: str, toy_names: list[str] | None, *args, **kwargs ): super().__init__(*args, **kwargs) self.response_llm = response_llm self.chat_store = chat_store self.sessionId = sessionId self.chat_history = self.chat_store.get_chat_history(self.sessionId) self.gender = gender self.toy_names = toy_names self.toy_status_store = toy_status_store self.current_pattern = self.toy_status_store.get_latest(self.sessionId)["pattern"] self.retry_ct = 0 @step async def censor(self, ctx: Context, ev: StartEvent) -> SafeStartEvent | RefuseEvent | StartEvent: # process llm output fmt_messages = ALIGNMENT_PROMPT.format_messages( user_input=ev.query, intent_labels=REFUSE_INTENTS ) response = self.response_llm.chat(fmt_messages).message.content try: response = json.loads(response) intent = response["intent"] lang = response["language"] prYellow(f"language: {lang}") # 检测中文输入默认用英文回复 if lang.lower() in ["zh", "chinese"]: lang = "english" await ctx.set("language", lang) except: if self.retry_ct < 3: self.retry_ct += 1 return StartEvent(query=ev.query) return SafeStartEvent(query=ev.query) # judge if intent in ("not harmful", "BDSM content"): return SafeStartEvent(query=ev.query) prRed(f"refuse: {intent}") return RefuseEvent(lang=lang) @step async def refuse(self, ctx: Context, ev: RefuseEvent) -> StopEvent: response = self.response_llm.stream(REFUSE_PROMPT, language=ev.lang) response_str = "" for token in response: response_str += token content = json.dumps({"content": token}) ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n")) await asyncio.sleep(0) ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n")) prRed(f"Response: {response_str}") return StopEvent(result="success") @step async def chat(self, ctx: Context, ev: SafeStartEvent) -> StopEvent: # preprocess chat history self.chat_store.add_message(self.sessionId, "user", ev.query) # generate response response_str = "" match self.gender: case "male": prompt = MALE_ROLEPLAY_PROMPT case "female": prompt = FEMALE_ROLEPLAY_PROMPT response = self.response_llm.stream( prompt, user_input=ev.query, chat_history=self.chat_history ) for token in response: response_str += token content = json.dumps({"content": token}) ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n")) await asyncio.sleep(0.005) ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n")) # update chat history t = threading.Thread(target=self.chat_store.add_message, args=(self.sessionId, "assistant", response_str)) t.start() prGreen(f"Response:\n{response_str}") # control toy if self.toy_names: pattern = await self.control_toy(ev.query) ctx.write_event_to_stream(TokenEvent(token=f"data:{pattern}\n\n")) return StopEvent(result="success") async def control_toy(self, user_input:str): command = generate_command( self.response_llm, TOY_CONTROL_PROMPT, user_input=user_input, chat_history=self.chat_history, toy_status=self.current_pattern, available_operations=OPERATIONS ) command_str = json.dumps(command) prGreen(command_str) # ctx.write_event_to_stream(TokenEvent(token=f"data:{command_str}\n\n")) # await asyncio.sleep(0) for toy_name in self.toy_names: t = threading.Thread(target=self.toy_status_store.update, args=(self.sessionId, command_str, toy_name)) t.start() return command_str def generate_command(llm: MyVllm, prompt: PromptTemplate, **kwagrs): """ Format: { "operation":{ "pattern":[ { "duration": 100, "operation": int(level), }, { "duration": 100, "operation": int(level), }, ] } } """ retry_ct = 0 fmt_messages = prompt.format_messages(**kwagrs) while True: response = llm.chat(fmt_messages).message.content try: response = json.loads(response) break except: if retry_ct > 3: return "Failed to generate command" retry_ct += 1 continue return response def post_process_command(command: dict): total_time = sum([item["duration"] for item in command["pattern"]]) mults = math.ceil(10000 / total_time)