Spaces:
Runtime error
Runtime error
import json | |
import os | |
from pathlib import Path | |
from typing import Any, Optional, TypeVar, overload | |
import confuse | |
import dotenv | |
from pydantic import parse_obj_as | |
from hibiapi import __file__ as root_file | |
CONFIG_DIR = Path(".") / "configs" | |
DEFAULT_DIR = Path(root_file).parent / "configs" | |
_T = TypeVar("_T") | |
class ConfigSubView(confuse.Subview): | |
def get(self) -> Any: ... | |
def get(self, template: type[_T]) -> _T: ... | |
def get(self, template: Optional[type[_T]] = None): # type: ignore | |
object_ = super().get() | |
if template is not None: | |
return parse_obj_as(template, object_) | |
return object_ | |
def get_optional(self, template: type[_T]) -> Optional[_T]: | |
try: | |
return self.get(template) | |
except Exception: | |
return None | |
def as_str(self) -> str: | |
return self.get(str) | |
def as_str_seq(self, split: str = "\n") -> list[str]: # type: ignore | |
return [ | |
stripped | |
for line in self.as_str().strip().split(split) | |
if (stripped := line.strip()) | |
] | |
def as_number(self) -> int: | |
return self.get(int) | |
def as_bool(self) -> bool: | |
return self.get(bool) | |
def as_path(self) -> Path: | |
return self.get(Path) | |
def as_dict(self) -> dict[str, Any]: | |
return self.get(dict[str, Any]) | |
def __getitem__(self, key: str) -> "ConfigSubView": | |
return self.__class__(self, key) | |
class AppConfig(confuse.Configuration): | |
def __init__(self, name: str): | |
self._config_name = name | |
self._config = CONFIG_DIR / (filename := f"{name}.yml") | |
self._default = DEFAULT_DIR / filename | |
super().__init__(name) | |
self._add_env_source() | |
def config_dir(self) -> str: | |
return str(CONFIG_DIR) | |
def user_config_path(self) -> str: | |
return str(self._config) | |
def _add_env_source(self): | |
if dotenv.find_dotenv(): | |
dotenv.load_dotenv() | |
config_name = f"{self._config_name.lower()}_" | |
env_configs = { | |
k[len(config_name) :].lower(): str(v) | |
for k, v in os.environ.items() | |
if k.lower().startswith(config_name) | |
} | |
# Convert `AAA_BBB_CCC=DDD` to `{'aaa':{'bbb':{'ccc':'ddd'}}}` | |
source_tree: dict[str, Any] = {} | |
for key, value in env_configs.items(): | |
_tmp = source_tree | |
*nodes, name = key.split("_") | |
for node in nodes: | |
_tmp = _tmp.setdefault(node, {}) | |
if value == "": | |
continue | |
try: | |
_tmp[name] = json.loads(value) | |
except json.JSONDecodeError: | |
_tmp[name] = value | |
self.sources.insert(0, confuse.ConfigSource.of(source_tree)) | |
def _add_default_source(self): | |
self.add(confuse.YamlSource(self._default, default=True)) | |
def _add_user_source(self): | |
self.add(confuse.YamlSource(self._config, optional=True)) | |
def __getitem__(self, key: str) -> ConfigSubView: | |
return ConfigSubView(self, key) | |
class GeneralConfig(AppConfig): | |
def __init__(self, name: str): | |
super().__init__(name) | |
class APIConfig(GeneralConfig): | |
pass | |
Config = GeneralConfig("general") | |