gabykim commited on
Commit
070f7e7
·
1 Parent(s): 0e9e5fc

app configuration refactoring

Browse files
src/know_lang_bot/__main__.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Optional
5
  from rich.console import Console
6
  from rich.table import Table
7
 
8
- from know_lang_bot.code_parser.parser import CodeChunk
9
  from know_lang_bot.config import AppConfig
10
  from know_lang_bot.parser.factory import CodeParserFactory
11
  from know_lang_bot.parser.providers.git import GitProvider
 
5
  from rich.console import Console
6
  from rich.table import Table
7
 
8
+ from know_lang_bot.core.types import CodeChunk
9
  from know_lang_bot.config import AppConfig
10
  from know_lang_bot.parser.factory import CodeParserFactory
11
  from know_lang_bot.parser.providers.git import GitProvider
src/know_lang_bot/chat_bot/chat_config.py DELETED
@@ -1,27 +0,0 @@
1
- from pydantic_settings import BaseSettings
2
- from pydantic import Field
3
- from know_lang_bot.config import AppConfig
4
-
5
- class ChatConfig(BaseSettings):
6
- max_context_chunks: int = Field(
7
- default=5,
8
- description="Maximum number of similar chunks to include in context"
9
- )
10
- similarity_threshold: float = Field(
11
- default=0.7,
12
- description="Minimum similarity score to include a chunk"
13
- )
14
- interface_title: str = Field(
15
- default="Code Repository Q&A Assistant",
16
- description="Title shown in the chat interface"
17
- )
18
- interface_description: str = Field(
19
- default="Ask questions about the codebase and I'll help you understand it!",
20
- description="Description shown in the chat interface"
21
- )
22
-
23
- class ChatAppConfig(AppConfig):
24
- chat: ChatConfig = Field(default_factory=ChatConfig)
25
-
26
-
27
- chat_app_config = ChatAppConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/know_lang_bot/chat_bot/chat_graph.py CHANGED
@@ -6,7 +6,7 @@ import chromadb
6
  from pydantic import BaseModel
7
  from pydantic_graph import BaseNode, Graph, GraphRunContext, End
8
  import ollama
9
- from know_lang_bot.chat_bot.chat_config import ChatAppConfig
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pydantic_ai import Agent
12
  import logfire
@@ -36,7 +36,7 @@ class ChatGraphState:
36
  class ChatGraphDeps:
37
  """Dependencies required by the graph"""
38
  collection: chromadb.Collection
39
- config: ChatAppConfig
40
 
41
 
42
  # Graph Nodes
@@ -74,7 +74,7 @@ class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
74
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
75
  try:
76
  embedded_question = ollama.embed(
77
- model=ctx.deps.config.llm.embedding_model,
78
  input=ctx.state.polished_question or ctx.state.original_question
79
  )
80
 
