gabykim's picture
refactor package name to knowlang
60532a1
# __future__ annotations is necessary for the type hints to work in this file
from __future__ import annotations
from dataclasses import dataclass
from typing import AsyncGenerator, List, Dict, Any, Optional
import chromadb
from pydantic import BaseModel
from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep
from knowlang.configs.config import AppConfig, RerankerConfig, EmbeddingConfig
from knowlang.utils.fancy_log import FancyLogger
from pydantic_ai import Agent
import logfire
from pprint import pformat
from enum import Enum
from rich.console import Console
from knowlang.utils.model_provider import create_pydantic_model
from knowlang.utils.chunking_util import truncate_chunk
from knowlang.models.embeddings import EmbeddingInputType, generate_embedding
import voyageai
from voyageai.object.reranking import RerankingObject
LOG = FancyLogger(__name__)
console = Console()
class ChatStatus(str, Enum):
"""Enum for tracking chat progress status"""
STARTING = "starting"
POLISHING = "polishing"
RETRIEVING = "retrieving"
ANSWERING = "answering"
COMPLETE = "complete"
ERROR = "error"
class StreamingChatResult(BaseModel):
"""Extended chat result with streaming information"""
answer: str
retrieved_context: Optional[RetrievedContext] = None
status: ChatStatus
progress_message: str
@classmethod
def from_node(cls, node: BaseNode, state: ChatGraphState) -> StreamingChatResult:
"""Create a StreamingChatResult from a node's current state"""
if isinstance(node, PolishQuestionNode):
return cls(
answer="",
status=ChatStatus.POLISHING,
progress_message=f"Refining question: '{state.original_question}'"
)
elif isinstance(node, RetrieveContextNode):
return cls(
answer="",
status=ChatStatus.RETRIEVING,
progress_message=f"Searching codebase with: '{state.polished_question or state.original_question}'"
)
elif isinstance(node, AnswerQuestionNode):
context_msg = f"Found {len(state.retrieved_context.chunks)} relevant segments" if state.retrieved_context else "No context found"
return cls(
answer="",
retrieved_context=state.retrieved_context,
status=ChatStatus.ANSWERING,
progress_message=f"Generating answer... {context_msg}"
)
else:
return cls(
answer="",
status=ChatStatus.ERROR,
progress_message=f"Unknown node type: {type(node).__name__}"
)
@classmethod
def complete(cls, result: ChatResult) -> StreamingChatResult:
"""Create a completed StreamingChatResult"""
return cls(
answer=result.answer,
retrieved_context=result.retrieved_context,
status=ChatStatus.COMPLETE,
progress_message="Response complete"
)
@classmethod
def error(cls, error_msg: str) -> StreamingChatResult:
"""Create an error StreamingChatResult"""
return cls(
answer=f"Error: {error_msg}",
status=ChatStatus.ERROR,
progress_message=f"An error occurred: {error_msg}"
)
class RetrievedContext(BaseModel):
"""Structure for retrieved context"""
chunks: List[str]
metadatas: List[Dict[str, Any]]
class ChatResult(BaseModel):
"""Final result from the chat graph"""
answer: str
retrieved_context: Optional[RetrievedContext] = None
@dataclass
class ChatGraphState:
"""State maintained throughout the graph execution"""
original_question: str
polished_question: Optional[str] = None
retrieved_context: Optional[RetrievedContext] = None
@dataclass
class ChatGraphDeps:
"""Dependencies required by the graph"""
collection: chromadb.Collection
config: AppConfig
# Graph Nodes
@dataclass
class PolishQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
"""Node that polishes the user's question"""
system_prompt = """You are a code question refinement expert. Your ONLY task is to rephrase questions
to be more precise for code context retrieval. Follow these rules strictly:
1. Output ONLY the refined question - no explanations or analysis
2. Preserve the original intent completely
3. Add missing technical terms if obvious
4. Keep the question concise - ideally one sentence
5. Focus on searchable technical terms
6. Do not add speculative terms not implied by the original question
Example Input: "How do I use transformers for translation?"
Example Output: "How do I use the Transformers pipeline for machine translation tasks?"
Example Input: "Where is the config stored?"
Example Output: "Where is the configuration file or configuration settings stored in this codebase?"
"""
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
# Create an agent for question polishing
polish_agent = Agent(
create_pydantic_model(
model_provider=ctx.deps.config.llm.model_provider,
model_name=ctx.deps.config.llm.model_name
),
system_prompt=self.system_prompt
)
prompt = f"""Original question: "{ctx.state.original_question}"
Return ONLY the polished question - no explanations or analysis.
Focus on making the question more searchable while preserving its original intent."""
result = await polish_agent.run(prompt)
ctx.state.polished_question = result.data
return RetrieveContextNode()
@dataclass
class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
"""Node that retrieves relevant code context using hybrid search: embeddings + reranking"""
async def _get_initial_chunks(
self,
query: str,
embedding_config: EmbeddingConfig,
collection: chromadb.Collection,
n_results: int
) -> tuple[List[str], List[Dict], List[float]]:
"""Get initial chunks using embedding search"""
question_embedding = generate_embedding(
input=query,
config=embedding_config,
input_type=EmbeddingInputType.QUERY
)
results = collection.query(
query_embeddings=question_embedding,
n_results=n_results,
include=['metadatas', 'documents', 'distances']
)
return (
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
)
async def _rerank_chunks(
self,
query: str,
chunks: List[str],
reranker_config: RerankerConfig,
) -> RerankingObject:
"""Rerank chunks using Voyage AI"""
voyage_client = voyageai.Client()
return voyage_client.rerank(
query=query,
documents=chunks,
model=reranker_config.model_name,
top_k=reranker_config.top_k,
truncation=True
)
def _filter_by_distance(
self,
chunks: List[str],
metadatas: List[Dict],
distances: List[float],
threshold: float
) -> tuple[List[str], List[Dict]]:
"""Filter chunks by distance threshold"""
filtered_chunks = []
filtered_metadatas = []
for chunk, meta, dist in zip(chunks, metadatas, distances):
if dist <= threshold:
filtered_chunks.append(chunk)
filtered_metadatas.append(meta)
return filtered_chunks, filtered_metadatas
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
try:
# Get query
query = ctx.state.polished_question or ctx.state.original_question
# First pass: Get more candidates using embedding search
initial_chunks, initial_metadatas, distances = await self._get_initial_chunks(
query=query,
embedding_config=ctx.deps.config.embedding,
collection=ctx.deps.collection,
n_results=min(ctx.deps.config.chat.max_context_chunks * 2, 50)
)
# Log top k initial results by distance
top_k_initial = sorted(
zip(initial_chunks, distances),
key=lambda x: x[1]
)[:ctx.deps.config.reranker.top_k]
logfire.info('top k embedding search results: {results}', results=top_k_initial)
top_k_initial_chunks = [chunk for chunk, _ in top_k_initial]
# Only proceed to reranking if we have initial results
if not initial_chunks:
LOG.warning("No initial chunks found through embedding search")
raise Exception("No chunks found through embedding search")
# Second pass: Rerank the candidates
try:
if not ctx.deps.config.reranker.enabled:
raise Exception("Reranker is disabled")
# Second pass: Rerank candidates
reranking = await self._rerank_chunks(
query=query,
chunks=initial_chunks,
reranker_config=ctx.deps.config.reranker
)
logfire.info('top k reranking search results: {results}', results=reranking.results)
# Build final context from reranked results
relevant_chunks = []
relevant_metadatas = []
for result in reranking.results:
# Only include if score is good enough
if result.relevance_score >= ctx.deps.config.reranker.relevance_threshold:
relevant_chunks.append(result.document)
# Get corresponding metadata using original index
relevant_metadatas.append(initial_metadatas[result.index])
if not relevant_chunks:
raise Exception("No relevant chunks found through reranking")
except Exception as e:
# Fallback to distance-based filtering if reranking fails
LOG.error(f"Reranking failed, falling back to distance-based filtering: {e}")
relevant_chunks, relevant_metadatas = self._filter_by_distance(
chunks=top_k_initial_chunks,
metadatas=initial_metadatas,
distances=distances,
threshold=ctx.deps.config.chat.similarity_threshold
)
ctx.state.retrieved_context = RetrievedContext(
chunks=relevant_chunks,
metadatas=relevant_metadatas,
)
except Exception as e:
LOG.error(f"Error in context retrieval: {e}")
ctx.state.retrieved_context = RetrievedContext(chunks=[], metadatas=[])
finally:
return AnswerQuestionNode()
@dataclass
class AnswerQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
"""Node that generates the final answer"""
system_prompt = """
You are an expert code assistant helping developers understand complex codebases. Follow these rules strictly:
1. ALWAYS answer the user's question - this is your primary task
2. Base your answer ONLY on the provided code context, not on general knowledge
3. When referencing code:
- Cite specific files and line numbers
- Quote relevant code snippets briefly
- Explain why this code is relevant to the question
4. If you cannot find sufficient context to answer fully:
- Clearly state what's missing
- Explain what additional information would help
5. Focus on accuracy over comprehensiveness:
- If you're unsure about part of your answer, explicitly say so
- Better to acknowledge limitations than make assumptions
Remember: Your primary goal is answering the user's specific question, not explaining the entire codebase."""
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
answer_agent = Agent(
create_pydantic_model(
model_provider=ctx.deps.config.llm.model_provider,
model_name=ctx.deps.config.llm.model_name
),
system_prompt=self.system_prompt
)
if not ctx.state.retrieved_context or not ctx.state.retrieved_context.chunks:
return End(ChatResult(
answer="I couldn't find any relevant code context for your question. "
"Could you please rephrase or be more specific?",
retrieved_context=None,
))
context = ctx.state.retrieved_context
for chunk in context.chunks:
chunk = truncate_chunk(chunk, ctx.deps.config.chat.max_length_per_chunk)
prompt = f"""
Question: {ctx.state.original_question}
Relevant Code Context:
{context.chunks}
Provide a focused answer to the question based on the provided context.
Important: Stay focused on answering the specific question asked.
"""
try:
result = await answer_agent.run(prompt)
return End(ChatResult(
answer=result.data,
retrieved_context=context,
))
except Exception as e:
LOG.error(f"Error generating answer: {e}")
return End(ChatResult(
answer="I encountered an error processing your question. Please try again.",
retrieved_context=context,
))
# Create the graph
chat_graph = Graph(
nodes=[PolishQuestionNode, RetrieveContextNode, AnswerQuestionNode]
)
async def process_chat(
question: str,
collection: chromadb.Collection,
config: AppConfig
) -> ChatResult:
"""
Process a chat question through the graph.
This is the main entry point for chat processing.
"""
state = ChatGraphState(original_question=question)
deps = ChatGraphDeps(collection=collection, config=config)
try:
result, _history = await chat_graph.run(
# Temporary fix to disable PolishQuestionNode
RetrieveContextNode(),
state=state,
deps=deps
)
except Exception as e:
LOG.error(f"Error processing chat in graph: {e}")
console.print_exception()
result = ChatResult(
answer="I encountered an error processing your question. Please try again."
)
finally:
return result
async def stream_chat_progress(
question: str,
collection: chromadb.Collection,
config: AppConfig
) -> AsyncGenerator[StreamingChatResult, None]:
"""
Stream chat progress through the graph.
This is the main entry point for chat processing.
"""
state = ChatGraphState(original_question=question)
deps = ChatGraphDeps(collection=collection, config=config)
# Temporary fix to disable PolishQuestionNode
start_node = RetrieveContextNode()
history: list[HistoryStep[ChatGraphState, ChatResult]] = []
try:
# Initial status
yield StreamingChatResult(
answer="",
status=ChatStatus.STARTING,
progress_message=f"Processing question: {question}"
)
with logfire.span(
'{graph_name} run {start=}',
graph_name='RAG_chat_graph',
start=start_node,
) as run_span:
current_node = start_node
while True:
# Yield current node's status before processing
yield StreamingChatResult.from_node(current_node, state)
try:
# Process the current node
next_node = await chat_graph.next(current_node, history, state=state, deps=deps, infer_name=False)
if isinstance(next_node, End):
result: ChatResult = next_node.data
history.append(EndStep(result=next_node))
run_span.set_attribute('history', history)
# Yield final result
yield StreamingChatResult.complete(result)
return
elif isinstance(next_node, BaseNode):
current_node = next_node
else:
raise ValueError(f"Invalid node type: {type(next_node)}")
except Exception as node_error:
LOG.error(f"Error in node {current_node.__class__.__name__}: {node_error}")
yield StreamingChatResult.error(str(node_error))
return
except Exception as e:
LOG.error(f"Error in stream_chat_progress: {e}")
yield StreamingChatResult.error(str(e))
return