gabykim commited on
Commit
6b5ac9a
·
1 Parent(s): 279fcbd

chat graph implemented

Browse files
src/know_lang_bot/chat_bot/__main__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
2
+ from know_lang_bot.chat_bot.chat_graph import process_chat
3
+ import chromadb
4
+ import asyncio
5
+
6
+ async def test_chat_processing():
7
+ config = chat_app_config
8
+ db_client = chromadb.PersistentClient(
9
+ path=str(config.db.persist_directory)
10
+ )
11
+ collection = db_client.get_collection(
12
+ name=config.db.collection_name
13
+ )
14
+
15
+ result = await process_chat(
16
+ "How does the parser handle nested classes?",
17
+ collection,
18
+ config
19
+ )
20
+
21
+ print(f"Answer: {result.answer}")
22
+ print(f"References: {result.references_md}")
23
+
24
+ if __name__ == "__main__":
25
+ asyncio.run(test_chat_processing())
src/know_lang_bot/chat_bot/chat_agent.py DELETED
@@ -1,100 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List, Dict, Any, Optional
3
- import chromadb
4
- from pathlib import Path
5
- from pydantic_ai import Agent, RunContext
6
- from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
7
- from know_lang_bot.utils.fancy_log import FancyLogger
8
- from pydantic import BaseModel
9
- import ollama
10
- import logfire
11
-
12
- LOG = FancyLogger(__name__)
13
-
14
- @dataclass
15
- class CodeQADeps:
16
- """Dependencies for the Code Q&A Agent"""
17
- collection: chromadb.Collection
18
- config: ChatAppConfig
19
-
20
- class RetrievedContext(BaseModel):
21
- """Structure for retrieved context"""
22
- chunks: List[str]
23
- metadatas: List[Dict[str, Any]]
24
- references_md: str
25
-
26
- class AgentResponse(BaseModel):
27
- """Structure for agent responses"""
28
- answer: str
29
- references_md: Optional[str] = None
30
-
31
- # Initialize the agent with system prompt and dependencies
32
- code_qa_agent = Agent(
33
- f'{chat_app_config.llm.model_provider}:{chat_app_config.llm.model_name}',
34
- deps_type=CodeQADeps,
35
- result_type=AgentResponse,
36
- system_prompt="""
37
- You are an expert code assistant helping users understand a codebase.
38
-
39
- Always:
40
- 1. Reference specific files and line numbers in your explanations
41
- 2. Be direct and concise while being comprehensive
42
- 3. If the context is insufficient, explain why
43
- 4. If you're unsure about something, acknowledge it
44
-
45
- Your response should be helpful for software engineers trying to understand complex codebases.
46
- """,
47
- )
48
-
49
- @code_qa_agent.tool
50
- @logfire.instrument()
51
- async def retrieve_context(
52
- ctx: RunContext[CodeQADeps],
53
- question: str
54
- ) -> RetrievedContext:
55
- """
56
- Retrieve relevant code context from the vector database.
57
-
58
- Args:
59
- ctx: The context containing dependencies
60
- question: The user's question to find relevant code for
61
- """
62
- embedded_question = ollama.embed(
63
- model=ctx.deps.config.llm.embedding_model,
64
- input=question
65
- )
66
-
67
- results = ctx.deps.collection.query(
68
- query_embeddings=embedded_question['embeddings'],
69
- n_results=ctx.deps.config.chat.max_context_chunks,
70
- include=['metadatas', 'documents', 'distances']
71
- )
72
-
73
- relevant_chunks = []
74
- relevant_metadatas = []
75
-
76
- for doc, meta, dist in zip(
77
- results['documents'][0],
78
- results['metadatas'][0],
79
- results['distances'][0]
80
- ):
81
- if dist <= ctx.deps.config.chat.similarity_threshold:
82
- relevant_chunks.append(doc)
83
- relevant_metadatas.append(meta)
84
-
85
-
86
- # Format references for display
87
- references = []
88
- for meta in relevant_metadatas:
89
- file_path = Path(meta['file_path']).name
90
- ref = f"**{file_path}** (lines {meta['start_line']}-{meta['end_line']})"
91
- if meta.get('name'):
92
- ref += f"\n- {meta['type']}: `{meta['name']}`"
93
- references.append(ref)
94
-
95
- return RetrievedContext(
96
- chunks=relevant_chunks,
97
- metadatas=relevant_metadatas,
98
- references_md="\n\n".join(references)
99
- )
100
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/know_lang_bot/chat_bot/chat_graph.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import List, Dict, Any, Optional
4
+ import chromadb
5
+ from pydantic import BaseModel
6
+ from pydantic_graph import BaseNode, Graph, GraphRunContext, End
7
+ import ollama
8
+ import logfire
9
+ from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
10
+ from know_lang_bot.utils.fancy_log import FancyLogger
11
+ from pydantic_ai import Agent
12
+
13
+ LOG = FancyLogger(__name__)
14
+
15
+ # Data Models
16
+ class RetrievedContext(BaseModel):
17
+ """Structure for retrieved context"""
18
+ chunks: List[str]
19
+ metadatas: List[Dict[str, Any]]
20
+ references_md: str
21
+
22
+ class ChatResult(BaseModel):
23
+ """Final result from the chat graph"""
24
+ answer: str
25
+ references_md: Optional[str] = None
26
+
27
+
28
+
29
+
30
+ @dataclass
31
+ class ChatGraphState:
32
+ """State maintained throughout the graph execution"""
33
+ original_question: str
34
+ polished_question: Optional[str] = None
35
+ retrieved_context: Optional[RetrievedContext] = None
36
+
37
+ @dataclass
38
+ class ChatGraphDeps:
39
+ """Dependencies required by the graph"""
40
+ collection: chromadb.Collection
41
+ config: ChatAppConfig
42
+
43
+
44
+ # Graph Nodes
45
+ @dataclass
46
+ class PolishQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
47
+ """Node that polishes the user's question"""
48
+ system_prompt = """
49
+ You are an expert at understanding code-related questions and reformulating them
50
+ for better context retrieval. Your task is to polish the user's question to make
51
+ it more specific and searchable. Focus on technical terms and code concepts.
52
+ """
53
+
54
+ async def run(self, ctx: GraphRunContext[ChatGraphState]) -> RetrieveContext:
55
+ # Create an agent for question polishing
56
+ from pydantic_ai import Agent
57
+ polish_agent = Agent(
58
+ f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}"
59
+ )
60
+ prompt = f"""
61
+ Original question: {ctx.state.original_question}
62
+
63
+ Please reformulate this question to be more specific and searchable,
64
+ focusing on technical terms and code concepts. Keep the core meaning
65
+ but make it more precise for code context retrieval.
66
+ """
67
+
68
+ result = await polish_agent.run(prompt)
69
+ ctx.state.polished_question = result.data
70
+ return RetrieveContext()
71
+
72
+ @dataclass
73
+ class RetrieveContext(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
74
+ """Node that retrieves relevant code context"""
75
+
76
+ async def run(self, ctx: GraphRunContext[ChatGraphState]) -> AnswerQuestion:
77
+ try:
78
+ embedded_question = ollama.embed(
79
+ model=ctx.deps.config.llm.embedding_model,
80
+ input=ctx.state.polished_question or ctx.state.original_question
81
+ )
82
+
83
+ results = ctx.deps.collection.query(
84
+ query_embeddings=embedded_question['embeddings'],
85
+ n_results=ctx.deps.config.chat.max_context_chunks,
86
+ include=['metadatas', 'documents', 'distances']
87
+ )
88
+
89
+ relevant_chunks = []
90
+ relevant_metadatas = []
91
+
92
+ for doc, meta, dist in zip(
93
+ results['documents'][0],
94
+ results['metadatas'][0],
95
+ results['distances'][0]
96
+ ):
97
+ if dist <= ctx.deps.config.chat.similarity_threshold:
98
+ relevant_chunks.append(doc)
99
+ relevant_metadatas.append(meta)
100
+
101
+ # Format references for display
102
+ references = []
103
+ for meta in relevant_metadatas:
104
+ file_path = meta['file_path'].split('/')[-1]
105
+ ref = f"**{file_path}** (lines {meta['start_line']}-{meta['end_line']})"
106
+ if meta.get('name'):
107
+ ref += f"\n- {meta['type']}: `{meta['name']}`"
108
+ references.append(ref)
109
+
110
+ ctx.state.retrieved_context = RetrievedContext(
111
+ chunks=relevant_chunks,
112
+ metadatas=relevant_metadatas,
113
+ references_md="\n\n".join(references)
114
+ )
115
+
116
+ return AnswerQuestion()
117
+
118
+ except Exception as e:
119
+ LOG.error(f"Error retrieving context: {e}")
120
+ return AnswerQuestion()
121
+
122
+ @dataclass
123
+ class AnswerQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
124
+ """Node that generates the final answer"""
125
+ system_prompt = """
126
+ You are an expert code assistant helping users understand a codebase.
127
+ Always:
128
+ 1. Reference specific files and line numbers in your explanations
129
+ 2. Be direct and concise while being comprehensive
130
+ 3. If the context is insufficient, explain why
131
+ 4. If you're unsure about something, acknowledge it
132
+ """
133
+
134
+ async def run(self, ctx: GraphRunContext[ChatGraphState]) -> End[ChatResult]:
135
+
136
+ answer_agent = Agent(
137
+ f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
138
+ system_prompt=self.system_prompt
139
+ )
140
+
141
+ if not ctx.state.retrieved_context or not ctx.state.retrieved_context.chunks:
142
+ return End(ChatResult(
143
+ answer="I couldn't find any relevant code context for your question. "
144
+ "Could you please rephrase or be more specific?",
145
+ references_md=""
146
+ ))
147
+
148
+ context = ctx.state.retrieved_context
149
+ prompt = f"""
150
+ Question: {ctx.state.original_question}
151
+
152
+ Available Code Context:
153
+ {context.chunks}
154
+
155
+ Please provide a comprehensive answer based on the code context above.
156
+ Make sure to reference specific files and line numbers from the context.
157
+ """
158
+
159
+ try:
160
+ result = await answer_agent.run(prompt)
161
+ return End(ChatResult(
162
+ answer=result.data,
163
+ references_md=context.references_md
164
+ ))
165
+ except Exception as e:
166
+ LOG.error(f"Error generating answer: {e}")
167
+ return End(ChatResult(
168
+ answer="I encountered an error processing your question. Please try again.",
169
+ references_md=""
170
+ ))
171
+
172
+ # Create the graph
173
+ chat_graph = Graph(
174
+ nodes=[PolishQuestion, RetrieveContext, AnswerQuestion]
175
+ )
176
+
177
+ async def process_chat(
178
+ question: str,
179
+ collection: chromadb.Collection,
180
+ config: ChatAppConfig
181
+ ) -> ChatResult:
182
+ """
183
+ Process a chat question through the graph.
184
+ This is the main entry point for chat processing.
185
+ """
186
+ state = ChatGraphState(original_question=question)
187
+ deps = ChatGraphDeps(collection=collection, config=config)
188
+
189
+ result, _history = await chat_graph.run(
190
+ PolishQuestion(),
191
+ state=state,
192
+ deps=deps
193
+ )
194
+
195
+ return result
src/know_lang_bot/chat_bot/chat_interface.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
3
  from know_lang_bot.utils.fancy_log import FancyLogger
4
- from know_lang_bot.chat_bot.chat_agent import code_qa_agent, CodeQADeps, AgentResponse
5
  import chromadb
6
  from typing import List, Dict
7
  import logfire
@@ -12,7 +12,6 @@ class CodeQAChatInterface:
12
  def __init__(self, config: ChatAppConfig):
13
  self.config = config
14
  self._init_chroma()
15
- self.agent = code_qa_agent
16
 
17
  def _init_chroma(self):
18
  """Initialize ChromaDB connection"""
@@ -28,23 +27,9 @@ class CodeQAChatInterface:
28
  self,
29
  message: str,
30
  history: List[Dict[str, str]]
31
- ) -> AgentResponse:
32
  """Process a question and return the answer with references"""
