chat graph implemented
Browse files
src/know_lang_bot/chat_bot/__main__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
|
2 |
+
from know_lang_bot.chat_bot.chat_graph import process_chat
|
3 |
+
import chromadb
|
4 |
+
import asyncio
|
5 |
+
|
6 |
+
async def test_chat_processing():
|
7 |
+
config = chat_app_config
|
8 |
+
db_client = chromadb.PersistentClient(
|
9 |
+
path=str(config.db.persist_directory)
|
10 |
+
)
|
11 |
+
collection = db_client.get_collection(
|
12 |
+
name=config.db.collection_name
|
13 |
+
)
|
14 |
+
|
15 |
+
result = await process_chat(
|
16 |
+
"How does the parser handle nested classes?",
|
17 |
+
collection,
|
18 |
+
config
|
19 |
+
)
|
20 |
+
|
21 |
+
print(f"Answer: {result.answer}")
|
22 |
+
print(f"References: {result.references_md}")
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
asyncio.run(test_chat_processing())
|
src/know_lang_bot/chat_bot/chat_agent.py
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import List, Dict, Any, Optional
|
3 |
-
import chromadb
|
4 |
-
from pathlib import Path
|
5 |
-
from pydantic_ai import Agent, RunContext
|
6 |
-
from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
|
7 |
-
from know_lang_bot.utils.fancy_log import FancyLogger
|
8 |
-
from pydantic import BaseModel
|
9 |
-
import ollama
|
10 |
-
import logfire
|
11 |
-
|
12 |
-
LOG = FancyLogger(__name__)
|
13 |
-
|
14 |
-
@dataclass
|
15 |
-
class CodeQADeps:
|
16 |
-
"""Dependencies for the Code Q&A Agent"""
|
17 |
-
collection: chromadb.Collection
|
18 |
-
config: ChatAppConfig
|
19 |
-
|
20 |
-
class RetrievedContext(BaseModel):
|
21 |
-
"""Structure for retrieved context"""
|
22 |
-
chunks: List[str]
|
23 |
-
metadatas: List[Dict[str, Any]]
|
24 |
-
references_md: str
|
25 |
-
|
26 |
-
class AgentResponse(BaseModel):
|
27 |
-
"""Structure for agent responses"""
|
28 |
-
answer: str
|
29 |
-
references_md: Optional[str] = None
|
30 |
-
|
31 |
-
# Initialize the agent with system prompt and dependencies
|
32 |
-
code_qa_agent = Agent(
|
33 |
-
f'{chat_app_config.llm.model_provider}:{chat_app_config.llm.model_name}',
|
34 |
-
deps_type=CodeQADeps,
|
35 |
-
result_type=AgentResponse,
|
36 |
-
system_prompt="""
|
37 |
-
You are an expert code assistant helping users understand a codebase.
|
38 |
-
|
39 |
-
Always:
|
40 |
-
1. Reference specific files and line numbers in your explanations
|
41 |
-
2. Be direct and concise while being comprehensive
|
42 |
-
3. If the context is insufficient, explain why
|
43 |
-
4. If you're unsure about something, acknowledge it
|
44 |
-
|
45 |
-
Your response should be helpful for software engineers trying to understand complex codebases.
|
46 |
-
""",
|
47 |
-
)
|
48 |
-
|
49 |
-
@code_qa_agent.tool
|
50 |
-
@logfire.instrument()
|
51 |
-
async def retrieve_context(
|
52 |
-
ctx: RunContext[CodeQADeps],
|
53 |
-
question: str
|
54 |
-
) -> RetrievedContext:
|
55 |
-
"""
|
56 |
-
Retrieve relevant code context from the vector database.
|
57 |
-
|
58 |
-
Args:
|
59 |
-
ctx: The context containing dependencies
|
60 |
-
question: The user's question to find relevant code for
|
61 |
-
"""
|
62 |
-
embedded_question = ollama.embed(
|
63 |
-
model=ctx.deps.config.llm.embedding_model,
|
64 |
-
input=question
|
65 |
-
)
|
66 |
-
|
67 |
-
results = ctx.deps.collection.query(
|
68 |
-
query_embeddings=embedded_question['embeddings'],
|
69 |
-
n_results=ctx.deps.config.chat.max_context_chunks,
|
70 |
-
include=['metadatas', 'documents', 'distances']
|
71 |
-
)
|
72 |
-
|
73 |
-
relevant_chunks = []
|
74 |
-
relevant_metadatas = []
|
75 |
-
|
76 |
-
for doc, meta, dist in zip(
|
77 |
-
results['documents'][0],
|
78 |
-
results['metadatas'][0],
|
79 |
-
results['distances'][0]
|
80 |
-
):
|
81 |
-
if dist <= ctx.deps.config.chat.similarity_threshold:
|
82 |
-
relevant_chunks.append(doc)
|
83 |
-
relevant_metadatas.append(meta)
|
84 |
-
|
85 |
-
|
86 |
-
# Format references for display
|
87 |
-
references = []
|
88 |
-
for meta in relevant_metadatas:
|
89 |
-
file_path = Path(meta['file_path']).name
|
90 |
-
ref = f"**{file_path}** (lines {meta['start_line']}-{meta['end_line']})"
|
91 |
-
if meta.get('name'):
|
92 |
-
ref += f"\n- {meta['type']}: `{meta['name']}`"
|
93 |
-
references.append(ref)
|
94 |
-
|
95 |
-
return RetrievedContext(
|
96 |
-
chunks=relevant_chunks,
|
97 |
-
metadatas=relevant_metadatas,
|
98 |
-
references_md="\n\n".join(references)
|
99 |
-
)
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/know_lang_bot/chat_bot/chat_graph.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Dict, Any, Optional
|
4 |
+
import chromadb
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from pydantic_graph import BaseNode, Graph, GraphRunContext, End
|
7 |
+
import ollama
|
8 |
+
import logfire
|
9 |
+
from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
|
10 |
+
from know_lang_bot.utils.fancy_log import FancyLogger
|
11 |
+
from pydantic_ai import Agent
|
12 |
+
|
13 |
+
LOG = FancyLogger(__name__)
|
14 |
+
|
15 |
+
# Data Models
|
16 |
+
class RetrievedContext(BaseModel):
|
17 |
+
"""Structure for retrieved context"""
|
18 |
+
chunks: List[str]
|
19 |
+
metadatas: List[Dict[str, Any]]
|
20 |
+
references_md: str
|
21 |
+
|
22 |
+
class ChatResult(BaseModel):
|
23 |
+
"""Final result from the chat graph"""
|
24 |
+
answer: str
|
25 |
+
references_md: Optional[str] = None
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class ChatGraphState:
|
32 |
+
"""State maintained throughout the graph execution"""
|
33 |
+
original_question: str
|
34 |
+
polished_question: Optional[str] = None
|
35 |
+
retrieved_context: Optional[RetrievedContext] = None
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class ChatGraphDeps:
|
39 |
+
"""Dependencies required by the graph"""
|
40 |
+
collection: chromadb.Collection
|
41 |
+
config: ChatAppConfig
|
42 |
+
|
43 |
+
|
44 |
+
# Graph Nodes
|
45 |
+
@dataclass
|
46 |
+
class PolishQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
47 |
+
"""Node that polishes the user's question"""
|
48 |
+
system_prompt = """
|
49 |
+
You are an expert at understanding code-related questions and reformulating them
|
50 |
+
for better context retrieval. Your task is to polish the user's question to make
|
51 |
+
it more specific and searchable. Focus on technical terms and code concepts.
|
52 |
+
"""
|
53 |
+
|
54 |
+
async def run(self, ctx: GraphRunContext[ChatGraphState]) -> RetrieveContext:
|
55 |
+
# Create an agent for question polishing
|
56 |
+
from pydantic_ai import Agent
|
57 |
+
polish_agent = Agent(
|
58 |
+
f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}"
|
59 |
+
)
|
60 |
+
prompt = f"""
|
61 |
+
Original question: {ctx.state.original_question}
|
62 |
+
|
63 |
+
Please reformulate this question to be more specific and searchable,
|
64 |
+
focusing on technical terms and code concepts. Keep the core meaning
|
65 |
+
but make it more precise for code context retrieval.
|
66 |
+
"""
|
67 |
+
|
68 |
+
result = await polish_agent.run(prompt)
|
69 |
+
ctx.state.polished_question = result.data
|
70 |
+
return RetrieveContext()
|
71 |
+
|
72 |
+
@dataclass
|
73 |
+
class RetrieveContext(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
74 |
+
"""Node that retrieves relevant code context"""
|
75 |
+
|
76 |
+
async def run(self, ctx: GraphRunContext[ChatGraphState]) -> AnswerQuestion:
|
77 |
+
try:
|
78 |
+
embedded_question = ollama.embed(
|
79 |
+
model=ctx.deps.config.llm.embedding_model,
|
80 |
+
input=ctx.state.polished_question or ctx.state.original_question
|
81 |
+
)
|
82 |
+
|
83 |
+
results = ctx.deps.collection.query(
|
84 |
+
query_embeddings=embedded_question['embeddings'],
|
85 |
+
n_results=ctx.deps.config.chat.max_context_chunks,
|
86 |
+
include=['metadatas', 'documents', 'distances']
|
87 |
+
)
|
88 |
+
|
89 |
+
relevant_chunks = []
|
90 |
+
relevant_metadatas = []
|
91 |
+
|
92 |
+
for doc, meta, dist in zip(
|
93 |
+
results['documents'][0],
|
94 |
+
results['metadatas'][0],
|
95 |
+
results['distances'][0]
|
96 |
+
):
|
97 |
+
if dist <= ctx.deps.config.chat.similarity_threshold:
|
98 |
+
relevant_chunks.append(doc)
|
99 |
+
relevant_metadatas.append(meta)
|
100 |
+
|
101 |
+
# Format references for display
|
102 |
+
references = []
|
103 |
+
for meta in relevant_metadatas:
|
104 |
+
file_path = meta['file_path'].split('/')[-1]
|
105 |
+
ref = f"**{file_path}** (lines {meta['start_line']}-{meta['end_line']})"
|
106 |
+
if meta.get('name'):
|
107 |
+
ref += f"\n- {meta['type']}: `{meta['name']}`"
|
108 |
+
references.append(ref)
|
109 |
+
|
110 |
+
ctx.state.retrieved_context = RetrievedContext(
|
111 |
+
chunks=relevant_chunks,
|
112 |
+
metadatas=relevant_metadatas,
|
113 |
+
references_md="\n\n".join(references)
|
114 |
+
)
|
115 |
+
|
116 |
+
return AnswerQuestion()
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
LOG.error(f"Error retrieving context: {e}")
|
120 |
+
return AnswerQuestion()
|
121 |
+
|
122 |
+
@dataclass
|
123 |
+
class AnswerQuestion(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
124 |
+
"""Node that generates the final answer"""
|
125 |
+
system_prompt = """
|
126 |
+
You are an expert code assistant helping users understand a codebase.
|
127 |
+
Always:
|
128 |
+
1. Reference specific files and line numbers in your explanations
|
129 |
+
2. Be direct and concise while being comprehensive
|
130 |
+
3. If the context is insufficient, explain why
|
131 |
+
4. If you're unsure about something, acknowledge it
|
132 |
+
"""
|
133 |
+
|
134 |
+
async def run(self, ctx: GraphRunContext[ChatGraphState]) -> End[ChatResult]:
|
135 |
+
|
136 |
+
answer_agent = Agent(
|
137 |
+
f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
|
138 |
+
system_prompt=self.system_prompt
|
139 |
+
)
|
140 |
+
|
141 |
+
if not ctx.state.retrieved_context or not ctx.state.retrieved_context.chunks:
|
142 |
+
return End(ChatResult(
|
143 |
+
answer="I couldn't find any relevant code context for your question. "
|
144 |
+
"Could you please rephrase or be more specific?",
|
145 |
+
references_md=""
|
146 |
+
))
|
147 |
+
|
148 |
+
context = ctx.state.retrieved_context
|
149 |
+
prompt = f"""
|
150 |
+
Question: {ctx.state.original_question}
|
151 |
+
|
152 |
+
Available Code Context:
|
153 |
+
{context.chunks}
|
154 |
+
|
155 |
+
Please provide a comprehensive answer based on the code context above.
|
156 |
+
Make sure to reference specific files and line numbers from the context.
|
157 |
+
"""
|
158 |
+
|
159 |
+
try:
|
160 |
+
result = await answer_agent.run(prompt)
|
161 |
+
return End(ChatResult(
|
162 |
+
answer=result.data,
|
163 |
+
references_md=context.references_md
|
164 |
+
))
|
165 |
+
except Exception as e:
|
166 |
+
LOG.error(f"Error generating answer: {e}")
|
167 |
+
return End(ChatResult(
|
168 |
+
answer="I encountered an error processing your question. Please try again.",
|
169 |
+
references_md=""
|
170 |
+
))
|
171 |
+
|
172 |
+
# Create the graph
|
173 |
+
chat_graph = Graph(
|
174 |
+
nodes=[PolishQuestion, RetrieveContext, AnswerQuestion]
|
175 |
+
)
|
176 |
+
|
177 |
+
async def process_chat(
|
178 |
+
question: str,
|
179 |
+
collection: chromadb.Collection,
|
180 |
+
config: ChatAppConfig
|
181 |
+
) -> ChatResult:
|
182 |
+
"""
|
183 |
+
Process a chat question through the graph.
|
184 |
+
This is the main entry point for chat processing.
|
185 |
+
"""
|
186 |
+
state = ChatGraphState(original_question=question)
|
187 |
+
deps = ChatGraphDeps(collection=collection, config=config)
|
188 |
+
|
189 |
+
result, _history = await chat_graph.run(
|
190 |
+
PolishQuestion(),
|
191 |
+
state=state,
|
192 |
+
deps=deps
|
193 |
+
)
|
194 |
+
|
195 |
+
return result
|
src/know_lang_bot/chat_bot/chat_interface.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
|
3 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
4 |
-
from know_lang_bot.chat_bot.
|
5 |
import chromadb
|
6 |
from typing import List, Dict
|
7 |
import logfire
|
@@ -12,7 +12,6 @@ class CodeQAChatInterface:
|
|
12 |
def __init__(self, config: ChatAppConfig):
|
13 |
self.config = config
|
14 |
self._init_chroma()
|
15 |
-
self.agent = code_qa_agent
|
16 |
|
17 |
def _init_chroma(self):
|
18 |
"""Initialize ChromaDB connection"""
|
@@ -28,23 +27,9 @@ class CodeQAChatInterface:
|
|
28 |
self,
|
29 |
message: str,
|
30 |
history: List[Dict[str, str]]
|
31 |
-
) ->
|
32 |
"""Process a question and return the answer with references"""
|
33 |
-
|
34 |
-
deps = CodeQADeps(
|
35 |
-
collection=self.collection,
|
36 |
-
config=self.config
|
37 |
-
)
|
38 |
-
|
39 |
-
response = await self.agent.run(message, deps=deps)
|
40 |
-
return response.data
|
41 |
-
|
42 |
-
except Exception as e:
|
43 |
-
LOG.error(f"Error processing question: {e}")
|
44 |
-
return AgentResponse(
|
45 |
-
answer="I encountered an error processing your question. Please try again.",
|
46 |
-
references_md=""
|
47 |
-
)
|
48 |
|
49 |
def create_interface(self) -> gr.Blocks:
|
50 |
"""Create the Gradio interface"""
|
@@ -54,10 +39,7 @@ class CodeQAChatInterface:
|
|
54 |
|
55 |
with gr.Row():
|
56 |
with gr.Column(scale=2):
|
57 |
-
chatbot = gr.Chatbot(
|
58 |
-
type="messages",
|
59 |
-
bubble_full_width=False
|
60 |
-
)
|
61 |
msg = gr.Textbox(
|
62 |
label="Ask about the codebase",
|
63 |
placeholder="What does the CodeParser class do?",
|
@@ -72,13 +54,13 @@ class CodeQAChatInterface:
|
|
72 |
)
|
73 |
|
74 |
async def respond(message, history):
|
75 |
-
|
76 |
-
references.value =
|
77 |
return {
|
78 |
msg: "",
|
79 |
chatbot: history + [
|
80 |
{"role": "user", "content": message},
|
81 |
-
{"role": "assistant", "content":
|
82 |
]
|
83 |
}
|
84 |
|
|
|
1 |
import gradio as gr
|
2 |
from know_lang_bot.chat_bot.chat_config import ChatAppConfig, chat_app_config
|
3 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
4 |
+
from know_lang_bot.chat_bot.chat_graph import ChatResult, process_chat
|
5 |
import chromadb
|
6 |
from typing import List, Dict
|
7 |
import logfire
|
|
|
12 |
def __init__(self, config: ChatAppConfig):
|
13 |
self.config = config
|
14 |
self._init_chroma()
|
|
|
15 |
|
16 |
def _init_chroma(self):
|
17 |
"""Initialize ChromaDB connection"""
|
|
|
27 |
self,
|
28 |
message: str,
|
29 |
history: List[Dict[str, str]]
|
30 |
+
) -> ChatResult:
|
31 |
"""Process a question and return the answer with references"""
|
32 |
+
return await process_chat(message, self.collection, self.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def create_interface(self) -> gr.Blocks:
|
35 |
"""Create the Gradio interface"""
|
|
|
39 |
|
40 |
with gr.Row():
|
41 |
with gr.Column(scale=2):
|
42 |
+
chatbot = gr.Chatbot(type="messages", bubble_full_width=False)
|
|
|
|
|
|
|
43 |
msg = gr.Textbox(
|
44 |
label="Ask about the codebase",
|
45 |
placeholder="What does the CodeParser class do?",
|
|
|
54 |
)
|
55 |
|
56 |
async def respond(message, history):
|
57 |
+
result = await self.process_question(message, history)
|
58 |
+
references.value = result.references_md
|
59 |
return {
|
60 |
msg: "",
|
61 |
chatbot: history + [
|
62 |
{"role": "user", "content": message},
|
63 |
+
{"role": "assistant", "content": result.answer}
|
64 |
]
|
65 |
}
|
66 |
|
src/know_lang_bot/chat_bot/{run.py → gradio.py}
RENAMED
File without changes
|