File size: 17,305 Bytes
3cf6c5a
 
6b5ac9a
52d3389
6b5ac9a
 
52d3389
60532a1
 
6b5ac9a
3cf6c5a
eb592fa
52d3389
183e719
60532a1
 
 
212ff4c
 
6b5ac9a
 
183e719
6b5ac9a
52d3389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b5ac9a
 
 
 
 
 
 
 
aad4327
6b5ac9a
 
 
 
 
 
 
 
 
 
 
 
070f7e7
6b5ac9a
 
 
 
3cf6c5a
6b5ac9a
ca665cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b5ac9a
 
3cf6c5a
6b5ac9a
 
183e719
 
 
 
d141c1f
6b5ac9a
ca665cb
 
 
 
6b5ac9a
 
 
3cf6c5a
6b5ac9a
 
3cf6c5a
212ff4c
 
 
 
 
 
 
 
 
 
 
c9b82b3
 
212ff4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b5ac9a
3cf6c5a
6b5ac9a
212ff4c
 
6b5ac9a
212ff4c
 
 
 
 
5f5f1b6
212ff4c
 
 
 
 
 
 
5f5f1b6
3e469b9
6b5ac9a
212ff4c
 
 
 
6b5ac9a
212ff4c
 
f0bc02d
 
 
212ff4c
 
 
 
 
 
5f5f1b6
212ff4c
 
 
 
 
 
 
5f5f1b6
212ff4c
 
 
5f5f1b6
 
 
212ff4c
 
 
 
 
 
3e469b9
212ff4c
 
 
 
 
6b5ac9a
 
 
 
 
 
212ff4c
 
 
3cf6c5a
 
6b5ac9a
 
3cf6c5a
6b5ac9a
 
ca665cb
 
484f007
ca665cb
 
 
 
 
 
 
 
 
 
 
 
 
6b5ac9a
3cf6c5a
6b5ac9a
183e719
 
 
 
6b5ac9a
 
 
 
 
 
 
aad4327
6b5ac9a
 
 
028eb6e
 
 
6b5ac9a
ca665cb
 
 
 
 
028eb6e
ca665cb
 
6b5ac9a
 
 
 
 
 
aad4327
6b5ac9a
 
 
 
 
aad4327
6b5ac9a
 
 
 
3cf6c5a
6b5ac9a
 
 
 
 
070f7e7
6b5ac9a
 
 
 
 
 
 
 
6919cca
 
9644b48
 
6919cca
 
 
 
 
183e719
 
6919cca
 
 
 
52d3389
 
 
 
 
 
 
 
 
 
 
 
 
 
9644b48
 
52d3389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
# __future__ annotations is necessary for the type hints to work in this file
from __future__ import annotations
from dataclasses import dataclass
from typing import AsyncGenerator, List, Dict, Any, Optional
import chromadb
from pydantic import BaseModel
from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep
from knowlang.configs.config import AppConfig, RerankerConfig, EmbeddingConfig
from knowlang.utils.fancy_log import FancyLogger
from pydantic_ai import Agent
import logfire
from pprint import pformat
from enum import Enum
from rich.console import Console
from knowlang.utils.model_provider import create_pydantic_model
from knowlang.utils.chunking_util import truncate_chunk
from knowlang.models.embeddings import EmbeddingInputType, generate_embedding
import voyageai
from voyageai.object.reranking import RerankingObject

LOG = FancyLogger(__name__)
console = Console()

class ChatStatus(str, Enum):
    """Enum for tracking chat progress status"""
    STARTING = "starting"
    POLISHING = "polishing"
    RETRIEVING = "retrieving"
    ANSWERING = "answering"
    COMPLETE = "complete"
    ERROR = "error"

