|
import concurrent.futures |
|
import threading, math |
|
import asyncio, json, os |
|
from dotenv import load_dotenv |
|
from llama_index.core import PromptTemplate |
|
from llama_index.core.workflow import ( |
|
Context, |
|
Workflow, |
|
StartEvent, |
|
StopEvent, |
|
step, |
|
) |
|
from workflow.events import ( |
|
SafeStartEvent, |
|
RefuseEvent, |
|
TokenEvent, |
|
ControlEvent |
|
) |
|
from workflow.vllm_model import MyVllm |
|
from workflow.modules import MySQLChatStore, ToyStatusStore, prGreen, prRed, prYellow |
|
from prompts.default_prompts import( |
|
ALIGNMENT_PROMPT, |
|
REFUSE_PROMPT, |
|
FEMALE_ROLEPLAY_PROMPT, |
|
MALE_ROLEPLAY_PROMPT, |
|
TOY_CONTROL_PROMPT_TEST, |
|
TOY_CONTROL_PROMPT |
|
) |
|
|
|
REFUSE_INTENTS = [ |
|
"medical advice", "Overdose medication", "child pornography", "self-harm", "political bias", "racial hate speech", "illegal drugs", "not harmful", "violent tendencies", "weaponry", "religious hate", "Theft", "Robbery", "Body Disposal", "Forgery", "Smuggling", "Money laundering", "Extortion", "Terrorism", "Explosion", "Cyberattack & Hacking", "illegal stalking", "Arms trafficking" |
|
] |
|
OPERATIONS = ["vibrate"] |
|
|
|
class RolePlayWorkflow(Workflow): |
|
def __init__( |
|
self, |
|
response_llm: MyVllm, |
|
chat_store: MySQLChatStore, |
|
toy_status_store: ToyStatusStore, |
|
sessionId: str, |
|
gender: str, |
|
toy_names: list[str] | None, |
|
*args, |
|
**kwargs |
|
): |
|
super().__init__(*args, **kwargs) |
|
self.response_llm = response_llm |
|
self.chat_store = chat_store |
|
self.sessionId = sessionId |
|
self.chat_history = self.chat_store.get_chat_history(self.sessionId) |
|
self.gender = gender |
|
self.toy_names = toy_names |
|
self.toy_status_store = toy_status_store |
|
self.current_pattern = self.toy_status_store.get_latest(self.sessionId)["pattern"] |
|
self.retry_ct = 0 |
|
|
|
@step |
|
async def censor(self, ctx: Context, ev: StartEvent) -> SafeStartEvent | RefuseEvent | StartEvent: |
|
|
|
fmt_messages = ALIGNMENT_PROMPT.format_messages( |
|
user_input=ev.query, |
|
intent_labels=REFUSE_INTENTS |
|
) |
|
response = self.response_llm.chat(fmt_messages).message.content |
|
try: |
|
response = json.loads(response) |
|
intent = response["intent"] |
|
lang = response["language"] |
|
prYellow(f"language: {lang}") |
|
|
|
if lang.lower() in ["zh", "chinese"]: |
|
lang = "english" |
|
await ctx.set("language", lang) |
|
except: |
|
if self.retry_ct < 3: |
|
self.retry_ct += 1 |
|
return StartEvent(query=ev.query) |
|
return SafeStartEvent(query=ev.query) |
|
|
|
if intent in ("not harmful", "BDSM content"): |
|
return SafeStartEvent(query=ev.query) |
|
prRed(f"refuse: {intent}") |
|
return RefuseEvent(lang=lang) |
|
|
|
@step |
|
async def refuse(self, ctx: Context, ev: RefuseEvent) -> StopEvent: |
|
response = self.response_llm.stream(REFUSE_PROMPT, language=ev.lang) |
|
response_str = "" |
|
for token in response: |
|
response_str += token |
|
content = json.dumps({"content": token}) |
|
ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n")) |
|
await asyncio.sleep(0) |
|
ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n")) |
|
|
|
prRed(f"Response: {response_str}") |
|
return StopEvent(result="success") |
|
|
|
@step |
|
async def chat(self, ctx: Context, ev: SafeStartEvent) -> StopEvent: |
|
|
|
self.chat_store.add_message(self.sessionId, "user", ev.query) |
|
|
|
|
|
response_str = "" |
|
match self.gender: |
|
case "male": |
|
prompt = MALE_ROLEPLAY_PROMPT |
|
case "female": |
|
prompt = FEMALE_ROLEPLAY_PROMPT |
|
response = self.response_llm.stream( |
|
prompt, |
|
user_input=ev.query, |
|
chat_history=self.chat_history |
|
) |
|
for token in response: |
|
response_str += token |
|
content = json.dumps({"content": token}) |
|
ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n")) |
|
await asyncio.sleep(0.005) |
|
ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n")) |
|
|
|
t = threading.Thread(target=self.chat_store.add_message, args=(self.sessionId, "assistant", response_str)) |
|
t.start() |
|
prGreen(f"Response:\n{response_str}") |
|
|
|
if self.toy_names: |
|
pattern = await self.control_toy(ev.query) |
|
ctx.write_event_to_stream(TokenEvent(token=f"data:{pattern}\n\n")) |
|
|
|
return StopEvent(result="success") |
|
|
|
async def control_toy(self, user_input:str): |
|
command = generate_command( |
|
self.response_llm, |
|
TOY_CONTROL_PROMPT, |
|
user_input=user_input, |
|
chat_history=self.chat_history, |
|
toy_status=self.current_pattern, |
|
available_operations=OPERATIONS |
|
) |
|
command_str = json.dumps(command) |
|
prGreen(command_str) |
|
|
|
|
|
for toy_name in self.toy_names: |
|
t = threading.Thread(target=self.toy_status_store.update, args=(self.sessionId, command_str, toy_name)) |
|
t.start() |
|
return command_str |
|
|
|
def generate_command(llm: MyVllm, prompt: PromptTemplate, **kwagrs): |
|
""" |
|
Format: |
|
{ |
|
"operation":{ |
|
"pattern":[ |
|
{ |
|
"duration": 100, |
|
"operation": int(level), |
|
}, |
|
{ |
|
"duration": 100, |
|
"operation": int(level), |
|
}, |
|
] |
|
} |
|
} |
|
""" |
|
retry_ct = 0 |
|
fmt_messages = prompt.format_messages(**kwagrs) |
|
while True: |
|
response = llm.chat(fmt_messages).message.content |
|
try: |
|
response = json.loads(response) |
|
break |
|
except: |
|
if retry_ct > 3: |
|
return "Failed to generate command" |
|
retry_ct += 1 |
|
continue |
|
return response |
|
|
|
def post_process_command(command: dict): |
|
total_time = sum([item["duration"] for item in command["pattern"]]) |
|
mults = math.ceil(10000 / total_time) |
|
|