gabykim's picture
add evalution round config setting
3a9d0c3
raw
history blame
7.16 kB
from typing import Optional, Dict, Any, List
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field, field_validator, ValidationInfo
from pathlib import Path
import fnmatch
from know_lang_bot.core.types import ModelProvider
import os
def _validate_api_key(v: Optional[str], info: ValidationInfo) -> Optional[str]:
"""Validate API key is present when required"""
if info.data['model_provider'] in [
ModelProvider.OPENAI,
ModelProvider.ANTHROPIC,
ModelProvider.VOYAGE
]:
if not v:
raise ValueError(f"API key required for {info.data['model_provider']}")
elif info.data['model_provider'] == ModelProvider.ANTHROPIC:
os.environ["ANTHROPIC_API_KEY"] = v
elif info.data['model_provider'] == ModelProvider.OPENAI:
os.environ["OPENAI_API_KEY"] = v
elif info.data['model_provider'] == ModelProvider.VOYAGE:
os.environ["VOYAGE_API_KEY"] = v
return v
class PathPatterns(BaseSettings):
include: List[str] = Field(
default=["**/*"],
description="Glob patterns for paths to include"
)
exclude: List[str] = Field(
default=[
"**/venv/**",
"**/.git/**",
"**/__pycache__/**",
"**/tests/**",
],
description="Glob patterns for paths to exclude"
)
def should_process_path(self, path: str) -> bool:
"""Check if a path should be processed based on include/exclude patterns"""
path_str = str(path)
# First check exclusions
for pattern in self.exclude:
if fnmatch.fnmatch(path_str, pattern):
return False
# Then check inclusions
for pattern in self.include:
if fnmatch.fnmatch(path_str, pattern):
return True
return False
class LanguageConfig(BaseSettings):
enabled: bool = True
file_extensions: List[str]
tree_sitter_language: str
chunk_types: List[str]
max_file_size: int = Field(
default=1_000_000, # 1MB
description="Maximum file size to process in bytes"
)
class ParserConfig(BaseSettings):
languages: Dict[str, LanguageConfig] = Field(
default={
"python": LanguageConfig(
file_extensions=[".py"],
tree_sitter_language="python",
chunk_types=["class_definition", "function_definition"]
)
}
)
path_patterns: PathPatterns = Field(default_factory=PathPatterns)
class EmbeddingConfig(BaseSettings):
"""Shared embedding configuration"""
model_name: str = Field(
default="mxbai-embed-large",
description="Name of the embedding model"
)
model_provider: ModelProvider = Field(
default=ModelProvider.OLLAMA,
description="Provider for embeddings"
)
dimension: int = Field(
default=768,
description="Embedding dimension"
)
settings: Dict[str, Any] = Field(
default_factory=dict,
description="Provider-specific settings"
)
api_key: Optional[str] = Field(
default=None,
description="API key for the model provider"
)
@field_validator('api_key', mode='after')
@classmethod
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
return _validate_api_key(v, info)
class LLMConfig(BaseSettings):
model_name: str = Field(
default="llama3.2",
description="Name of the LLM model to use"
)
model_provider: str = Field(
default=ModelProvider.OLLAMA,
description="Model provider (anthropic, openai, ollama, etc)"
)
api_key: Optional[str] = Field(
default=None,
description="API key for the model provider"
)
model_settings: Dict[str, Any] = Field(
default_factory=dict,
description="Additional model settings"
)
@field_validator('api_key', mode='after')
@classmethod
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
return _validate_api_key(v, info)
class DBConfig(BaseSettings):
persist_directory: Path = Field(
default=Path("./chroma_db"),
description="Directory to store ChromaDB files"
)
collection_name: str = Field(
default="code_chunks",
description="Name of the ChromaDB collection"
)
codebase_directory: Path = Field(
default=Path("./"),
description="Root directory of the codebase to analyze"
)
class RerankerConfig(BaseSettings):
enabled: bool = Field(
default=False,
description="Enable reranking"
)
model_name: str = Field(
default="reranker",
description="Name of the reranker model to use"
)
model_provider: str = Field(
default=ModelProvider.OLLAMA,
description="Model provider (anthropic, openai, ollama, etc)"
)
api_key: Optional[str] = Field(
default=None,
description="API key for the model provider"
)
top_k: int = Field(
default=4,
description="Number of most relevant documents to return from reranking"
)
relevance_threshold: float = Field(
default=0.5,
description="Minimum relevance score to include a document in reranking"
)
@field_validator('api_key', mode='after')
@classmethod
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
return _validate_api_key(v, info)
class ChatConfig(BaseSettings):
max_context_chunks: int = Field(
default=5,
description="Maximum number of similar chunks to include in context"
)
similarity_threshold: float = Field(
default=0.7,
description="Minimum similarity score to include a chunk"
)
interface_title: str = Field(
default="Code Repository Q&A Assistant",
description="Title shown in the chat interface"
)
interface_description: str = Field(
default="Ask questions about the codebase and I'll help you understand it!",
description="Description shown in the chat interface"
)
max_length_per_chunk: int = Field(
default=8000,
description="Maximum number of characters per chunk"
)
class EvaluatorConfig(LLMConfig):
evaluation_rounds: int = Field(
default=1,
description="Number of evaluation rounds per test case"
)
class AppConfig(BaseSettings):
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
env_nested_delimiter='__'
)
llm: LLMConfig = Field(default_factory=LLMConfig)
evaluator: EvaluatorConfig = Field(default_factory=EvaluatorConfig)
reranker: RerankerConfig = Field(default_factory=RerankerConfig)
db: DBConfig = Field(default_factory=DBConfig)
parser: ParserConfig = Field(default_factory=ParserConfig)
chat: ChatConfig = Field(default_factory=ChatConfig)
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)