class StreamingChatResult(BaseModel):
    """Extended chat result with streaming information"""
    answer: str
    retrieved_context: Optional[RetrievedContext] = None
    status: ChatStatus
    progress_message: str
    
    @classmethod
    def from_node(cls, node: BaseNode, state: ChatGraphState) -> StreamingChatResult:
        """Create a StreamingChatResult from a node's current state"""
        if isinstance(node, PolishQuestionNode):
            return cls(
                answer="",
                status=ChatStatus.POLISHING,
                progress_message=f"Refining question: '{state.original_question}'"
            )
        elif isinstance(node, RetrieveContextNode):
            return cls(
                answer="",
                status=ChatStatus.RETRIEVING,
                progress_message=f"Searching codebase with: '{state.polished_question or state.original_question}'"
            )
        elif isinstance(node, AnswerQuestionNode):
            context_msg = f"Found {len(state.retrieved_context.chunks)} relevant segments" if state.retrieved_context else "No context found"
            return cls(
                answer="",
                retrieved_context=state.retrieved_context,
                status=ChatStatus.ANSWERING,
                progress_message=f"Generating answer... {context_msg}"
            )
        else:
            return cls(
                answer="",
                status=ChatStatus.ERROR,
                progress_message=f"Unknown node type: {type(node).__name__}"
            )
    
    @classmethod
    def complete(cls, result: ChatResult) -> StreamingChatResult:
        """Create a completed StreamingChatResult"""
        return cls(
            answer=result.answer,
            retrieved_context=result.retrieved_context,
            status=ChatStatus.COMPLETE,
            progress_message="Response complete"
        )
    
    @classmethod
    def error(cls, error_msg: str) -> StreamingChatResult:
        """Create an error StreamingChatResult"""
        return cls(
            answer=f"Error: {error_msg}",
            status=ChatStatus.ERROR,
            progress_message=f"An error occurred: {error_msg}"
        )

class RetrievedContext(BaseModel):
    """Structure for retrieved context"""
    chunks: List[str]
    metadatas: List[Dict[str, Any]]

class ChatResult(BaseModel):
    """Final result from the chat graph"""
    answer: str
    retrieved_context: Optional[RetrievedContext] = None

@dataclass
class ChatGraphState:
    """State maintained throughout the graph execution"""
    original_question: str
    polished_question: Optional[str] = None
    retrieved_context: Optional[RetrievedContext] = None

@dataclass
class ChatGraphDeps:
    """Dependencies required by the graph"""
    collection: chromadb.Collection
    config: AppConfig


# Graph Nodes
@dataclass
class PolishQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
    """Node that polishes the user's question"""
    system_prompt = """You are a code question refinement expert. Your ONLY task is to rephrase questions 
to be more precise for code context retrieval. Follow these rules strictly:

1. Output ONLY the refined question - no explanations or analysis
2. Preserve the original intent completely
3. Add missing technical terms if obvious
4. Keep the question concise - ideally one sentence
5. Focus on searchable technical terms
6. Do not add speculative terms not implied by the original question

Example Input: "How do I use transformers for translation?"
Example Output: "How do I use the Transformers pipeline for machine translation tasks?"

Example Input: "Where is the config stored?"
Example Output: "Where is the configuration file or configuration settings stored in this codebase?"
    """

    async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
        # Create an agent for question polishing
        polish_agent = Agent(
            create_pydantic_model(
                model_provider=ctx.deps.config.llm.model_provider,
                model_name=ctx.deps.config.llm.model_name
            ),
            system_prompt=self.system_prompt
        )
        prompt = f"""Original question: "{ctx.state.original_question}"

Return ONLY the polished question - no explanations or analysis.
Focus on making the question more searchable while preserving its original intent."""
        
        result = await polish_agent.run(prompt)
        ctx.state.polished_question = result.data
        return RetrieveContextNode()

