[BugFix] importing __future__ annotations is necessary
Browse files
src/know_lang_bot/chat_bot/chat_graph.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import List, Dict, Any, Optional
|
3 |
import chromadb
|
@@ -7,6 +9,8 @@ import ollama
|
|
7 |
from know_lang_bot.chat_bot.chat_config import ChatAppConfig
|
8 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
9 |
from pydantic_ai import Agent
|
|
|
|
|
10 |
|
11 |
LOG = FancyLogger(__name__)
|
12 |
|
@@ -38,7 +42,7 @@ class ChatGraphDeps:
|
|
38 |
|
39 |
# Graph Nodes
|
40 |
@dataclass
|
41 |
-
class
|
42 |
"""Node that polishes the user's question"""
|
43 |
system_prompt = """
|
44 |
You are an expert at understanding code-related questions and reformulating them
|
@@ -46,7 +50,7 @@ class PolishQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
46 |
it more specific and searchable. Focus on technical terms and code concepts.
|
47 |
"""
|
48 |
|
49 |
-
async def run(self, ctx: GraphRunContext[ChatGraphState]) ->
|
50 |
# Create an agent for question polishing
|
51 |
from pydantic_ai import Agent
|
52 |
polish_agent = Agent(
|
@@ -62,13 +66,13 @@ class PolishQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
62 |
|
63 |
result = await polish_agent.run(prompt)
|
64 |
ctx.state.polished_question = result.data
|
65 |
-
return
|
66 |
|
67 |
@dataclass
|
68 |
-
class
|
69 |
"""Node that retrieves relevant code context"""
|
70 |
|
71 |
-
async def run(self, ctx: GraphRunContext[ChatGraphState]) ->
|
72 |
try:
|
73 |
embedded_question = ollama.embed(
|
74 |
model=ctx.deps.config.llm.embedding_model,
|
@@ -80,6 +84,7 @@ class RetrieveContext(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
80 |
n_results=ctx.deps.config.chat.max_context_chunks,
|
81 |
include=['metadatas', 'documents', 'distances']
|
82 |
)
|
|
|
83 |
|
84 |
relevant_chunks = []
|
85 |
relevant_metadatas = []
|
@@ -102,20 +107,23 @@ class RetrieveContext(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
102 |
ref += f"\n- {meta['type']}: `{meta['name']}`"
|
103 |
references.append(ref)
|
104 |
|
|
|
|
|
|
|
|
|
105 |
ctx.state.retrieved_context = RetrievedContext(
|
106 |
chunks=relevant_chunks,
|
107 |
metadatas=relevant_metadatas,
|
108 |
references_md="\n\n".join(references)
|
109 |
)
|
110 |
|
111 |
-
return AnswerQuestion()
|
112 |
-
|
113 |
except Exception as e:
|
114 |
LOG.error(f"Error retrieving context: {e}")
|
115 |
-
|
|
|
116 |
|
117 |
@dataclass
|
118 |
-
class
|
119 |
"""Node that generates the final answer"""
|
120 |
system_prompt = """
|
121 |
You are an expert code assistant helping users understand a codebase.
|
@@ -126,8 +134,7 @@ class AnswerQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
126 |
4. If you're unsure about something, acknowledge it
|
127 |
"""
|
128 |
|
129 |
-
async def run(self, ctx: GraphRunContext[ChatGraphState]) -> End[ChatResult]:
|
130 |
-
|
131 |
answer_agent = Agent(
|
132 |
f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
|
133 |
system_prompt=self.system_prompt
|
@@ -166,7 +173,7 @@ class AnswerQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
166 |
|
167 |
# Create the graph
|
168 |
chat_graph = Graph(
|
169 |
-
nodes=[
|
170 |
)
|
171 |
|
172 |
async def process_chat(
|
@@ -182,7 +189,7 @@ async def process_chat(
|
|
182 |
deps = ChatGraphDeps(collection=collection, config=config)
|
183 |
|
184 |
result, _history = await chat_graph.run(
|
185 |
-
|
186 |
state=state,
|
187 |
deps=deps
|
188 |
)
|
|
|
1 |
+
# __future__ annotations is necessary for the type hints to work in this file
|
2 |
+
from __future__ import annotations
|
3 |
from dataclasses import dataclass
|
4 |
from typing import List, Dict, Any, Optional
|
5 |
import chromadb
|
|
|
9 |
from know_lang_bot.chat_bot.chat_config import ChatAppConfig
|
10 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
11 |
from pydantic_ai import Agent
|
12 |
+
import logfire
|
13 |
+
from pprint import pformat
|
14 |
|
15 |
LOG = FancyLogger(__name__)
|
16 |
|
|
|
42 |
|
43 |
# Graph Nodes
|
44 |
@dataclass
|
45 |
+
class PolishQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
46 |
"""Node that polishes the user's question"""
|
47 |
system_prompt = """
|
48 |
You are an expert at understanding code-related questions and reformulating them
|
|
|
50 |
it more specific and searchable. Focus on technical terms and code concepts.
|
51 |
"""
|
52 |
|
53 |
+
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
|
54 |
# Create an agent for question polishing
|
55 |
from pydantic_ai import Agent
|
56 |
polish_agent = Agent(
|
|
|
66 |
|
67 |
result = await polish_agent.run(prompt)
|
68 |
ctx.state.polished_question = result.data
|
69 |
+
return RetrieveContextNode()
|
70 |
|
71 |
@dataclass
|
72 |
+
class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
73 |
"""Node that retrieves relevant code context"""
|
74 |
|
75 |
+
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
|
76 |
try:
|
77 |
embedded_question = ollama.embed(
|
78 |
model=ctx.deps.config.llm.embedding_model,
|
|
|
84 |
n_results=ctx.deps.config.chat.max_context_chunks,
|
85 |
include=['metadatas', 'documents', 'distances']
|
86 |
)
|
87 |
+
logfire.debug('query result: {result}', result=pformat(results))
|
88 |
|
89 |
relevant_chunks = []
|
90 |
relevant_metadatas = []
|
|
|
107 |
ref += f"\n- {meta['type']}: `{meta['name']}`"
|
108 |
references.append(ref)
|
109 |
|
110 |
+
with logfire.span('formatted {count} references', count=len(references)):
|
111 |
+
for ref in references:
|
112 |
+
logfire.debug(ref)
|
113 |
+
|
114 |
ctx.state.retrieved_context = RetrievedContext(
|
115 |
chunks=relevant_chunks,
|
116 |
metadatas=relevant_metadatas,
|
117 |
references_md="\n\n".join(references)
|
118 |
)
|
119 |
|
|
|
|
|
120 |
except Exception as e:
|
121 |
LOG.error(f"Error retrieving context: {e}")
|
122 |
+
finally:
|
123 |
+
return AnswerQuestionNode()
|
124 |
|
125 |
@dataclass
|
126 |
+
class AnswerQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
127 |
"""Node that generates the final answer"""
|
128 |
system_prompt = """
|
129 |
You are an expert code assistant helping users understand a codebase.
|
|
|
134 |
4. If you're unsure about something, acknowledge it
|
135 |
"""
|
136 |
|
137 |
+
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
|
|
|
138 |
answer_agent = Agent(
|
139 |
f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
|
140 |
system_prompt=self.system_prompt
|
|
|
173 |
|
174 |
# Create the graph
|
175 |
chat_graph = Graph(
|
176 |
+
nodes=[PolishQuestionNode, RetrieveContextNode, AnswerQuestionNode]
|
177 |
)
|
178 |
|
179 |
async def process_chat(
|
|
|
189 |
deps = ChatGraphDeps(collection=collection, config=config)
|
190 |
|
191 |
result, _history = await chat_graph.run(
|
192 |
+
PolishQuestionNode(),
|
193 |
state=state,
|
194 |
deps=deps
|
195 |
)
|