File size: 6,718 Bytes
3079197 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import socket
from pathlib import Path
from web_server import utils
from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo
from .reload_config_base import ReloadConfigBase
class ServiceRegistry(ReloadConfigBase):
@classmethod
@DB.connection_context()
def load_service(cls, **kwargs) -> [ServiceRegistryInfo]:
service_registry_list = ServiceRegistryInfo.query(**kwargs)
return [service for service in service_registry_list]
@classmethod
@DB.connection_context()
def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"):
if not server_info:
server_list = ServerRegistry.query_server_info_from_db(server_name=server_name)
if not server_list:
raise Exception(f"no found server {server_name}")
server_info = server_list[0]
url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}"
else:
url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}"
service_info = {
"f_server_name": server_name,
"f_service_name": service_name,
"f_url": url,
"f_method": method,
"f_params": params if params else {},
"f_data": data if data else {},
"f_headers": headers if headers else {}
}
entity_model, status = ServiceRegistryInfo.get_or_create(
f_server_name=server_name,
f_service_name=service_name,
defaults=service_info)
if status is False:
for key in service_info:
setattr(entity_model, key, service_info[key])
entity_model.save(force_insert=False)
class ServerRegistry(ReloadConfigBase):
FATEBOARD = None
FATE_ON_STANDALONE = None
FATE_ON_EGGROLL = None
FATE_ON_SPARK = None
MODEL_STORE_ADDRESS = None
SERVINGS = None
FATEMANAGER = None
STUDIO = None
@classmethod
def load(cls):
cls.load_server_info_from_conf()
cls.load_server_info_from_db()
@classmethod
def load_server_info_from_conf(cls):
path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF
conf = utils.file_utils.load_yaml_conf(path)
if not isinstance(conf, dict):
raise ValueError('invalid config file')
local_path = path.with_name(f'local.{utils.SERVICE_CONF}')
if local_path.exists():
local_conf = utils.file_utils.load_yaml_conf(local_path)
if not isinstance(local_conf, dict):
raise ValueError('invalid local config file')
conf.update(local_conf)
for k, v in conf.items():
if isinstance(v, dict):
setattr(cls, k.upper(), v)
@classmethod
def register(cls, server_name, server_info):
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http"))
setattr(cls, server_name, server_info)
@classmethod
def save(cls, service_config):
update_server = {}
for server_name, server_info in service_config.items():
cls.parameter_check(server_info)
api_info = server_info.pop("api", {})
for service_name, info in api_info.items():
ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info)
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http")
setattr(cls, server_name.upper(), server_info)
return update_server
@classmethod
def parameter_check(cls, service_info):
if "host" in service_info and "port" in service_info:
cls.connection_test(service_info.get("host"), service_info.get("port"))
@classmethod
def connection_test(cls, ip, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = s.connect_ex((ip, port))
if result != 0:
raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}")
@classmethod
def query(cls, service_name, default=None):
service_info = getattr(cls, service_name, default)
if not service_info:
service_info = utils.get_base_config(service_name, default)
return service_info
@classmethod
@DB.connection_context()
def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]:
if server_name:
server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper())
else:
server_list = ServerRegistryInfo.select()
return [server for server in server_list]
@classmethod
@DB.connection_context()
def load_server_info_from_db(cls):
for server in cls.query_server_info_from_db():
server_info = {
"host": server.f_host,
"port": server.f_port,
"protocol": server.f_protocol
}
setattr(cls, server.f_server_name.upper(), server_info)
@classmethod
@DB.connection_context()
def save_server_info_to_db(cls, server_name, host, port, protocol="http"):
server_info = {
"f_server_name": server_name,
"f_host": host,
"f_port": port,
"f_protocol": protocol
}
entity_model, status = ServerRegistryInfo.get_or_create(
f_server_name=server_name,
defaults=server_info)
if status is False:
for key in server_info:
setattr(entity_model, key, server_info[key])
entity_model.save(force_insert=False) |