@dataclass
class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
    """Node that retrieves relevant code context using hybrid search: embeddings + reranking"""
    async def _get_initial_chunks(
        self, 
        query: str,
        embedding_config: EmbeddingConfig,
        collection: chromadb.Collection,
        n_results: int
    ) -> tuple[List[str], List[Dict], List[float]]:
        """Get initial chunks using embedding search"""
        question_embedding = generate_embedding(
            input=query,
            config=embedding_config,
            input_type=EmbeddingInputType.QUERY
        )
        
        results = collection.query(
            query_embeddings=question_embedding,
            n_results=n_results,
            include=['metadatas', 'documents', 'distances']
        )
        
        return (
            results['documents'][0],
            results['metadatas'][0],
            results['distances'][0]
        )

    async def _rerank_chunks(
        self,
        query: str,
        chunks: List[str],
        reranker_config: RerankerConfig,
    ) -> RerankingObject:
        """Rerank chunks using Voyage AI"""
        voyage_client = voyageai.Client()
        return voyage_client.rerank(
            query=query,
            documents=chunks,
            model=reranker_config.model_name,
            top_k=reranker_config.top_k,
            truncation=True
        )

    def _filter_by_distance(
        self,
        chunks: List[str],
        metadatas: List[Dict],
        distances: List[float],
        threshold: float
    ) -> tuple[List[str], List[Dict]]:
        """Filter chunks by distance threshold"""
        filtered_chunks = []
        filtered_metadatas = []
        
        for chunk, meta, dist in zip(chunks, metadatas, distances):
            if dist <= threshold:
                filtered_chunks.append(chunk)
                filtered_metadatas.append(meta)
                
        return filtered_chunks, filtered_metadatas
    
    async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
        try:
            # Get query
            query = ctx.state.polished_question or ctx.state.original_question
            
            # First pass: Get more candidates using embedding search
            initial_chunks, initial_metadatas, distances = await self._get_initial_chunks(
                query=query,
                embedding_config=ctx.deps.config.embedding,
                collection=ctx.deps.collection,
                n_results=min(ctx.deps.config.chat.max_context_chunks * 2, 50)
            )
                
            # Log top k initial results by distance
            top_k_initial = sorted(
                zip(initial_chunks, distances),
                key=lambda x: x[1]
            )[:ctx.deps.config.reranker.top_k]
            logfire.info('top k embedding search results: {results}', results=top_k_initial)
            top_k_initial_chunks = [chunk for chunk, _ in top_k_initial]
            
            # Only proceed to reranking if we have initial results
            if not initial_chunks:
                LOG.warning("No initial chunks found through embedding search")
                raise Exception("No chunks found through embedding search")

            # Second pass: Rerank the candidates
            try:
                if not ctx.deps.config.reranker.enabled:
                    raise Exception("Reranker is disabled")
                
                # Second pass: Rerank candidates
                reranking = await self._rerank_chunks(
                    query=query,
                    chunks=initial_chunks,
                    reranker_config=ctx.deps.config.reranker
                )
                logfire.info('top k reranking search results: {results}', results=reranking.results)
                
                # Build final context from reranked results
                relevant_chunks = []
                relevant_metadatas = []
                
                for result in reranking.results:
                    # Only include if score is good enough
                    if result.relevance_score >= ctx.deps.config.reranker.relevance_threshold:
                        relevant_chunks.append(result.document)
                        # Get corresponding metadata using original index
                        relevant_metadatas.append(initial_metadatas[result.index])
                    
                if not relevant_chunks:
                    raise Exception("No relevant chunks found through reranking")
                
                
            except Exception as e:
                # Fallback to distance-based filtering if reranking fails
                LOG.error(f"Reranking failed, falling back to distance-based filtering: {e}")
                relevant_chunks, relevant_metadatas = self._filter_by_distance(
                    chunks=top_k_initial_chunks,
                    metadatas=initial_metadatas,
                    distances=distances,
                    threshold=ctx.deps.config.chat.similarity_threshold
                )
            
            ctx.state.retrieved_context = RetrievedContext(
                chunks=relevant_chunks,
                metadatas=relevant_metadatas,
            )
            
        except Exception as e:
            LOG.error(f"Error in context retrieval: {e}")
            ctx.state.retrieved_context = RetrievedContext(chunks=[], metadatas=[])
        
        finally:
            return AnswerQuestionNode()

