|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Tuple |
|
|
|
import numpy as np |
|
|
|
from .log import log |
|
|
|
|
|
class ContentSafetyGuardrail: |
|
def is_safe(self, **kwargs) -> Tuple[bool, str]: |
|
raise NotImplementedError("Child classes must implement the is_safe method") |
|
|
|
|
|
class PostprocessingGuardrail: |
|
def postprocess(self, frames: np.ndarray) -> np.ndarray: |
|
raise NotImplementedError("Child classes must implement the postprocess method") |
|
|
|
|
|
class GuardrailRunner: |
|
def __init__( |
|
self, |
|
safety_models: list[ContentSafetyGuardrail] | None = None, |
|
generic_block_msg: str = "", |
|
generic_safe_msg: str = "", |
|
postprocessors: list[PostprocessingGuardrail] | None = None, |
|
): |
|
self.safety_models = safety_models |
|
self.generic_block_msg = generic_block_msg |
|
self.generic_safe_msg = generic_safe_msg if generic_safe_msg else "Prompt is safe" |
|
self.postprocessors = postprocessors |
|
|
|
def run_safety_check(self, input: Any) -> Tuple[bool, str]: |
|
"""Run the safety check on the input.""" |
|
if not self.safety_models: |
|
log.warning("No safety models found, returning safe") |
|
return True, self.generic_safe_msg |
|
|
|
for guardrail in self.safety_models: |
|
guardrail_name = str(guardrail.__class__.__name__).upper() |
|
log.debug(f"Running guardrail: {guardrail_name}") |
|
safe, message = guardrail.is_safe(input) |
|
if not safe: |
|
reasoning = self.generic_block_msg if self.generic_block_msg else f"{guardrail_name}: {message}" |
|
return False, reasoning |
|
return True, self.generic_safe_msg |
|
|
|
def postprocess(self, frames: np.ndarray) -> np.ndarray: |
|
"""Run the postprocessing on the video frames.""" |
|
if not self.postprocessors: |
|
log.warning("No postprocessors found, returning original frames") |
|
return frames |
|
|
|
for guardrail in self.postprocessors: |
|
guardrail_name = str(guardrail.__class__.__name__).upper() |
|
log.debug(f"Running guardrail: {guardrail_name}") |
|
frames = guardrail.postprocess(frames) |
|
return frames |
|
|