gabykim commited on
Commit
3cf6c5a
·
1 Parent(s): 6c18ccd

[BugFix] importing __future__ annotations is necessary

Browse files
src/know_lang_bot/chat_bot/chat_graph.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from dataclasses import dataclass
2
  from typing import List, Dict, Any, Optional
3
  import chromadb
@@ -7,6 +9,8 @@ import ollama
7
  from know_lang_bot.chat_bot.chat_config import ChatAppConfig
8
  from know_lang_bot.utils.fancy_log import FancyLogger
9
  from pydantic_ai import Agent
 
 
10
 
11
  LOG = FancyLogger(__name__)
12
 
@@ -38,7 +42,7 @@ class ChatGraphDeps:
38
 
39
  # Graph Nodes
40
  @dataclass
41
- class PolishQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
42
  """Node that polishes the user's question"""
43
  system_prompt = """
44
  You are an expert at understanding code-related questions and reformulating them
@@ -46,7 +50,7 @@ class PolishQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
46
  it more specific and searchable. Focus on technical terms and code concepts.
47
  """
48
 
49
- async def run(self, ctx: GraphRunContext[ChatGraphState]) -> RetrieveContext:
50
  # Create an agent for question polishing
51
  from pydantic_ai import Agent
52
  polish_agent = Agent(
@@ -62,13 +66,13 @@ class PolishQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
62
 
63
  result = await polish_agent.run(prompt)
64
  ctx.state.polished_question = result.data
65
- return RetrieveContext()
66
 
67
  @dataclass
68
- class RetrieveContext(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
69
  """Node that retrieves relevant code context"""
70
 
71
- async def run(self, ctx: GraphRunContext[ChatGraphState]) -> AnswerQuestion:
72
  try:
73
  embedded_question = ollama.embed(
74
  model=ctx.deps.config.llm.embedding_model,
@@ -80,6 +84,7 @@ class RetrieveContext(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
80
  n_results=ctx.deps.config.chat.max_context_chunks,
81
  include=['metadatas', 'documents', 'distances']
82
  )
 
83
 
84
  relevant_chunks = []
85
  relevant_metadatas = []
@@ -102,20 +107,23 @@ class RetrieveContext(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
102
  ref += f"\n- {meta['type']}: `{meta['name']}`"
103
  references.append(ref)
104
 
 
 
 
 
105
  ctx.state.retrieved_context = RetrievedContext(
106
  chunks=relevant_chunks,
107
  metadatas=relevant_metadatas,
108
  references_md="\n\n".join(references)
109
  )
110
 
111
- return AnswerQuestion()
112
-
113
  except Exception as e:
114
  LOG.error(f"Error retrieving context: {e}")
115
- return AnswerQuestion()
 
116
 
117
  @dataclass
118
- class AnswerQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
119
  """Node that generates the final answer"""
120
  system_prompt = """
121
  You are an expert code assistant helping users understand a codebase.
@@ -126,8 +134,7 @@ class AnswerQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
126
  4. If you're unsure about something, acknowledge it
127
  """
128
 
129
- async def run(self, ctx: GraphRunContext[ChatGraphState]) -> End[ChatResult]:
130
-
131
  answer_agent = Agent(
132
  f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
133
  system_prompt=self.system_prompt
@@ -166,7 +173,7 @@ class AnswerQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
166
 
167
  # Create the graph
168
  chat_graph = Graph(
169
- nodes=[PolishQuestion, RetrieveContext, AnswerQuestion]
170
  )
171
 
172
  async def process_chat(
@@ -182,7 +189,7 @@ async def process_chat(
182
  deps = ChatGraphDeps(collection=collection, config=config)
183
 
184
  result, _history = await chat_graph.run(
185
- PolishQuestion(),
186
  state=state,
187
  deps=deps
188
  )
 
1
+ # __future__ annotations is necessary for the type hints to work in this file
2
+ from __future__ import annotations
3
  from dataclasses import dataclass
4
  from typing import List, Dict, Any, Optional
5
  import chromadb
 
9
  from know_lang_bot.chat_bot.chat_config import ChatAppConfig
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pydantic_ai import Agent
12
+ import logfire
13
+ from pprint import pformat
14
 
15
  LOG = FancyLogger(__name__)
16
 
 
42
 
43
  # Graph Nodes
44
  @dataclass
45
+ class PolishQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
46
  """Node that polishes the user's question"""
47
  system_prompt = """
48
  You are an expert at understanding code-related questions and reformulating them
 
50
  it more specific and searchable. Focus on technical terms and code concepts.
51
  """
52
 
53
+ async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
54
  # Create an agent for question polishing
55
  from pydantic_ai import Agent
56
  polish_agent = Agent(
 
66
 
67
  result = await polish_agent.run(prompt)
68
  ctx.state.polished_question = result.data
69
+ return RetrieveContextNode()
70
 
71
  @dataclass
72
+ class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
73
  """Node that retrieves relevant code context"""
74
 
75
+ async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
76
  try:
77
  embedded_question = ollama.embed(
78
  model=ctx.deps.config.llm.embedding_model,
 
84
  n_results=ctx.deps.config.chat.max_context_chunks,
85
  include=['metadatas', 'documents', 'distances']
86
  )
87
+ logfire.debug('query result: {result}', result=pformat(results))
88
 
89
  relevant_chunks = []
90
  relevant_metadatas = []
 
107
  ref += f"\n- {meta['type']}: `{meta['name']}`"
108
  references.append(ref)
109
 
110
+ with logfire.span('formatted {count} references', count=len(references)):
111
+ for ref in references:
112
+ logfire.debug(ref)
113
+
114
  ctx.state.retrieved_context = RetrievedContext(
115
  chunks=relevant_chunks,
116
  metadatas=relevant_metadatas,
117
  references_md="\n\n".join(references)
118
  )
119
 
 
 
120
  except Exception as e:
121
  LOG.error(f"Error retrieving context: {e}")
122
+ finally:
123
+ return AnswerQuestionNode()
124
 
125
  @dataclass
126
+ class AnswerQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
127
  """Node that generates the final answer"""
128
  system_prompt = """
129
  You are an expert code assistant helping users understand a codebase.
 
134
  4. If you're unsure about something, acknowledge it
135
  """
136
 
137
+ async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
 
138
  answer_agent = Agent(
139
  f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
140
  system_prompt=self.system_prompt
 
173
 
174
  # Create the graph
175
  chat_graph = Graph(
176
+ nodes=[PolishQuestionNode, RetrieveContextNode, AnswerQuestionNode]
177
  )
178
 
179
  async def process_chat(
 
189
  deps = ChatGraphDeps(collection=collection, config=config)
190
 
191
  result, _history = await chat_graph.run(
192
+ PolishQuestionNode(),
193
  state=state,
194
  deps=deps
195
  )