33
- try:
34
- deps = CodeQADeps(
35
- collection=self.collection,
36
- config=self.config
37
- )
38
-
39
- response = await self.agent.run(message, deps=deps)
40
- return response.data
41
-
42
- except Exception as e:
43
- LOG.error(f"Error processing question: {e}")
44
- return AgentResponse(
45
- answer="I encountered an error processing your question. Please try again.",
46
- references_md=""
47
- )
48
 
49
  def create_interface(self) -> gr.Blocks:
50
  """Create the Gradio interface"""
@@ -54,10 +39,7 @@ class CodeQAChatInterface:
54
 
55
  with gr.Row():
56
  with gr.Column(scale=2):
57
- chatbot = gr.Chatbot(
58
- type="messages",
59
- bubble_full_width=False
60
- )
61
  msg = gr.Textbox(
62
  label="Ask about the codebase",
63
  placeholder="What does the CodeParser class do?",
@@ -72,13 +54,13 @@ class CodeQAChatInterface:
72
  )
73
 
74
  async def respond(message, history):
75
- response = await self.process_question(message, history)
76
- references.value = response.references_md
77
  return {
78
  msg: "",
79
  chatbot: history + [
80
  {"role": "user", "content": message},
81
- {"role": "assistant", "content": response.answer}
82
  ]
83
  }
