Spaces:
Sleeping
Sleeping
from typing import Optional, Dict, Any, List | |
from pydantic_settings import BaseSettings, SettingsConfigDict | |
from pydantic import Field, field_validator, ValidationInfo | |
from pathlib import Path | |
import fnmatch | |
from know_lang_bot.core.types import ModelProvider | |
import os | |
def _validate_api_key(v: Optional[str], info: ValidationInfo) -> Optional[str]: | |
"""Validate API key is present when required""" | |
if info.data['model_provider'] in [ | |
ModelProvider.OPENAI, | |
ModelProvider.ANTHROPIC, | |
ModelProvider.VOYAGE | |
]: | |
if not v: | |
raise ValueError(f"API key required for {info.data['model_provider']}") | |
elif info.data['model_provider'] == ModelProvider.ANTHROPIC: | |
os.environ["ANTHROPIC_API_KEY"] = v | |
elif info.data['model_provider'] == ModelProvider.OPENAI: | |
os.environ["OPENAI_API_KEY"] = v | |
elif info.data['model_provider'] == ModelProvider.VOYAGE: | |
os.environ["VOYAGE_API_KEY"] = v | |
return v | |
class PathPatterns(BaseSettings): | |
include: List[str] = Field( | |
default=["**/*"], | |
description="Glob patterns for paths to include" | |
) | |
exclude: List[str] = Field( | |
default=[ | |
"**/venv/**", | |
"**/.git/**", | |
"**/__pycache__/**", | |
"**/tests/**", | |
], | |
description="Glob patterns for paths to exclude" | |
) | |
def should_process_path(self, path: str) -> bool: | |
"""Check if a path should be processed based on include/exclude patterns""" | |
path_str = str(path) | |
# First check exclusions | |
for pattern in self.exclude: | |
if fnmatch.fnmatch(path_str, pattern): | |
return False | |
# Then check inclusions | |
for pattern in self.include: | |
if fnmatch.fnmatch(path_str, pattern): | |
return True | |
return False | |
class LanguageConfig(BaseSettings): | |
enabled: bool = True | |
file_extensions: List[str] | |
tree_sitter_language: str | |
chunk_types: List[str] | |
max_file_size: int = Field( | |
default=1_000_000, # 1MB | |
description="Maximum file size to process in bytes" | |
) | |
class ParserConfig(BaseSettings): | |
languages: Dict[str, LanguageConfig] = Field( | |
default={ | |
"python": LanguageConfig( | |
file_extensions=[".py"], | |
tree_sitter_language="python", | |
chunk_types=["class_definition", "function_definition"] | |
) | |
} | |
) | |
path_patterns: PathPatterns = Field(default_factory=PathPatterns) | |
class EmbeddingConfig(BaseSettings): | |
"""Shared embedding configuration""" | |
model_name: str = Field( | |
default="mxbai-embed-large", | |
description="Name of the embedding model" | |
) | |
model_provider: ModelProvider = Field( | |
default=ModelProvider.OLLAMA, | |
description="Provider for embeddings" | |
) | |
dimension: int = Field( | |
default=768, | |
description="Embedding dimension" | |
) | |
settings: Dict[str, Any] = Field( | |
default_factory=dict, | |
description="Provider-specific settings" | |
) | |
api_key: Optional[str] = Field( | |
default=None, | |
description="API key for the model provider" | |
) | |
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]: | |
return _validate_api_key(v, info) | |
class LLMConfig(BaseSettings): | |
model_name: str = Field( | |
default="llama3.2", | |
description="Name of the LLM model to use" | |
) | |
model_provider: str = Field( | |
default=ModelProvider.OLLAMA, | |
description="Model provider (anthropic, openai, ollama, etc)" | |
) | |
api_key: Optional[str] = Field( | |
default=None, | |
description="API key for the model provider" | |
) | |
model_settings: Dict[str, Any] = Field( | |
default_factory=dict, | |
description="Additional model settings" | |
) | |
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]: | |
return _validate_api_key(v, info) | |
class DBConfig(BaseSettings): | |
persist_directory: Path = Field( | |
default=Path("./chroma_db"), | |
description="Directory to store ChromaDB files" | |
) | |
collection_name: str = Field( | |
default="code_chunks", | |
description="Name of the ChromaDB collection" | |
) | |
codebase_directory: Path = Field( | |
default=Path("./"), | |
description="Root directory of the codebase to analyze" | |
) | |
class RerankerConfig(BaseSettings): | |
enabled: bool = Field( | |
default=False, | |
description="Enable reranking" | |
) | |
model_name: str = Field( | |
default="reranker", | |
description="Name of the reranker model to use" | |
) | |
model_provider: str = Field( | |
default=ModelProvider.OLLAMA, | |
description="Model provider (anthropic, openai, ollama, etc)" | |
) | |
api_key: Optional[str] = Field( | |
default=None, | |
description="API key for the model provider" | |
) | |
top_k: int = Field( | |
default=4, | |
description="Number of most relevant documents to return from reranking" | |
) | |
relevance_threshold: float = Field( | |
default=0.5, | |
description="Minimum relevance score to include a document in reranking" | |
) | |
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]: | |
return _validate_api_key(v, info) | |
class ChatConfig(BaseSettings): | |
max_context_chunks: int = Field( | |
default=5, | |
description="Maximum number of similar chunks to include in context" | |
) | |
similarity_threshold: float = Field( | |
default=0.7, | |
description="Minimum similarity score to include a chunk" | |
) | |
interface_title: str = Field( | |
default="Code Repository Q&A Assistant", | |
description="Title shown in the chat interface" | |
) | |
interface_description: str = Field( | |
default="Ask questions about the codebase and I'll help you understand it!", | |
description="Description shown in the chat interface" | |
) | |
max_length_per_chunk: int = Field( | |
default=8000, | |
description="Maximum number of characters per chunk" | |
) | |
class EvaluatorConfig(LLMConfig): | |
evaluation_rounds: int = Field( | |
default=1, | |
description="Number of evaluation rounds per test case" | |
) | |
class AppConfig(BaseSettings): | |
model_config = SettingsConfigDict( | |
env_file='.env', | |
env_file_encoding='utf-8', | |
env_nested_delimiter='__' | |
) | |
llm: LLMConfig = Field(default_factory=LLMConfig) | |
evaluator: EvaluatorConfig = Field(default_factory=EvaluatorConfig) | |
reranker: RerankerConfig = Field(default_factory=RerankerConfig) | |
db: DBConfig = Field(default_factory=DBConfig) | |
parser: ParserConfig = Field(default_factory=ParserConfig) | |
chat: ChatConfig = Field(default_factory=ChatConfig) | |
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig) |