Spaces:
Sleeping
Sleeping
# __future__ annotations is necessary for the type hints to work in this file | |
from __future__ import annotations | |
from dataclasses import dataclass | |
from typing import AsyncGenerator, List, Dict, Any, Optional | |
import chromadb | |
from pydantic import BaseModel | |
from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep | |
from knowlang.configs.config import AppConfig, RerankerConfig, EmbeddingConfig | |
from knowlang.utils.fancy_log import FancyLogger | |
from pydantic_ai import Agent | |
import logfire | |
from pprint import pformat | |
from enum import Enum | |
from rich.console import Console | |
from knowlang.utils.model_provider import create_pydantic_model | |
from knowlang.utils.chunking_util import truncate_chunk | |
from knowlang.models.embeddings import EmbeddingInputType, generate_embedding | |
import voyageai | |
from voyageai.object.reranking import RerankingObject | |
LOG = FancyLogger(__name__) | |
console = Console() | |
class ChatStatus(str, Enum): | |
"""Enum for tracking chat progress status""" | |
STARTING = "starting" | |
POLISHING = "polishing" | |
RETRIEVING = "retrieving" | |
ANSWERING = "answering" | |
COMPLETE = "complete" | |
ERROR = "error" | |
class StreamingChatResult(BaseModel): | |
"""Extended chat result with streaming information""" | |
answer: str | |
retrieved_context: Optional[RetrievedContext] = None | |
status: ChatStatus | |
progress_message: str | |
def from_node(cls, node: BaseNode, state: ChatGraphState) -> StreamingChatResult: | |
"""Create a StreamingChatResult from a node's current state""" | |
if isinstance(node, PolishQuestionNode): | |
return cls( | |
answer="", | |
status=ChatStatus.POLISHING, | |
progress_message=f"Refining question: '{state.original_question}'" | |
) | |
elif isinstance(node, RetrieveContextNode): | |
return cls( | |
answer="", | |
status=ChatStatus.RETRIEVING, | |
progress_message=f"Searching codebase with: '{state.polished_question or state.original_question}'" | |
) | |
elif isinstance(node, AnswerQuestionNode): | |
context_msg = f"Found {len(state.retrieved_context.chunks)} relevant segments" if state.retrieved_context else "No context found" | |
return cls( | |
answer="", | |
retrieved_context=state.retrieved_context, | |
status=ChatStatus.ANSWERING, | |
progress_message=f"Generating answer... {context_msg}" | |
) | |
else: | |
return cls( | |
answer="", | |
status=ChatStatus.ERROR, | |
progress_message=f"Unknown node type: {type(node).__name__}" | |
) | |
def complete(cls, result: ChatResult) -> StreamingChatResult: | |
"""Create a completed StreamingChatResult""" | |
return cls( | |
answer=result.answer, | |
retrieved_context=result.retrieved_context, | |
status=ChatStatus.COMPLETE, | |
progress_message="Response complete" | |
) | |
def error(cls, error_msg: str) -> StreamingChatResult: | |
"""Create an error StreamingChatResult""" | |
return cls( | |
answer=f"Error: {error_msg}", | |
status=ChatStatus.ERROR, | |
progress_message=f"An error occurred: {error_msg}" | |
) | |
class RetrievedContext(BaseModel): | |
"""Structure for retrieved context""" | |
chunks: List[str] | |
metadatas: List[Dict[str, Any]] | |
class ChatResult(BaseModel): | |
"""Final result from the chat graph""" | |
answer: str | |
retrieved_context: Optional[RetrievedContext] = None | |
class ChatGraphState: | |
"""State maintained throughout the graph execution""" | |
original_question: str | |
polished_question: Optional[str] = None | |
retrieved_context: Optional[RetrievedContext] = None | |
class ChatGraphDeps: | |
"""Dependencies required by the graph""" | |
collection: chromadb.Collection | |
config: AppConfig | |
# Graph Nodes | |
class PolishQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]): | |
"""Node that polishes the user's question""" | |
system_prompt = """You are a code question refinement expert. Your ONLY task is to rephrase questions | |
to be more precise for code context retrieval. Follow these rules strictly: | |
1. Output ONLY the refined question - no explanations or analysis | |
2. Preserve the original intent completely | |
3. Add missing technical terms if obvious | |
4. Keep the question concise - ideally one sentence | |
5. Focus on searchable technical terms | |
6. Do not add speculative terms not implied by the original question | |
Example Input: "How do I use transformers for translation?" | |
Example Output: "How do I use the Transformers pipeline for machine translation tasks?" | |
Example Input: "Where is the config stored?" | |
Example Output: "Where is the configuration file or configuration settings stored in this codebase?" | |
""" | |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode: | |
# Create an agent for question polishing | |
polish_agent = Agent( | |
create_pydantic_model( | |
model_provider=ctx.deps.config.llm.model_provider, | |
model_name=ctx.deps.config.llm.model_name | |
), | |
system_prompt=self.system_prompt | |
) | |
prompt = f"""Original question: "{ctx.state.original_question}" | |
Return ONLY the polished question - no explanations or analysis. | |
Focus on making the question more searchable while preserving its original intent.""" | |
result = await polish_agent.run(prompt) | |
ctx.state.polished_question = result.data | |
return RetrieveContextNode() | |
class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]): | |
"""Node that retrieves relevant code context using hybrid search: embeddings + reranking""" | |
async def _get_initial_chunks( | |
self, | |
query: str, | |
embedding_config: EmbeddingConfig, | |
collection: chromadb.Collection, | |
n_results: int | |
) -> tuple[List[str], List[Dict], List[float]]: | |
"""Get initial chunks using embedding search""" | |
question_embedding = generate_embedding( | |
input=query, | |
config=embedding_config, | |
input_type=EmbeddingInputType.QUERY | |
) | |
results = collection.query( | |
query_embeddings=question_embedding, | |
n_results=n_results, | |
include=['metadatas', 'documents', 'distances'] | |
) | |
return ( | |
results['documents'][0], | |
results['metadatas'][0], | |
results['distances'][0] | |
) | |
async def _rerank_chunks( | |
self, | |
query: str, | |
chunks: List[str], | |
reranker_config: RerankerConfig, | |
) -> RerankingObject: | |
"""Rerank chunks using Voyage AI""" | |
voyage_client = voyageai.Client() | |
return voyage_client.rerank( | |
query=query, | |
documents=chunks, | |
model=reranker_config.model_name, | |
top_k=reranker_config.top_k, | |
truncation=True | |
) | |
def _filter_by_distance( | |
self, | |
chunks: List[str], | |
metadatas: List[Dict], | |
distances: List[float], | |
threshold: float | |
) -> tuple[List[str], List[Dict]]: | |
"""Filter chunks by distance threshold""" | |
filtered_chunks = [] | |
filtered_metadatas = [] | |
for chunk, meta, dist in zip(chunks, metadatas, distances): | |
if dist <= threshold: | |
filtered_chunks.append(chunk) | |
filtered_metadatas.append(meta) | |
return filtered_chunks, filtered_metadatas | |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode: | |
try: | |
# Get query | |
query = ctx.state.polished_question or ctx.state.original_question | |
# First pass: Get more candidates using embedding search | |
initial_chunks, initial_metadatas, distances = await self._get_initial_chunks( | |
query=query, | |
embedding_config=ctx.deps.config.embedding, | |
collection=ctx.deps.collection, | |
n_results=min(ctx.deps.config.chat.max_context_chunks * 2, 50) | |
) | |
# Log top k initial results by distance | |
top_k_initial = sorted( | |
zip(initial_chunks, distances), | |
key=lambda x: x[1] | |
)[:ctx.deps.config.reranker.top_k] | |
logfire.info('top k embedding search results: {results}', results=top_k_initial) | |
top_k_initial_chunks = [chunk for chunk, _ in top_k_initial] | |
# Only proceed to reranking if we have initial results | |
if not initial_chunks: | |
LOG.warning("No initial chunks found through embedding search") | |
raise Exception("No chunks found through embedding search") | |
# Second pass: Rerank the candidates | |
try: | |
if not ctx.deps.config.reranker.enabled: | |
raise Exception("Reranker is disabled") | |
# Second pass: Rerank candidates | |
reranking = await self._rerank_chunks( | |
query=query, | |
chunks=initial_chunks, | |
reranker_config=ctx.deps.config.reranker | |
) | |
logfire.info('top k reranking search results: {results}', results=reranking.results) | |
# Build final context from reranked results | |
relevant_chunks = [] | |
relevant_metadatas = [] | |
for result in reranking.results: | |
# Only include if score is good enough | |
if result.relevance_score >= ctx.deps.config.reranker.relevance_threshold: | |
relevant_chunks.append(result.document) | |
# Get corresponding metadata using original index | |
relevant_metadatas.append(initial_metadatas[result.index]) | |
if not relevant_chunks: | |
raise Exception("No relevant chunks found through reranking") | |
except Exception as e: | |
# Fallback to distance-based filtering if reranking fails | |
LOG.error(f"Reranking failed, falling back to distance-based filtering: {e}") | |
relevant_chunks, relevant_metadatas = self._filter_by_distance( | |
chunks=top_k_initial_chunks, | |
metadatas=initial_metadatas, | |
distances=distances, | |
threshold=ctx.deps.config.chat.similarity_threshold | |
) | |
ctx.state.retrieved_context = RetrievedContext( | |
chunks=relevant_chunks, | |
metadatas=relevant_metadatas, | |
) | |
except Exception as e: | |
LOG.error(f"Error in context retrieval: {e}") | |
ctx.state.retrieved_context = RetrievedContext(chunks=[], metadatas=[]) | |
finally: | |
return AnswerQuestionNode() | |
class AnswerQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]): | |
"""Node that generates the final answer""" | |
system_prompt = """ | |
You are an expert code assistant helping developers understand complex codebases. Follow these rules strictly: | |
1. ALWAYS answer the user's question - this is your primary task | |
2. Base your answer ONLY on the provided code context, not on general knowledge | |
3. When referencing code: | |
- Cite specific files and line numbers | |
- Quote relevant code snippets briefly | |
- Explain why this code is relevant to the question | |
4. If you cannot find sufficient context to answer fully: | |
- Clearly state what's missing | |
- Explain what additional information would help | |
5. Focus on accuracy over comprehensiveness: | |
- If you're unsure about part of your answer, explicitly say so | |
- Better to acknowledge limitations than make assumptions | |
Remember: Your primary goal is answering the user's specific question, not explaining the entire codebase.""" | |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]: | |
answer_agent = Agent( | |
create_pydantic_model( | |
model_provider=ctx.deps.config.llm.model_provider, | |
model_name=ctx.deps.config.llm.model_name | |
), | |
system_prompt=self.system_prompt | |
) | |
if not ctx.state.retrieved_context or not ctx.state.retrieved_context.chunks: | |
return End(ChatResult( | |
answer="I couldn't find any relevant code context for your question. " | |
"Could you please rephrase or be more specific?", | |
retrieved_context=None, | |
)) | |
context = ctx.state.retrieved_context | |
for chunk in context.chunks: | |
chunk = truncate_chunk(chunk, ctx.deps.config.chat.max_length_per_chunk) | |
prompt = f""" | |
Question: {ctx.state.original_question} | |
Relevant Code Context: | |
{context.chunks} | |
Provide a focused answer to the question based on the provided context. | |
Important: Stay focused on answering the specific question asked. | |
""" | |
try: | |
result = await answer_agent.run(prompt) | |
return End(ChatResult( | |
answer=result.data, | |
retrieved_context=context, | |
)) | |
except Exception as e: | |
LOG.error(f"Error generating answer: {e}") | |
return End(ChatResult( | |
answer="I encountered an error processing your question. Please try again.", | |
retrieved_context=context, | |
)) | |
# Create the graph | |
chat_graph = Graph( | |
nodes=[PolishQuestionNode, RetrieveContextNode, AnswerQuestionNode] | |
) | |
async def process_chat( | |
question: str, | |
collection: chromadb.Collection, | |
config: AppConfig | |
) -> ChatResult: | |
""" | |
Process a chat question through the graph. | |
This is the main entry point for chat processing. | |
""" | |
state = ChatGraphState(original_question=question) | |
deps = ChatGraphDeps(collection=collection, config=config) | |
try: | |
result, _history = await chat_graph.run( | |
# Temporary fix to disable PolishQuestionNode | |
RetrieveContextNode(), | |
state=state, | |
deps=deps | |
) | |
except Exception as e: | |
LOG.error(f"Error processing chat in graph: {e}") | |
console.print_exception() | |
result = ChatResult( | |
answer="I encountered an error processing your question. Please try again." | |
) | |
finally: | |
return result | |
async def stream_chat_progress( | |
question: str, | |
collection: chromadb.Collection, | |
config: AppConfig | |
) -> AsyncGenerator[StreamingChatResult, None]: | |
""" | |
Stream chat progress through the graph. | |
This is the main entry point for chat processing. | |
""" | |
state = ChatGraphState(original_question=question) | |
deps = ChatGraphDeps(collection=collection, config=config) | |
# Temporary fix to disable PolishQuestionNode | |
start_node = RetrieveContextNode() | |
history: list[HistoryStep[ChatGraphState, ChatResult]] = [] | |
try: | |
# Initial status | |
yield StreamingChatResult( | |
answer="", | |
status=ChatStatus.STARTING, | |
progress_message=f"Processing question: {question}" | |
) | |
with logfire.span( | |
'{graph_name} run {start=}', | |
graph_name='RAG_chat_graph', | |
start=start_node, | |
) as run_span: | |
current_node = start_node | |
while True: | |
# Yield current node's status before processing | |
yield StreamingChatResult.from_node(current_node, state) | |
try: | |
# Process the current node | |
next_node = await chat_graph.next(current_node, history, state=state, deps=deps, infer_name=False) | |
if isinstance(next_node, End): | |
result: ChatResult = next_node.data | |
history.append(EndStep(result=next_node)) | |
run_span.set_attribute('history', history) | |
# Yield final result | |
yield StreamingChatResult.complete(result) | |
return | |
elif isinstance(next_node, BaseNode): | |
current_node = next_node | |
else: | |
raise ValueError(f"Invalid node type: {type(next_node)}") | |
except Exception as node_error: | |
LOG.error(f"Error in node {current_node.__class__.__name__}: {node_error}") | |
yield StreamingChatResult.error(str(node_error)) | |
return | |
except Exception as e: | |
LOG.error(f"Error in stream_chat_progress: {e}") | |
yield StreamingChatResult.error(str(e)) | |
return | |