84
 
 
1
  import gradio as gr
2
  from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
3
  from know_lang_bot.utils.fancy_log import FancyLogger
4
+ from know_lang_bot.chat_bot.chat_graph import ChatResult, process_chat
5
  import chromadb
6
  from typing import List, Dict
7
  import logfire
 
12
  def __init__(self, config: ChatAppConfig):
13
  self.config = config
14
  self._init_chroma()
 
15
 
16
  def _init_chroma(self):
17
  """Initialize ChromaDB connection"""
 
27
  self,
28
  message: str,
29
  history: List[Dict[str, str]]
30
+ ) -> ChatResult:
31
  """Process a question and return the answer with references"""
32
+ return await process_chat(message, self.collection, self.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def create_interface(self) -> gr.Blocks:
35
  """Create the Gradio interface"""
 
39
 
40
  with gr.Row():
41
  with gr.Column(scale=2):
42
+ chatbot = gr.Chatbot(type="messages", bubble_full_width=False)
 
 
 
43
  msg = gr.Textbox(
44
  label="Ask about the codebase",
45
  placeholder="What does the CodeParser class do?",
 
54
  )
55
 
56
  async def respond(message, history):
57
+ result = await self.process_question(message, history)
58
+ references.value = result.references_md
59
  return {
60
  msg: "",
61
  chatbot: history + [
62
  {"role": "user", "content": message},
63
+ {"role": "assistant", "content": result.answer}
64
  ]
65
  }
66
 
src/know_lang_bot/chat_bot/{run.py → gradio.py} RENAMED
File without changes