Spaces:
Sleeping
Sleeping
| from typing import Literal | |
| from pydantic import BaseModel, Field | |
| from private_gpt.settings.settings_loader import load_active_settings | |
| class CorsSettings(BaseModel): | |
| """CORS configuration. | |
| For more details on the CORS configuration, see: | |
| # * https://fastapi.tiangolo.com/tutorial/cors/ | |
| # * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS | |
| """ | |
| enabled: bool = Field( | |
| description="Flag indicating if CORS headers are set or not." | |
| "If set to True, the CORS headers will be set to allow all origins, methods and headers.", | |
| default=False, | |
| ) | |
| allow_credentials: bool = Field( | |
| description="Indicate that cookies should be supported for cross-origin requests", | |
| default=False, | |
| ) | |
| allow_origins: list[str] = Field( | |
| description="A list of origins that should be permitted to make cross-origin requests.", | |
| default=[], | |
| ) | |
| allow_origin_regex: list[str] = Field( | |
| description="A regex string to match against origins that should be permitted to make cross-origin requests.", | |
| default=None, | |
| ) | |
| allow_methods: list[str] = Field( | |
| description="A list of HTTP methods that should be allowed for cross-origin requests.", | |
| default=[ | |
| "GET", | |
| ], | |
| ) | |
| allow_headers: list[str] = Field( | |
| description="A list of HTTP request headers that should be supported for cross-origin requests.", | |
| default=[], | |
| ) | |
| class AuthSettings(BaseModel): | |
| """Authentication configuration. | |
| The implementation of the authentication strategy must | |
| """ | |
| enabled: bool = Field( | |
| description="Flag indicating if authentication is enabled or not.", | |
| default=False, | |
| ) | |
| secret: str = Field( | |
| description="The secret to be used for authentication. " | |
| "It can be any non-blank string. For HTTP basic authentication, " | |
| "this value should be the whole 'Authorization' header that is expected" | |
| ) | |
| class ServerSettings(BaseModel): | |
| env_name: str = Field( | |
| description="Name of the environment (prod, staging, local...)" | |
| ) | |
| port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001") | |
| cors: CorsSettings = Field( | |
| description="CORS configuration", default=CorsSettings(enabled=False) | |
| ) | |
| auth: AuthSettings = Field( | |
| description="Authentication configuration", | |
| default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"), | |
| ) | |
| class DataSettings(BaseModel): | |
| local_data_folder: str = Field( | |
| description="Path to local storage." | |
| "It will be treated as an absolute path if it starts with /" | |
| ) | |
| class LLMSettings(BaseModel): | |
| mode: Literal["local", "openai", "sagemaker", "mock"] | |
| max_new_tokens: int = Field( | |
| 256, | |
| description="The maximum number of token that the LLM is authorized to generate in one completion.", | |
| ) | |
| class VectorstoreSettings(BaseModel): | |
| database: Literal["chroma", "qdrant"] | |
| class LocalSettings(BaseModel): | |
| llm_hf_repo_id: str | |
| llm_hf_model_file: str | |
| embedding_hf_model_name: str = Field( | |
| description="Name of the HuggingFace model to use for embeddings" | |
| ) | |
| prompt_style: Literal["default", "llama2", "tag"] = Field( | |
| "llama2", | |
| description=( | |
| "The prompt style to use for the chat engine. " | |
| "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" | |
| "If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n" | |
| "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" | |
| "`llama2` is the historic behaviour. `default` might work better with your custom models." | |
| ), | |
| ) | |
| default_system_prompt: str | None = Field( | |
| None, | |
| description=( | |
| "The default system prompt to use for the chat engine. " | |
| "If none is given - use the default system prompt (from the llama_index). " | |
| "Please note that the default prompt might not be the same for all prompt styles. " | |
| "Also note that this is only used if the first message is not a system message. " | |
| ), | |
| ) | |
| class EmbeddingSettings(BaseModel): | |
| mode: Literal["local", "openai", "sagemaker", "mock"] | |
| ingest_mode: Literal["simple", "batch", "parallel"] = Field( | |
| "simple", | |
| description=( | |
| "The ingest mode to use for the embedding engine:\n" | |
| "If `simple` - ingest files sequentially and one by one. It is the historic behaviour.\n" | |
| "If `batch` - if multiple files, parse all the files in parallel, " | |
| "and send them in batch to the embedding model.\n" | |
| "If `parallel` - parse the files in parallel using multiple cores, and embedd them in parallel.\n" | |
| "`parallel` is the fastest mode for local setup, as it parallelize IO RW in the index.\n" | |
| "For modes that leverage parallelization, you can specify the number of " | |
| "workers to use with `count_workers`.\n" | |
| ), | |
| ) | |
| count_workers: int = Field( | |
| 2, | |
| description=( | |
| "The number of workers to use for file ingestion.\n" | |
| "In `batch` mode, this is the number of workers used to parse the files.\n" | |
| "In `parallel` mode, this is the number of workers used to parse the files and embed them.\n" | |
| "This is only used if `ingest_mode` is not `simple`.\n" | |
| "Do not go too high with this number, as it might cause memory issues. (especially in `parallel` mode)\n" | |
| "Do not set it higher than your number of threads of your CPU." | |
| ), | |
| ) | |
| class SagemakerSettings(BaseModel): | |
| llm_endpoint_name: str | |
| embedding_endpoint_name: str | |
| class OpenAISettings(BaseModel): | |
| api_key: str | |
| class UISettings(BaseModel): | |
| enabled: bool | |
| path: str | |
| class QdrantSettings(BaseModel): | |
| location: str | None = Field( | |
| None, | |
| description=( | |
| "If `:memory:` - use in-memory Qdrant instance.\n" | |
| "If `str` - use it as a `url` parameter.\n" | |
| ), | |
| ) | |
| url: str | None = Field( | |
| None, | |
| description=( | |
| "Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'." | |
| ), | |
| ) | |
| port: int | None = Field(6333, description="Port of the REST API interface.") | |
| grpc_port: int | None = Field(6334, description="Port of the gRPC interface.") | |
| prefer_grpc: bool | None = Field( | |
| False, | |
| description="If `true` - use gRPC interface whenever possible in custom methods.", | |
| ) | |
| https: bool | None = Field( | |
| None, | |
| description="If `true` - use HTTPS(SSL) protocol.", | |
| ) | |
| api_key: str | None = Field( | |
| None, | |
| description="API key for authentication in Qdrant Cloud.", | |
| ) | |
| prefix: str | None = Field( | |
| None, | |
| description=( | |
| "Prefix to add to the REST URL path." | |
| "Example: `service/v1` will result in " | |
| "'http://localhost:6333/service/v1/{qdrant-endpoint}' for REST API." | |
| ), | |
| ) | |
| timeout: float | None = Field( | |
| None, | |
| description="Timeout for REST and gRPC API requests.", | |
| ) | |
| host: str | None = Field( | |
| None, | |
| description="Host name of Qdrant service. If url and host are None, set to 'localhost'.", | |
| ) | |
| path: str | None = Field(None, description="Persistence path for QdrantLocal.") | |
| force_disable_check_same_thread: bool | None = Field( | |
| True, | |
| description=( | |
| "For QdrantLocal, force disable check_same_thread. Default: `True`" | |
| "Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient." | |
| ), | |
| ) | |
| class Settings(BaseModel): | |
| server: ServerSettings | |
| data: DataSettings | |
| ui: UISettings | |
| llm: LLMSettings | |
| embedding: EmbeddingSettings | |
| local: LocalSettings | |
| sagemaker: SagemakerSettings | |
| openai: OpenAISettings | |
| vectorstore: VectorstoreSettings | |
| qdrant: QdrantSettings | None = None | |
| """ | |
| This is visible just for DI or testing purposes. | |
| Use dependency injection or `settings()` method instead. | |
| """ | |
| unsafe_settings = load_active_settings() | |
| """ | |
| This is visible just for DI or testing purposes. | |
| Use dependency injection or `settings()` method instead. | |
| """ | |
| unsafe_typed_settings = Settings(**unsafe_settings) | |
| def settings() -> Settings: | |
| """Get the current loaded settings from the DI container. | |
| This method exists to keep compatibility with the existing code, | |
| that require global access to the settings. | |
| For regular components use dependency injection instead. | |
| """ | |
| from private_gpt.di import global_injector | |
| return global_injector.get(Settings) | |