SexBot / workflow /roleplay_workflow.py
Pew404's picture
Upload folder using huggingface_hub
318db6e verified
import concurrent.futures
import threading, math
import asyncio, json, os
from dotenv import load_dotenv
from llama_index.core import PromptTemplate
from llama_index.core.workflow import (
Context,
Workflow,
StartEvent,
StopEvent,
step,
)
from workflow.events import (
SafeStartEvent,
RefuseEvent,
TokenEvent,
ControlEvent
)
from workflow.vllm_model import MyVllm
from workflow.modules import MySQLChatStore, ToyStatusStore, prGreen, prRed, prYellow
from prompts.default_prompts import(
ALIGNMENT_PROMPT,
REFUSE_PROMPT,
FEMALE_ROLEPLAY_PROMPT,
MALE_ROLEPLAY_PROMPT,
TOY_CONTROL_PROMPT_TEST,
TOY_CONTROL_PROMPT
)
REFUSE_INTENTS = [
"medical advice", "Overdose medication", "child pornography", "self-harm", "political bias", "racial hate speech", "illegal drugs", "not harmful", "violent tendencies", "weaponry", "religious hate", "Theft", "Robbery", "Body Disposal", "Forgery", "Smuggling", "Money laundering", "Extortion", "Terrorism", "Explosion", "Cyberattack & Hacking", "illegal stalking", "Arms trafficking"
]
OPERATIONS = ["vibrate"]
class RolePlayWorkflow(Workflow):
def __init__(
self,
response_llm: MyVllm,
chat_store: MySQLChatStore,
toy_status_store: ToyStatusStore,
sessionId: str,
gender: str,
toy_names: list[str] | None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.response_llm = response_llm
self.chat_store = chat_store
self.sessionId = sessionId
self.chat_history = self.chat_store.get_chat_history(self.sessionId)
self.gender = gender
self.toy_names = toy_names
self.toy_status_store = toy_status_store
self.current_pattern = self.toy_status_store.get_latest(self.sessionId)["pattern"]
self.retry_ct = 0
@step
async def censor(self, ctx: Context, ev: StartEvent) -> SafeStartEvent | RefuseEvent | StartEvent:
# process llm output
fmt_messages = ALIGNMENT_PROMPT.format_messages(
user_input=ev.query,
intent_labels=REFUSE_INTENTS
)
response = self.response_llm.chat(fmt_messages).message.content
try:
response = json.loads(response)
intent = response["intent"]
lang = response["language"]
prYellow(f"language: {lang}")
# 检测中文输入默认用英文回复
if lang.lower() in ["zh", "chinese"]:
lang = "english"
await ctx.set("language", lang)
except:
if self.retry_ct < 3:
self.retry_ct += 1
return StartEvent(query=ev.query)
return SafeStartEvent(query=ev.query)
# judge
if intent in ("not harmful", "BDSM content"):
return SafeStartEvent(query=ev.query)
prRed(f"refuse: {intent}")
return RefuseEvent(lang=lang)
@step
async def refuse(self, ctx: Context, ev: RefuseEvent) -> StopEvent:
response = self.response_llm.stream(REFUSE_PROMPT, language=ev.lang)
response_str = ""
for token in response:
response_str += token
content = json.dumps({"content": token})
ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n"))
await asyncio.sleep(0)
ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n"))
prRed(f"Response: {response_str}")
return StopEvent(result="success")
@step
async def chat(self, ctx: Context, ev: SafeStartEvent) -> StopEvent:
# preprocess chat history
self.chat_store.add_message(self.sessionId, "user", ev.query)
# generate response
response_str = ""
match self.gender:
case "male":
prompt = MALE_ROLEPLAY_PROMPT
case "female":
prompt = FEMALE_ROLEPLAY_PROMPT
response = self.response_llm.stream(
prompt,
user_input=ev.query,
chat_history=self.chat_history
)
for token in response:
response_str += token
content = json.dumps({"content": token})
ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n"))
await asyncio.sleep(0.005)
ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n"))
# update chat history
t = threading.Thread(target=self.chat_store.add_message, args=(self.sessionId, "assistant", response_str))
t.start()
prGreen(f"Response:\n{response_str}")
# control toy
if self.toy_names:
pattern = await self.control_toy(ev.query)
ctx.write_event_to_stream(TokenEvent(token=f"data:{pattern}\n\n"))
return StopEvent(result="success")
async def control_toy(self, user_input:str):
command = generate_command(
self.response_llm,
TOY_CONTROL_PROMPT,
user_input=user_input,
chat_history=self.chat_history,
toy_status=self.current_pattern,
available_operations=OPERATIONS
)
command_str = json.dumps(command)
prGreen(command_str)
# ctx.write_event_to_stream(TokenEvent(token=f"data:{command_str}\n\n"))
# await asyncio.sleep(0)
for toy_name in self.toy_names:
t = threading.Thread(target=self.toy_status_store.update, args=(self.sessionId, command_str, toy_name))
t.start()
return command_str
def generate_command(llm: MyVllm, prompt: PromptTemplate, **kwagrs):
"""
Format:
{
"operation":{
"pattern":[
{
"duration": 100,
"operation": int(level),
},
{
"duration": 100,
"operation": int(level),
},
]
}
}
"""
retry_ct = 0
fmt_messages = prompt.format_messages(**kwagrs)
while True:
response = llm.chat(fmt_messages).message.content
try:
response = json.loads(response)
break
except:
if retry_ct > 3:
return "Failed to generate command"
retry_ct += 1
continue
return response
def post_process_command(command: dict):
total_time = sum([item["duration"] for item in command["pattern"]])
mults = math.ceil(10000 / total_time)