Spaces:
Running
Running
import logging | |
import threading | |
import time | |
from typing import Any, Optional | |
from flask import Flask, current_app | |
from pydantic import BaseModel | |
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |
from core.app.entities.queue_entities import QueueMessageReplaceEvent | |
from core.moderation.base import ModerationAction, ModerationOutputsResult | |
from core.moderation.factory import ModerationFactory | |
logger = logging.getLogger(__name__) | |
class ModerationRule(BaseModel): | |
type: str | |
config: dict[str, Any] | |
class OutputModeration(BaseModel): | |
DEFAULT_BUFFER_SIZE: int = 300 | |
tenant_id: str | |
app_id: str | |
rule: ModerationRule | |
queue_manager: AppQueueManager | |
thread: Optional[threading.Thread] = None | |
thread_running: bool = True | |
buffer: str = '' | |
is_final_chunk: bool = False | |
final_output: Optional[str] = None | |
class Config: | |
arbitrary_types_allowed = True | |
def should_direct_output(self): | |
return self.final_output is not None | |
def get_final_output(self): | |
return self.final_output | |
def append_new_token(self, token: str): | |
self.buffer += token | |
if not self.thread: | |
self.thread = self.start_thread() | |
def moderation_completion(self, completion: str, public_event: bool = False) -> str: | |
self.buffer = completion | |
self.is_final_chunk = True | |
result = self.moderation( | |
tenant_id=self.tenant_id, | |
app_id=self.app_id, | |
moderation_buffer=completion | |
) | |
if not result or not result.flagged: | |
return completion | |
if result.action == ModerationAction.DIRECT_OUTPUT: | |
final_output = result.preset_response | |
else: | |
final_output = result.text | |
if public_event: | |
self.queue_manager.publish( | |
QueueMessageReplaceEvent( | |
text=final_output | |
), | |
PublishFrom.TASK_PIPELINE | |
) | |
return final_output | |
def start_thread(self) -> threading.Thread: | |
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE)) | |
thread = threading.Thread(target=self.worker, kwargs={ | |
'flask_app': current_app._get_current_object(), | |
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE | |
}) | |
thread.start() | |
return thread | |
def stop_thread(self): | |
if self.thread and self.thread.is_alive(): | |
self.thread_running = False | |
def worker(self, flask_app: Flask, buffer_size: int): | |
with flask_app.app_context(): | |
current_length = 0 | |
while self.thread_running: | |
moderation_buffer = self.buffer | |
buffer_length = len(moderation_buffer) | |
if not self.is_final_chunk: | |
chunk_length = buffer_length - current_length | |
if 0 <= chunk_length < buffer_size: | |
time.sleep(1) | |
continue | |
current_length = buffer_length | |
result = self.moderation( | |
tenant_id=self.tenant_id, | |
app_id=self.app_id, | |
moderation_buffer=moderation_buffer | |
) | |
if not result or not result.flagged: | |
continue | |
if result.action == ModerationAction.DIRECT_OUTPUT: | |
final_output = result.preset_response | |
self.final_output = final_output | |
else: | |
final_output = result.text + self.buffer[len(moderation_buffer):] | |
# trigger replace event | |
if self.thread_running: | |
self.queue_manager.publish( | |
QueueMessageReplaceEvent( | |
text=final_output | |
), | |
PublishFrom.TASK_PIPELINE | |
) | |
if result.action == ModerationAction.DIRECT_OUTPUT: | |
break | |
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: | |
try: | |
moderation_factory = ModerationFactory( | |
name=self.rule.type, | |
app_id=app_id, | |
tenant_id=tenant_id, | |
config=self.rule.config | |
) | |
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) | |
return result | |
except Exception as e: | |
logger.error("Moderation Output error: %s", e) | |
return None | |