@@ -164,7 +164,7 @@ chat_graph = Graph(
164
  async def process_chat(
165
  question: str,
166
  collection: chromadb.Collection,
167
- config: ChatAppConfig
168
  ) -> ChatResult:
169
  """
170
  Process a chat question through the graph.
 
6
  from pydantic import BaseModel
7
  from pydantic_graph import BaseNode, Graph, GraphRunContext, End
8
  import ollama
9
+ from know_lang_bot.config import AppConfig
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pydantic_ai import Agent
12
  import logfire
 
36
  class ChatGraphDeps:
37
  """Dependencies required by the graph"""
38
  collection: chromadb.Collection
39
+ config: AppConfig
40
 
41
 
42
  # Graph Nodes
 
74
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
75
  try:
76
  embedded_question = ollama.embed(
77
+ model=ctx.deps.config.embedding.model_name,
78
  input=ctx.state.polished_question or ctx.state.original_question
79
  )
80
 
 
164
  async def process_chat(
165
  question: str,
166
  collection: chromadb.Collection,
167
+ config: AppConfig
168
  ) -> ChatResult:
169
  """
170
  Process a chat question through the graph.
src/know_lang_bot/config.py CHANGED
@@ -1,8 +1,9 @@
1
  from typing import Optional, Dict, Any, List
2
  from pydantic_settings import BaseSettings, SettingsConfigDict
3
- from pydantic import Field
4
  from pathlib import Path
5
  import fnmatch
 
6
 
7
  class PathPatterns(BaseSettings):
8
  include: List[str] = Field(
@@ -53,13 +54,32 @@ class ParserConfig(BaseSettings):
53
  path_patterns: PathPatterns = Field(default_factory=PathPatterns)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class LLMConfig(BaseSettings):
57
  model_name: str = Field(
58
  default="llama3.2",
59
  description="Name of the LLM model to use"
60
  )
61
  model_provider: str = Field(
62
- default="ollama",
63
  description="Model provider (anthropic, openai, ollama, etc)"
64
  )
65
  api_key: Optional[str] = Field(
@@ -70,14 +90,14 @@ class LLMConfig(BaseSettings):
70
  default_factory=dict,
71
  description="Additional model settings"
72
  )
73
- embedding_model: str = Field(
74
- default="mxbai-embed-large",
75
- description="Name of the embedding model to use"
76
- )
77
- embedding_provider: str = Field(
78
- default="ollama",
79
- description="Provider for embeddings (ollama, openai, etc)"
80
- )
81
 
82
  class DBConfig(BaseSettings):
83
  persist_directory: Path = Field(
@@ -88,15 +108,29 @@ class DBConfig(BaseSettings):
88
  default="code_chunks",
89
  description="Name of the ChromaDB collection"
90
  )
91
- embedding_model: str = Field(
92
- default="sentence-transformers/all-mpnet-base-v2",
93
- description="Embedding model to use"
94
- )
95
  codebase_directory: Path = Field(
96
  default=Path("./"),
97
  description="Root directory of the codebase to analyze"
98
  )
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  class AppConfig(BaseSettings):
101
  model_config = SettingsConfigDict(
102
  env_file='.env',
@@ -106,4 +140,6 @@ class AppConfig(BaseSettings):
106
 
107
  llm: LLMConfig = Field(default_factory=LLMConfig)
108
  db: DBConfig = Field(default_factory=DBConfig)
109
- parser: ParserConfig = Field(default_factory=ParserConfig)
 
 
 
1
  from typing import Optional, Dict, Any, List
2
  from pydantic_settings import BaseSettings, SettingsConfigDict
3
+ from pydantic import Field, field_validator, ValidationInfo
4
  from pathlib import Path
5
  import fnmatch
6
+ from know_lang_bot.core.types import ModelProvider
7
 
8
  class PathPatterns(BaseSettings):
9
  include: List[str] = Field(
 
54
  path_patterns: PathPatterns = Field(default_factory=PathPatterns)
55
 
56
 
57
+ class EmbeddingConfig(BaseSettings):
58
+ """Shared embedding configuration"""
59
+ model_name: str = Field(
60
+ default="mxbai-embed-large",
61
+ description="Name of the embedding model"
62
+ )
63
+ provider: ModelProvider = Field(
64
+ default=ModelProvider.OLLAMA,
65
+ description="Provider for embeddings"
66
+ )
67
+ dimension: int = Field(
68
+ default=768,
69
+ description="Embedding dimension"
70
+ )
71
+ settings: Dict[str, Any] = Field(
72
+ default_factory=dict,
73
+ description="Provider-specific settings"
74
+ )
75
+
76
  class LLMConfig(BaseSettings):
77
  model_name: str = Field(
78
  default="llama3.2",
79
  description="Name of the LLM model to use"
80
  )
81
  model_provider: str = Field(
82
+ default=ModelProvider.OLLAMA,
83
  description="Model provider (anthropic, openai, ollama, etc)"
84
  )
85
  api_key: Optional[str] = Field(
 
90
  default_factory=dict,
91
  description="Additional model settings"
92
  )
93
+
94
+ @field_validator('api_key', mode='after')
95
+ @classmethod
96
+ def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
97
+ """Validate API key is present when required"""
98
+ if info.data['model_provider'] in [ModelProvider.OPENAI, ModelProvider.ANTHROPIC] and not v:
99
+ raise ValueError(f"API key required for {info.data['model_provider']}")
100
+ return v
101
 
102
  class DBConfig(BaseSettings):
103
  persist_directory: Path = Field(
 
108
  default="code_chunks",
109
  description="Name of the ChromaDB collection"
110
  )
 
 
 
 
111
  codebase_directory: Path = Field(
112
  default=Path("./"),
113
  description="Root directory of the codebase to analyze"
114
  )
115
 
116
+ class ChatConfig(BaseSettings):
117
+ max_context_chunks: int = Field(
118
+ default=5,
119
+ description="Maximum number of similar chunks to include in context"
120
+ )
121
+ similarity_threshold: float = Field(
122
+ default=0.7,
123
+ description="Minimum similarity score to include a chunk"
124
+ )
125
+ interface_title: str = Field(
126
+ default="Code Repository Q&A Assistant",
127
+ description="Title shown in the chat interface"
128
+ )
129
+ interface_description: str = Field(
130
+ default="Ask questions about the codebase and I'll help you understand it!",
131
+ description="Description shown in the chat interface"
132
+ )
133
+
134
  class AppConfig(BaseSettings):
135
  model_config = SettingsConfigDict(
136
  env_file='.env',
 
140
 
141
  llm: LLMConfig = Field(default_factory=LLMConfig)
142
  db: DBConfig = Field(default_factory=DBConfig)
143
+ parser: ParserConfig = Field(default_factory=ParserConfig)
144
+ chat: ChatConfig = Field(default_factory=ChatConfig)
145
+ embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
src/know_lang_bot/core/types.py CHANGED
@@ -16,4 +16,11 @@ class CodeChunk(BaseModel):
16
  file_path: str
17
  name: Optional[str] = None
18
  parent_name: Optional[str] = None # For nested classes/functions
19
- docstring: Optional[str] = None
 
 
 
 
 
 
 
 
16
  file_path: str
17
  name: Optional[str] = None
18
  parent_name: Optional[str] = None # For nested classes/functions
19
+ docstring: Optional[str] = None
20
+
21
+
22
+ class ModelProvider(str, Enum):
23
+ OPENAI = "openai"
24
+ ANTHROPIC = "anthropic"
25
+ OLLAMA = "ollama"
26
+ HUGGINGFACE = "huggingface"
src/know_lang_bot/summarizer/summarizer.py CHANGED
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
6
  import ollama
7
 
8
  from know_lang_bot.config import AppConfig
9
- from know_lang_bot.code_parser.parser import CodeChunk
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pprint import pformat
12
 
@@ -65,14 +65,14 @@ class CodeSummarizer:
65
 
66
  def _get_embedding(self, text: str) -> List[float]:
67
  """Get embedding for text using configured provider"""
68
- if self.config.llm.embedding_provider == "ollama":
69
  response = ollama.embed(
70
- model=self.config.llm.embedding_model,
71
  input=text
72
  )
73
  return response['embeddings']
74
  else:
75
- raise ValueError(f"Unsupported embedding provider: {self.config.llm.embedding_provider}")
76
 
77
  async def summarize_chunk(self, chunk: CodeChunk) -> str:
78
  """Summarize a single code chunk using the LLM"""
 
6
  import ollama
7
 
8
  from know_lang_bot.config import AppConfig
9
+ from know_lang_bot.core.types import CodeChunk, ModelProvider
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pprint import pformat
12
 
 
65
 
66
  def _get_embedding(self, text: str) -> List[float]:
67
  """Get embedding for text using configured provider"""
68
+ if self.config.embedding.provider == ModelProvider.OLLAMA:
69
  response = ollama.embed(
70
+ model=self.config.embedding.model_name,
71
  input=text
72
  )
73
  return response['embeddings']
74
  else:
75
+ raise ValueError(f"Unsupported embedding provider: {self.config.embedding.provider}")
76
 
77
  async def summarize_chunk(self, chunk: CodeChunk) -> str:
78
  """Summarize a single code chunk using the LLM"""
tests/test_summarizer.py CHANGED
@@ -113,7 +113,7 @@ async def test_process_and_store_chunk_with_embedding(
113
 
114
  # Verify ollama.embed was called with correct parameters
115
  mock_ollama.embed.assert_called_once_with(
116
- model=config.llm.embedding_model,
117
  input=mock_run_result.data
118
  )
119
 
 
113
 
114
  # Verify ollama.embed was called with correct parameters
115
  mock_ollama.embed.assert_called_once_with(
116
+ model=config.embedding.model_name,
117
  input=mock_run_result.data
118
  )
119