Spaces:
Running
Running
import json | |
import logging | |
from httpx import get | |
from core.model_runtime.utils.encoders import jsonable_encoder | |
from core.tools.entities.common_entities import I18nObject | |
from core.tools.entities.tool_bundle import ApiBasedToolBundle | |
from core.tools.entities.tool_entities import ( | |
ApiProviderAuthType, | |
ApiProviderSchemaType, | |
ToolCredentialsOption, | |
ToolProviderCredentials, | |
) | |
from core.tools.entities.user_entities import UserTool, UserToolProvider | |
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError | |
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController | |
from core.tools.provider.builtin._positions import BuiltinToolProviderSort | |
from core.tools.provider.tool_provider import ToolProviderController | |
from core.tools.tool_manager import ToolManager | |
from core.tools.utils.configuration import ToolConfigurationManager | |
from core.tools.utils.parser import ApiBasedToolSchemaParser | |
from extensions.ext_database import db | |
from models.tools import ApiToolProvider, BuiltinToolProvider | |
from services.model_provider_service import ModelProviderService | |
from services.tools_transform_service import ToolTransformService | |
logger = logging.getLogger(__name__) | |
class ToolManageService: | |
def list_tool_providers(user_id: str, tenant_id: str): | |
""" | |
list tool providers | |
:return: the list of tool providers | |
""" | |
providers = ToolManager.user_list_providers( | |
user_id, tenant_id | |
) | |
# add icon | |
for provider in providers: | |
ToolTransformService.repack_provider(provider) | |
result = [provider.to_dict() for provider in providers] | |
return result | |
def list_builtin_tool_provider_tools( | |
user_id: str, tenant_id: str, provider: str | |
) -> list[UserTool]: | |
""" | |
list builtin tool provider tools | |
""" | |
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) | |
tools = provider_controller.get_tools() | |
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
# check if user has added the provider | |
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |
BuiltinToolProvider.tenant_id == tenant_id, | |
BuiltinToolProvider.provider == provider, | |
).first() | |
credentials = {} | |
if builtin_provider is not None: | |
# get credentials | |
credentials = builtin_provider.credentials | |
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) | |
result = [] | |
for tool in tools: | |
result.append(ToolTransformService.tool_to_user_tool( | |
tool=tool, credentials=credentials, tenant_id=tenant_id | |
)) | |
return result | |
def list_builtin_provider_credentials_schema( | |
provider_name | |
): | |
""" | |
list builtin provider credentials schema | |
:return: the list of tool providers | |
""" | |
provider = ToolManager.get_builtin_provider(provider_name) | |
return jsonable_encoder([ | |
v for _, v in (provider.credentials_schema or {}).items() | |
]) | |
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]: | |
""" | |
parse api schema to tool bundle | |
""" | |
try: | |
warnings = {} | |
try: | |
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) | |
except Exception as e: | |
raise ValueError(f'invalid schema: {str(e)}') | |
credentials_schema = [ | |
ToolProviderCredentials( | |
name='auth_type', | |
type=ToolProviderCredentials.CredentialsType.SELECT, | |
required=True, | |
default='none', | |
options=[ | |
ToolCredentialsOption(value='none', label=I18nObject( | |
en_US='None', | |
zh_Hans='无' | |
)), | |
ToolCredentialsOption(value='api_key', label=I18nObject( | |
en_US='Api Key', | |
zh_Hans='Api Key' | |
)), | |
], | |
placeholder=I18nObject( | |
en_US='Select auth type', | |
zh_Hans='选择认证方式' | |
) | |
), | |
ToolProviderCredentials( | |
name='api_key_header', | |
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |
required=False, | |
placeholder=I18nObject( | |
en_US='Enter api key header', | |
zh_Hans='输入 api key header,如:X-API-KEY' | |
), | |
default='api_key', | |
help=I18nObject( | |
en_US='HTTP header name for api key', | |
zh_Hans='HTTP 头部字段名,用于传递 api key' | |
) | |
), | |
ToolProviderCredentials( | |
name='api_key_value', | |
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |
required=False, | |
placeholder=I18nObject( | |
en_US='Enter api key', | |
zh_Hans='输入 api key' | |
), | |
default='' | |
), | |
] | |
return jsonable_encoder({ | |
'schema_type': schema_type, | |
'parameters_schema': tool_bundles, | |
'credentials_schema': credentials_schema, | |
'warning': warnings | |
}) | |
except Exception as e: | |
raise ValueError(f'invalid schema: {str(e)}') | |
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]: | |
""" | |
convert schema to tool bundles | |
:return: the list of tool bundles, description | |
""" | |
try: | |
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) | |
return tool_bundles | |
except Exception as e: | |
raise ValueError(f'invalid schema: {str(e)}') | |
def create_api_tool_provider( | |
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, | |
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str | |
): | |
""" | |
create api tool provider | |
""" | |
if schema_type not in [member.value for member in ApiProviderSchemaType]: | |
raise ValueError(f'invalid schema type {schema}') | |
# check if the provider exists | |
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider_name, | |
).first() | |
if provider is not None: | |
raise ValueError(f'provider {provider_name} already exists') | |
# parse openapi to tool bundle | |
extra_info = {} | |
# extra info like description will be set here | |
tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) | |
if len(tool_bundles) > 100: | |
raise ValueError('the number of apis should be less than 100') | |
# create db provider | |
db_provider = ApiToolProvider( | |
tenant_id=tenant_id, | |
user_id=user_id, | |
name=provider_name, | |
icon=json.dumps(icon), | |
schema=schema, | |
description=extra_info.get('description', ''), | |
schema_type_str=schema_type, | |
tools_str=json.dumps(jsonable_encoder(tool_bundles)), | |
credentials_str={}, | |
privacy_policy=privacy_policy, | |
custom_disclaimer=custom_disclaimer | |
) | |
if 'auth_type' not in credentials: | |
raise ValueError('auth_type is required') | |
# get auth type, none or api key | |
auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) | |
# create provider entity | |
provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) | |
# load tools into provider entity | |
provider_controller.load_bundled_tools(tool_bundles) | |
# encrypt credentials | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) | |
db_provider.credentials_str = json.dumps(encrypted_credentials) | |
db.session.add(db_provider) | |
db.session.commit() | |
return { 'result': 'success' } | |
def get_api_tool_provider_remote_schema( | |
user_id: str, tenant_id: str, url: str | |
): | |
""" | |
get api tool provider remote schema | |
""" | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", | |
"Accept": "*/*", | |
} | |
try: | |
response = get(url, headers=headers, timeout=10) | |
if response.status_code != 200: | |
raise ValueError(f'Got status code {response.status_code}') | |
schema = response.text | |
# try to parse schema, avoid SSRF attack | |
ToolManageService.parser_api_schema(schema) | |
except Exception as e: | |
logger.error(f"parse api schema error: {str(e)}") | |
raise ValueError('invalid schema, please check the url you provided') | |
return { | |
'schema': schema | |
} | |
def list_api_tool_provider_tools( | |
user_id: str, tenant_id: str, provider: str | |
) -> list[UserTool]: | |
""" | |
list api tool provider tools | |
""" | |
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider, | |
).first() | |
if provider is None: | |
raise ValueError(f'you have not added provider {provider}') | |
return [ | |
ToolTransformService.tool_to_user_tool(tool_bundle) for tool_bundle in provider.tools | |
] | |
def update_builtin_tool_provider( | |
user_id: str, tenant_id: str, provider_name: str, credentials: dict | |
): | |
""" | |
update builtin tool provider | |
""" | |
# get if the provider exists | |
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |
BuiltinToolProvider.tenant_id == tenant_id, | |
BuiltinToolProvider.provider == provider_name, | |
).first() | |
try: | |
# get provider | |
provider_controller = ToolManager.get_builtin_provider(provider_name) | |
if not provider_controller.need_credentials: | |
raise ValueError(f'provider {provider_name} does not need credentials') | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
# get original credentials if exists | |
if provider is not None: | |
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | |
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) | |
# check if the credential has changed, save the original credential | |
for name, value in credentials.items(): | |
if name in masked_credentials and value == masked_credentials[name]: | |
credentials[name] = original_credentials[name] | |
# validate credentials | |
provider_controller.validate_credentials(credentials) | |
# encrypt credentials | |
credentials = tool_configuration.encrypt_tool_credentials(credentials) | |
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: | |
raise ValueError(str(e)) | |
if provider is None: | |
# create provider | |
provider = BuiltinToolProvider( | |
tenant_id=tenant_id, | |
user_id=user_id, | |
provider=provider_name, | |
encrypted_credentials=json.dumps(credentials), | |
) | |
db.session.add(provider) | |
db.session.commit() | |
else: | |
provider.encrypted_credentials = json.dumps(credentials) | |
db.session.add(provider) | |
db.session.commit() | |
# delete cache | |
tool_configuration.delete_tool_credentials_cache() | |
return { 'result': 'success' } | |
def get_builtin_tool_provider_credentials( | |
user_id: str, tenant_id: str, provider: str | |
): | |
""" | |
get builtin tool provider credentials | |
""" | |
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |
BuiltinToolProvider.tenant_id == tenant_id, | |
BuiltinToolProvider.provider == provider, | |
).first() | |
if provider is None: | |
return {} | |
provider_controller = ToolManager.get_builtin_provider(provider.provider) | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | |
credentials = tool_configuration.mask_tool_credentials(credentials) | |
return credentials | |
def update_api_tool_provider( | |
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, | |
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str | |
): | |
""" | |
update api tool provider | |
""" | |
if schema_type not in [member.value for member in ApiProviderSchemaType]: | |
raise ValueError(f'invalid schema type {schema}') | |
# check if the provider exists | |
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == original_provider, | |
).first() | |
if provider is None: | |
raise ValueError(f'api provider {provider_name} does not exists') | |
# parse openapi to tool bundle | |
extra_info = {} | |
# extra info like description will be set here | |
tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) | |
# update db provider | |
provider.name = provider_name | |
provider.icon = json.dumps(icon) | |
provider.schema = schema | |
provider.description = extra_info.get('description', '') | |
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value | |
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) | |
provider.privacy_policy = privacy_policy | |
provider.custom_disclaimer = custom_disclaimer | |
if 'auth_type' not in credentials: | |
raise ValueError('auth_type is required') | |
# get auth type, none or api key | |
auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) | |
# create provider entity | |
provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type) | |
# load tools into provider entity | |
provider_controller.load_bundled_tools(tool_bundles) | |
# get original credentials if exists | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | |
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) | |
# check if the credential has changed, save the original credential | |
for name, value in credentials.items(): | |
if name in masked_credentials and value == masked_credentials[name]: | |
credentials[name] = original_credentials[name] | |
credentials = tool_configuration.encrypt_tool_credentials(credentials) | |
provider.credentials_str = json.dumps(credentials) | |
db.session.add(provider) | |
db.session.commit() | |
# delete cache | |
tool_configuration.delete_tool_credentials_cache() | |
return { 'result': 'success' } | |
def delete_builtin_tool_provider( | |
user_id: str, tenant_id: str, provider_name: str | |
): | |
""" | |
delete tool provider | |
""" | |
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |
BuiltinToolProvider.tenant_id == tenant_id, | |
BuiltinToolProvider.provider == provider_name, | |
).first() | |
if provider is None: | |
raise ValueError(f'you have not added provider {provider_name}') | |
db.session.delete(provider) | |
db.session.commit() | |
# delete cache | |
provider_controller = ToolManager.get_builtin_provider(provider_name) | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
tool_configuration.delete_tool_credentials_cache() | |
return { 'result': 'success' } | |
def get_builtin_tool_provider_icon( | |
provider: str | |
): | |
""" | |
get tool provider icon and it's mimetype | |
""" | |
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) | |
with open(icon_path, 'rb') as f: | |
icon_bytes = f.read() | |
return icon_bytes, mime_type | |
def get_model_tool_provider_icon( | |
provider: str | |
): | |
""" | |
get tool provider icon and it's mimetype | |
""" | |
service = ModelProviderService() | |
icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US') | |
if icon_bytes is None: | |
raise ValueError(f'provider {provider} does not exists') | |
return icon_bytes, mime_type | |
def list_model_tool_provider_tools( | |
user_id: str, tenant_id: str, provider: str | |
) -> list[UserTool]: | |
""" | |
list model tool provider tools | |
""" | |
provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider) | |
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) | |
result = [ | |
UserTool( | |
author=tool.identity.author, | |
name=tool.identity.name, | |
label=tool.identity.label, | |
description=tool.description.human, | |
parameters=tool.parameters or [] | |
) for tool in tools | |
] | |
return jsonable_encoder(result) | |
def delete_api_tool_provider( | |
user_id: str, tenant_id: str, provider_name: str | |
): | |
""" | |
delete tool provider | |
""" | |
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider_name, | |
).first() | |
if provider is None: | |
raise ValueError(f'you have not added provider {provider_name}') | |
db.session.delete(provider) | |
db.session.commit() | |
return { 'result': 'success' } | |
def get_api_tool_provider( | |
user_id: str, tenant_id: str, provider: str | |
): | |
""" | |
get api tool provider | |
""" | |
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) | |
def test_api_tool_preview( | |
tenant_id: str, | |
provider_name: str, | |
tool_name: str, | |
credentials: dict, | |
parameters: dict, | |
schema_type: str, | |
schema: str | |
): | |
""" | |
test api tool before adding api tool provider | |
""" | |
if schema_type not in [member.value for member in ApiProviderSchemaType]: | |
raise ValueError(f'invalid schema type {schema_type}') | |
try: | |
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) | |
except Exception as e: | |
raise ValueError('invalid schema') | |
# get tool bundle | |
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) | |
if tool_bundle is None: | |
raise ValueError(f'invalid tool name {tool_name}') | |
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider_name, | |
).first() | |
if not db_provider: | |
# create a fake db provider | |
db_provider = ApiToolProvider( | |
tenant_id='', user_id='', name='', icon='', | |
schema=schema, | |
description='', | |
schema_type_str=ApiProviderSchemaType.OPENAPI.value, | |
tools_str=json.dumps(jsonable_encoder(tool_bundles)), | |
credentials_str=json.dumps(credentials), | |
) | |
if 'auth_type' not in credentials: | |
raise ValueError('auth_type is required') | |
# get auth type, none or api key | |
auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) | |
# create provider entity | |
provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) | |
# load tools into provider entity | |
provider_controller.load_bundled_tools(tool_bundles) | |
# decrypt credentials | |
if db_provider.id: | |
tool_configuration = ToolConfigurationManager( | |
tenant_id=tenant_id, | |
provider_controller=provider_controller | |
) | |
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) | |
# check if the credential has changed, save the original credential | |
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) | |
for name, value in credentials.items(): | |
if name in masked_credentials and value == masked_credentials[name]: | |
credentials[name] = decrypted_credentials[name] | |
try: | |
provider_controller.validate_credentials_format(credentials) | |
# get tool | |
tool = provider_controller.get_tool(tool_name) | |
tool = tool.fork_tool_runtime(meta={ | |
'credentials': credentials, | |
'tenant_id': tenant_id, | |
}) | |
result = tool.validate_credentials(credentials, parameters) | |
except Exception as e: | |
return { 'error': str(e) } | |
return { 'result': result or 'empty response' } | |
def list_builtin_tools( | |
user_id: str, tenant_id: str | |
) -> list[UserToolProvider]: | |
""" | |
list builtin tools | |
""" | |
# get all builtin providers | |
provider_controllers = ToolManager.list_builtin_providers() | |
# get all user added providers | |
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( | |
BuiltinToolProvider.tenant_id == tenant_id | |
).all() or [] | |
# find provider | |
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) | |
result: list[UserToolProvider] = [] | |
for provider_controller in provider_controllers: | |
# convert provider controller to user provider | |
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( | |
provider_controller=provider_controller, | |
db_provider=find_provider(provider_controller.identity.name), | |
decrypt_credentials=True | |
) | |
# add icon | |
ToolTransformService.repack_provider(user_builtin_provider) | |
tools = provider_controller.get_tools() | |
for tool in tools: | |
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( | |
tenant_id=tenant_id, | |
tool=tool, | |
credentials=user_builtin_provider.original_credentials, | |
)) | |
result.append(user_builtin_provider) | |
return BuiltinToolProviderSort.sort(result) | |
def list_api_tools( | |
user_id: str, tenant_id: str | |
) -> list[UserToolProvider]: | |
""" | |
list api tools | |
""" | |
# get all api providers | |
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( | |
ApiToolProvider.tenant_id == tenant_id | |
).all() or [] | |
result: list[UserToolProvider] = [] | |
for provider in db_providers: | |
# convert provider controller to user provider | |
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) | |
user_provider = ToolTransformService.api_provider_to_user_provider( | |
provider_controller, | |
db_provider=provider, | |
decrypt_credentials=True | |
) | |
# add icon | |
ToolTransformService.repack_provider(user_provider) | |
tools = provider_controller.get_tools( | |
user_id=user_id, tenant_id=tenant_id | |
) | |
for tool in tools: | |
user_provider.tools.append(ToolTransformService.tool_to_user_tool( | |
tenant_id=tenant_id, | |
tool=tool, | |
credentials=user_provider.original_credentials, | |
)) | |
result.append(user_provider) | |
return result | |