generic-chatbot-backend / common /configuration.py
muryshev's picture
update
0341212
raw
history blame
3.91 kB
"""This module includes classes to define configurations."""
from typing import Any, Dict, Optional
from pyaml_env import parse_config
class EntitiesExtractorConfiguration:
def __init__(self, config_data):
self.strategy_name = str(config_data['strategy_name'])
self.strategy_params: dict | None = config_data['strategy_params']
self.process_tables = bool(config_data['process_tables'])
self.neighbors_max_distance = int(config_data['neighbors_max_distance'])
class SearchConfiguration:
def __init__(self, config_data):
self.use_vector_search = bool(config_data['use_vector_search'])
self.vectorizer_path = str(config_data['vectorizer_path'])
self.device = str(config_data['device'])
self.max_entities_per_message = int(config_data['max_entities_per_message'])
self.max_entities_per_dialogue = int(config_data['max_entities_per_dialogue'])
self.use_qe = bool(config_data['use_qe'])
class FilesConfiguration:
def __init__(self, config_data):
self.empty_start = bool(config_data['empty_start'])
self.documents_path = str(config_data['documents_path'])
class DataBaseConfiguration:
def __init__(self, config_data):
self.entities = EntitiesExtractorConfiguration(config_data['entities'])
self.search = SearchConfiguration(config_data['search'])
self.files = FilesConfiguration(config_data['files'])
class LLMConfiguration:
def __init__(self, config_data):
self.base_url = (
str(config_data['base_url'])
if config_data['base_url'] not in ("", "null", "None")
else None
)
self.api_key_env = (
str(config_data['api_key_env'])
if config_data['api_key_env'] not in ("", "null", "None")
else None
)
self.model = str(config_data['model'])
self.tokenizer = str(config_data['tokenizer_name'])
self.temperature = float(config_data['temperature'])
self.top_p = float(config_data['top_p'])
self.min_p = float(config_data['min_p'])
self.frequency_penalty = float(config_data['frequency_penalty'])
self.presence_penalty = float(config_data['presence_penalty'])
self.seed = int(config_data['seed'])
class CommonConfiguration:
def __init__(self, config_data):
self.log_file_path = str(config_data['log_file_path'])
self.log_sql_path = str(config_data['log_sql_path'])
self.log_level = str(config_data['log_level'])
class Configuration:
"""Encapsulates all configuration parameters."""
def __init__(self, config_file_path: Optional[str] = None):
"""Creates an instance of the class.
There is 1 possibility to load configuration data:
- from configuration file using a path;
If attribute is not None, the configuration file is used.
Args:
config_file_path: A path to config file to load configuration data from.
"""
if config_file_path is not None:
self._load_from_config(config_file_path)
else:
raise ValueError('At least one of config_path must be not None.')
def _load_data(self, data: Dict[str, Any]):
"""Loads configuration data from dictionary.
Args:
data: A configuration dictionary to load configuration data from.
"""
self.common_config = CommonConfiguration(data['common'])
self.db_config = DataBaseConfiguration(data['bd'])
self.llm_config = LLMConfiguration(data['llm'])
def _load_from_config(self, config_file_path: str):
"""Reads configuration file and form configuration dictionary.
Args:
config_file_path: A configuration dictionary to load configuration data from.
"""
data = parse_config(config_file_path)
self._load_data(data)