@dataclass
class AnswerQuestionNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
    """Node that generates the final answer"""
    system_prompt = """
You are an expert code assistant helping developers understand complex codebases. Follow these rules strictly:

1. ALWAYS answer the user's question - this is your primary task
2. Base your answer ONLY on the provided code context, not on general knowledge
3. When referencing code:
   - Cite specific files and line numbers
   - Quote relevant code snippets briefly
   - Explain why this code is relevant to the question
4. If you cannot find sufficient context to answer fully:
   - Clearly state what's missing
   - Explain what additional information would help
5. Focus on accuracy over comprehensiveness:
   - If you're unsure about part of your answer, explicitly say so
   - Better to acknowledge limitations than make assumptions

Remember: Your primary goal is answering the user's specific question, not explaining the entire codebase."""

    async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
        answer_agent = Agent(
            create_pydantic_model(
                model_provider=ctx.deps.config.llm.model_provider,
                model_name=ctx.deps.config.llm.model_name
            ),
            system_prompt=self.system_prompt
        )
        
        if not ctx.state.retrieved_context or not ctx.state.retrieved_context.chunks:
            return End(ChatResult(
                answer="I couldn't find any relevant code context for your question. "
                      "Could you please rephrase or be more specific?",
                retrieved_context=None,
            ))

        context = ctx.state.retrieved_context
        for chunk in context.chunks:
            chunk = truncate_chunk(chunk, ctx.deps.config.chat.max_length_per_chunk)

        prompt = f"""
Question: {ctx.state.original_question}

Relevant Code Context:
{context.chunks}

Provide a focused answer to the question based on the provided context.

Important: Stay focused on answering the specific question asked.
        """
        
        try:
            result = await answer_agent.run(prompt)
            return End(ChatResult(
                answer=result.data,
                retrieved_context=context,
            ))
        except Exception as e:
            LOG.error(f"Error generating answer: {e}")
            return End(ChatResult(
                answer="I encountered an error processing your question. Please try again.",
                retrieved_context=context,
            ))

# Create the graph
chat_graph = Graph(
    nodes=[PolishQuestionNode, RetrieveContextNode, AnswerQuestionNode]
)

async def process_chat(
    question: str,
    collection: chromadb.Collection,
    config: AppConfig
) -> ChatResult:
    """
    Process a chat question through the graph.
    This is the main entry point for chat processing.
    """
    state = ChatGraphState(original_question=question)
    deps = ChatGraphDeps(collection=collection, config=config)
    
    try:
        result, _history = await chat_graph.run(
            # Temporary fix to disable PolishQuestionNode
            RetrieveContextNode(),
            state=state,
            deps=deps
        )
    except Exception as e:
        LOG.error(f"Error processing chat in graph: {e}")
        console.print_exception()
        
        result = ChatResult(
            answer="I encountered an error processing your question. Please try again."
        )
    finally:
        return result
    
async def stream_chat_progress(
    question: str,
    collection: chromadb.Collection,
    config: AppConfig
) -> AsyncGenerator[StreamingChatResult, None]:
    """
    Stream chat progress through the graph.
    This is the main entry point for chat processing.
    """
    state = ChatGraphState(original_question=question)
    deps = ChatGraphDeps(collection=collection, config=config)
    
    # Temporary fix to disable PolishQuestionNode
    start_node = RetrieveContextNode()
    history: list[HistoryStep[ChatGraphState, ChatResult]] = []

    try:
        # Initial status
        yield StreamingChatResult(
            answer="",
            status=ChatStatus.STARTING,
            progress_message=f"Processing question: {question}"
        )

        with logfire.span(
            '{graph_name} run {start=}',
            graph_name='RAG_chat_graph',
            start=start_node,
        ) as run_span:
            current_node = start_node
            
            while True:
                # Yield current node's status before processing
                yield StreamingChatResult.from_node(current_node, state)
                
                try:
                    # Process the current node
                    next_node = await chat_graph.next(current_node, history, state=state, deps=deps, infer_name=False)
                    
                    if isinstance(next_node, End):
                        result: ChatResult = next_node.data
                        history.append(EndStep(result=next_node))
                        run_span.set_attribute('history', history)
                        # Yield final result
                        yield StreamingChatResult.complete(result)
                        return
                    elif isinstance(next_node, BaseNode):
                        current_node = next_node
                    else:
                        raise ValueError(f"Invalid node type: {type(next_node)}")
                        
                except Exception as node_error:
                    LOG.error(f"Error in node {current_node.__class__.__name__}: {node_error}")
                    yield StreamingChatResult.error(str(node_error))
                    return
                    
    except Exception as e:
        LOG.error(f"Error in stream_chat_progress: {e}")
        yield StreamingChatResult.error(str(e))
        return