|
from typing import Optional, Dict, Any, List |
|
from pydantic_settings import BaseSettings, SettingsConfigDict |
|
from pydantic import Field, field_validator, ValidationInfo |
|
from pathlib import Path |
|
import fnmatch |
|
from know_lang_bot.core.types import ModelProvider |
|
import os |
|
|
|
def _validate_api_key(v: Optional[str], info: ValidationInfo) -> Optional[str]: |
|
"""Validate API key is present when required""" |
|
if info.data['model_provider'] in [ |
|
ModelProvider.OPENAI, |
|
ModelProvider.ANTHROPIC, |
|
ModelProvider.VOYAGE |
|
]: |
|
if not v: |
|
raise ValueError(f"API key required for {info.data['model_provider']}") |
|
elif info.data['model_provider'] == ModelProvider.ANTHROPIC: |
|
os.environ["ANTHROPIC_API_KEY"] = v |
|
elif info.data['model_provider'] == ModelProvider.OPENAI: |
|
os.environ["OPENAI_API_KEY"] = v |
|
elif info.data['model_provider'] == ModelProvider.VOYAGE: |
|
os.environ["VOYAGE_API_KEY"] = v |
|
|
|
return v |
|
|
|
class PathPatterns(BaseSettings): |
|
include: List[str] = Field( |
|
default=["**/*"], |
|
description="Glob patterns for paths to include" |
|
) |
|
exclude: List[str] = Field( |
|
default=[ |
|
"**/venv/**", |
|
"**/.git/**", |
|
"**/__pycache__/**", |
|
"**/tests/**", |
|
], |
|
description="Glob patterns for paths to exclude" |
|
) |
|
|
|
def should_process_path(self, path: str) -> bool: |
|
"""Check if a path should be processed based on include/exclude patterns""" |
|
path_str = str(path) |
|
|
|
|
|
for pattern in self.exclude: |
|
if fnmatch.fnmatch(path_str, pattern): |
|
return False |
|
|
|
|
|
for pattern in self.include: |
|
if fnmatch.fnmatch(path_str, pattern): |
|
return True |
|
|
|
return False |
|
|
|
class LanguageConfig(BaseSettings): |
|
enabled: bool = True |
|
file_extensions: List[str] |
|
tree_sitter_language: str |
|
chunk_types: List[str] |
|
max_file_size: int = Field( |
|
default=1_000_000, |
|
description="Maximum file size to process in bytes" |
|
) |
|
|
|
class ParserConfig(BaseSettings): |
|
languages: Dict[str, LanguageConfig] = Field( |
|
default={ |
|
"python": LanguageConfig( |
|
file_extensions=[".py"], |
|
tree_sitter_language="python", |
|
chunk_types=["class_definition", "function_definition"] |
|
) |
|
} |
|
) |
|
path_patterns: PathPatterns = Field(default_factory=PathPatterns) |
|
|
|
|
|
class EmbeddingConfig(BaseSettings): |
|
"""Shared embedding configuration""" |
|
model_name: str = Field( |
|
default="mxbai-embed-large", |
|
description="Name of the embedding model" |
|
) |
|
model_provider: ModelProvider = Field( |
|
default=ModelProvider.OLLAMA, |
|
description="Provider for embeddings" |
|
) |
|
dimension: int = Field( |
|
default=768, |
|
description="Embedding dimension" |
|
) |
|
settings: Dict[str, Any] = Field( |
|
default_factory=dict, |
|
description="Provider-specific settings" |
|
) |
|
api_key: Optional[str] = Field( |
|
default=None, |
|
description="API key for the model provider" |
|
) |
|
|
|
@field_validator('api_key', mode='after') |
|
@classmethod |
|
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]: |
|
return _validate_api_key(v, info) |
|
|
|
class LLMConfig(BaseSettings): |
|
model_name: str = Field( |
|
default="llama3.2", |
|
description="Name of the LLM model to use" |
|
) |
|
model_provider: str = Field( |
|
default=ModelProvider.OLLAMA, |
|
description="Model provider (anthropic, openai, ollama, etc)" |
|
) |
|
api_key: Optional[str] = Field( |
|
default=None, |
|
description="API key for the model provider" |
|
) |
|
model_settings: Dict[str, Any] = Field( |
|
default_factory=dict, |
|
description="Additional model settings" |
|
) |
|
|
|
@field_validator('api_key', mode='after') |
|
@classmethod |
|
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]: |
|
return _validate_api_key(v, info) |
|
|
|
class DBConfig(BaseSettings): |
|
persist_directory: Path = Field( |
|
default=Path("./chroma_db"), |
|
description="Directory to store ChromaDB files" |
|
) |
|
collection_name: str = Field( |
|
default="code_chunks", |
|
description="Name of the ChromaDB collection" |
|
) |
|
codebase_directory: Path = Field( |
|
default=Path("./"), |
|
description="Root directory of the codebase to analyze" |
|
) |
|
|
|
class RerankerConfig(BaseSettings): |
|
enabled: bool = Field( |
|
default=False, |
|
description="Enable reranking" |
|
) |
|
model_name: str = Field( |
|
default="reranker", |
|
description="Name of the reranker model to use" |
|
) |
|
model_provider: str = Field( |
|
default=ModelProvider.OLLAMA, |
|
description="Model provider (anthropic, openai, ollama, etc)" |
|
) |
|
api_key: Optional[str] = Field( |
|
default=None, |
|
description="API key for the model provider" |
|
) |
|
top_k: int = Field( |
|
default=4, |
|
description="Number of most relevant documents to return from reranking" |
|
) |
|
relevance_threshold: float = Field( |
|
default=0.5, |
|
description="Minimum relevance score to include a document in reranking" |
|
) |
|
|
|
@field_validator('api_key', mode='after') |
|
@classmethod |
|
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]: |
|
return _validate_api_key(v, info) |
|
|
|
class ChatConfig(BaseSettings): |
|
max_context_chunks: int = Field( |
|
default=5, |
|
description="Maximum number of similar chunks to include in context" |
|
) |
|
similarity_threshold: float = Field( |
|
default=0.7, |
|
description="Minimum similarity score to include a chunk" |
|
) |
|
interface_title: str = Field( |
|
default="Code Repository Q&A Assistant", |
|
description="Title shown in the chat interface" |
|
) |
|
interface_description: str = Field( |
|
default="Ask questions about the codebase and I'll help you understand it!", |
|
description="Description shown in the chat interface" |
|
) |
|
max_length_per_chunk: int = Field( |
|
default=8000, |
|
description="Maximum number of characters per chunk" |
|
) |
|
|
|
class EvaluatorConfig(LLMConfig): |
|
evaluation_rounds: int = Field( |
|
default=1, |
|
description="Number of evaluation rounds per test case" |
|
) |
|
|
|
class AppConfig(BaseSettings): |
|
model_config = SettingsConfigDict( |
|
env_file='.env', |
|
env_file_encoding='utf-8', |
|
env_nested_delimiter='__' |
|
) |
|
|
|
llm: LLMConfig = Field(default_factory=LLMConfig) |
|
evaluator: EvaluatorConfig = Field(default_factory=EvaluatorConfig) |
|
reranker: RerankerConfig = Field(default_factory=RerankerConfig) |
|
db: DBConfig = Field(default_factory=DBConfig) |
|
parser: ParserConfig = Field(default_factory=ParserConfig) |
|
chat: ChatConfig = Field(default_factory=ChatConfig) |
|
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig) |