|
from pydantic import BaseModel |
|
|
|
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor |
|
from core.helper.encrypter import decrypt_token |
|
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult |
|
from extensions.ext_database import db |
|
from models.api_based_extension import APIBasedExtension |
|
|
|
|
|
class ModerationInputParams(BaseModel): |
|
app_id: str = "" |
|
inputs: dict = {} |
|
query: str = "" |
|
|
|
|
|
class ModerationOutputParams(BaseModel): |
|
app_id: str = "" |
|
text: str |
|
|
|
|
|
class ApiModeration(Moderation): |
|
name: str = "api" |
|
|
|
@classmethod |
|
def validate_config(cls, tenant_id: str, config: dict) -> None: |
|
""" |
|
Validate the incoming form config data. |
|
|
|
:param tenant_id: the id of workspace |
|
:param config: the form config data |
|
:return: |
|
""" |
|
cls._validate_inputs_and_outputs_config(config, False) |
|
|
|
api_based_extension_id = config.get("api_based_extension_id") |
|
if not api_based_extension_id: |
|
raise ValueError("api_based_extension_id is required") |
|
|
|
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id) |
|
if not extension: |
|
raise ValueError("API-based Extension not found. Please check it again.") |
|
|
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: |
|
flagged = False |
|
preset_response = "" |
|
|
|
if self.config["inputs_config"]["enabled"]: |
|
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) |
|
|
|
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) |
|
return ModerationInputsResult(**result) |
|
|
|
return ModerationInputsResult( |
|
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response |
|
) |
|
|
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: |
|
flagged = False |
|
preset_response = "" |
|
|
|
if self.config["outputs_config"]["enabled"]: |
|
params = ModerationOutputParams(app_id=self.app_id, text=text) |
|
|
|
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) |
|
return ModerationOutputsResult(**result) |
|
|
|
return ModerationOutputsResult( |
|
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response |
|
) |
|
|
|
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: |
|
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) |
|
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) |
|
|
|
result = requestor.request(extension_point, params) |
|
return result |
|
|
|
@staticmethod |
|
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: |
|
extension = ( |
|
db.session.query(APIBasedExtension) |
|
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) |
|
.first() |
|
) |
|
|
|
return extension |
|
|