gabykim commited on
Commit
438162b
·
2 Parent(s): c7e758e 52d3389

Merge branch 'main' into huggingface

Browse files
src/know_lang_bot/chat_bot/chat_graph.py CHANGED
@@ -1,20 +1,85 @@
1
  # __future__ annotations is necessary for the type hints to work in this file
2
  from __future__ import annotations
3
  from dataclasses import dataclass
4
- from typing import List, Dict, Any, Optional
5
  import chromadb
6
  from pydantic import BaseModel
7
- from pydantic_graph import BaseNode, Graph, GraphRunContext, End
8
  import ollama
9
  from know_lang_bot.config import AppConfig
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pydantic_ai import Agent
12
  import logfire
13
  from pprint import pformat
 
14
 
15
  LOG = FancyLogger(__name__)
16
 
17
- # Data Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class RetrievedContext(BaseModel):
19
  """Structure for retrieved context"""
20
  chunks: List[str]
@@ -206,4 +271,66 @@ async def process_chat(
206
  answer="I encountered an error processing your question. Please try again."
207
  )
208
  finally:
209
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # __future__ annotations is necessary for the type hints to work in this file
2
  from __future__ import annotations
3
  from dataclasses import dataclass
4
+ from typing import AsyncGenerator, List, Dict, Any, Optional
5
  import chromadb
6
  from pydantic import BaseModel
7
+ from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep
8
  import ollama
9
  from know_lang_bot.config import AppConfig
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pydantic_ai import Agent
12
  import logfire
13
  from pprint import pformat
14
+ from enum import Enum
15
 
16
  LOG = FancyLogger(__name__)
17
 
18
+ class ChatStatus(str, Enum):
19
+ """Enum for tracking chat progress status"""
20
+ STARTING = "starting"
21
+ POLISHING = "polishing"
22
+ RETRIEVING = "retrieving"
23
+ ANSWERING = "answering"
24
+ COMPLETE = "complete"
25
+ ERROR = "error"
26
+
27
+ class StreamingChatResult(BaseModel):
28
+ """Extended chat result with streaming information"""
29
+ answer: str
30
+ retrieved_context: Optional[RetrievedContext] = None
31
+ status: ChatStatus
32
+ progress_message: str
33
+
34
+ @classmethod
35
+ def from_node(cls, node: BaseNode, state: ChatGraphState) -> StreamingChatResult:
36
+ """Create a StreamingChatResult from a node's current state"""
37
+ if isinstance(node, PolishQuestionNode):
38
+ return cls(
39
+ answer="",
40
+ status=ChatStatus.POLISHING,
41
+ progress_message=f"Refining question: '{state.original_question}'"
42
+ )
43
+ elif isinstance(node, RetrieveContextNode):
44
+ return cls(
45
+ answer="",
46
+ status=ChatStatus.RETRIEVING,
47
+ progress_message=f"Searching codebase with: '{state.polished_question or state.original_question}'"
48
+ )
49
+ elif isinstance(node, AnswerQuestionNode):
50
+ context_msg = f"Found {len(state.retrieved_context.chunks)} relevant segments" if state.retrieved_context else "No context found"
51
+ return cls(
52
+ answer="",
53
+ retrieved_context=state.retrieved_context,
54
+ status=ChatStatus.ANSWERING,
55
+ progress_message=f"Generating answer... {context_msg}"
56
+ )
57
+ else:
58
+ return cls(
59
+ answer="",
60
+ status=ChatStatus.ERROR,
61
+ progress_message=f"Unknown node type: {type(node).__name__}"
62
+ )
63
+
64
+ @classmethod
65
+ def complete(cls, result: ChatResult) -> StreamingChatResult:
66
+ """Create a completed StreamingChatResult"""
67
+ return cls(
68
+ answer=result.answer,
69
+ retrieved_context=result.retrieved_context,
70
+ status=ChatStatus.COMPLETE,
71
+ progress_message="Response complete"
72
+ )
73
+
74
+ @classmethod
75
+ def error(cls, error_msg: str) -> StreamingChatResult:
76
+ """Create an error StreamingChatResult"""
77
+ return cls(
78
+ answer=f"Error: {error_msg}",
79
+ status=ChatStatus.ERROR,
80
+ progress_message=f"An error occurred: {error_msg}"
81
+ )
82
+
83
  class RetrievedContext(BaseModel):
84
  """Structure for retrieved context"""
85
  chunks: List[str]
 
271
  answer="I encountered an error processing your question. Please try again."
272
  )
273
  finally:
274
+ return result
275
+
276
+ async def stream_chat_progress(
277
+ question: str,
278
+ collection: chromadb.Collection,
279
+ config: AppConfig
280
+ ) -> AsyncGenerator[StreamingChatResult, None]:
281
+ """
282
+ Stream chat progress through the graph.
283
+ This is the main entry point for chat processing.
284
+ """
285
+ state = ChatGraphState(original_question=question)
286
+ deps = ChatGraphDeps(collection=collection, config=config)
287
+
288
+ start_node = PolishQuestionNode()
289
+ history: list[HistoryStep[ChatGraphState, ChatResult]] = []
290
+
291
+ try:
292
+ # Initial status
293
+ yield StreamingChatResult(
294
+ answer="",
295
+ status=ChatStatus.STARTING,
296
+ progress_message=f"Processing question: {question}"
297
+ )
298
+
299
+ with logfire.span(
300
+ '{graph_name} run {start=}',
301
+ graph_name='RAG_chat_graph',
302
+ start=start_node,
303
+ ) as run_span:
304
+ current_node = start_node
305
+
306
+ while True:
307
+ # Yield current node's status before processing
308
+ yield StreamingChatResult.from_node(current_node, state)
309
+
310
+ try:
311
+ # Process the current node
312
+ next_node = await chat_graph.next(current_node, history, state=state, deps=deps, infer_name=False)
313
+
314
+ if isinstance(next_node, End):
315
+ result: ChatResult = next_node.data
316
+ history.append(EndStep(result=next_node))
317
+ run_span.set_attribute('history', history)
318
+ # Yield final result
319
+ yield StreamingChatResult.complete(result)
320
+ return
321
+ elif isinstance(next_node, BaseNode):
322
+ current_node = next_node
323
+ else:
324
+ raise ValueError(f"Invalid node type: {type(next_node)}")
325
+
326
+ except Exception as node_error:
327
+ LOG.error(f"Error in node {current_node.__class__.__name__}: {node_error}")
328
+ yield StreamingChatResult.error(str(node_error))
329
+ return
330
+
331
+ except Exception as e:
332
+ LOG.error(f"Error in stream_chat_progress: {e}")
333
+ yield StreamingChatResult.error(str(e))
334
+ return
335
+
336
+
src/know_lang_bot/chat_bot/chat_interface.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
  from know_lang_bot.config import AppConfig
3
  from know_lang_bot.utils.fancy_log import FancyLogger
4
- from know_lang_bot.chat_bot.chat_graph import ChatResult, process_chat
5
  import chromadb
6
- from typing import List, Dict
7
- import logfire
8
  from pathlib import Path
 
 
9
 
10
  LOG = FancyLogger(__name__)
11
 
@@ -35,16 +36,86 @@ class CodeQAChatInterface:
35
  except Exception as e:
36
  LOG.error(f"Error reading code block: {e}")
37
  return "Error reading code"
38
-
39
- @logfire.instrument('Chatbot Process Question with {message=}')
40
- async def process_question(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  self,
42
  message: str,
43
- history: List[Dict[str, str]]
44
- ) -> ChatResult:
45
- """Process a question and return the answer with references"""
46
- return await process_chat(message, self.collection, self.config)
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def create_interface(self) -> gr.Blocks:
49
  """Create the Gradio interface"""
50
  with gr.Blocks() as interface:
@@ -54,66 +125,28 @@ class CodeQAChatInterface:
54
  chatbot = gr.Chatbot(
55
  type="messages",
56
  bubble_full_width=False,
57
- render_markdown=True
 
58
  )
59
 
60
  msg = gr.Textbox(
61
  label="Ask about the codebase",
62
  placeholder="What does the CodeParser class do?",
63
- container=False
 
64
  )
65
 
66
  with gr.Row():
67
- submit = gr.Button("Submit")
68
- clear = gr.ClearButton([msg, chatbot])
69
-
70
- async def respond(message, history):
71
- result = await self.process_question(message, history)
72
-
73
- # Format the answer with code blocks
74
- formatted_messages = []
75
-
76
- # Add user message
77
- formatted_messages.append({
78
- "role": "user",
79
- "content": message
80
- })
81
-
82
- # Collect code blocks first
83
- code_blocks = []
84
- if result.retrieved_context and result.retrieved_context.metadatas:
85
- for metadata in result.retrieved_context.metadatas:
86
- file_path = metadata['file_path']
87
- start_line = metadata['start_line']
88
- end_line = metadata['end_line']
89
-
90
- code = self._get_code_block(file_path, start_line, end_line)
91
- if code:
92
- title = f"📄 {file_path} (lines {start_line}-{end_line})"
93
- if metadata.get('name'):
94
- title += f" - {metadata['type']}: {metadata['name']}"
95
-
96
- code_blocks.append({
97
- "role": "assistant",
98
- "content": f"<details><summary>{title}</summary>\n\n```python\n{code}\n```\n\n</details>",
99
- })
100
-
101
- # Add code blocks before the answer
102
- formatted_messages.extend(code_blocks)
103
-
104
- # Add assistant's answer
105
- formatted_messages.append({
106
- "role": "assistant",
107
- "content": result.answer
108
- })
109
-
110
- return {
111
- msg: "",
112
- chatbot: history + formatted_messages
113
- }
114
 
115
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
116
- submit.click(respond, [msg, chatbot], [msg, chatbot])
 
 
 
 
 
117
 
118
  return interface
119
 
 
1
  import gradio as gr
2
  from know_lang_bot.config import AppConfig
3
  from know_lang_bot.utils.fancy_log import FancyLogger
4
+ from know_lang_bot.chat_bot.chat_graph import stream_chat_progress, ChatStatus
5
  import chromadb
6
+ from typing import List, Dict, AsyncGenerator
 
7
  from pathlib import Path
8
+ from gradio import ChatMessage
9
+
10
 
11
  LOG = FancyLogger(__name__)
12
 
 
36
  except Exception as e:
37
  LOG.error(f"Error reading code block: {e}")
38
  return "Error reading code"
39
+
40
+ def _format_code_block(self, metadata: Dict) -> str:
41
+ """Format a single code block with metadata"""
42
+ file_path = metadata['file_path']
43
+ start_line = metadata['start_line']
44
+ end_line = metadata['end_line']
45
+
46
+ code = self._get_code_block(file_path, start_line, end_line)
47
+ if not code:
48
+ return None
49
+
50
+ title = f"📄 {file_path} (lines {start_line}-{end_line})"
51
+ if metadata.get('name'):
52
+ title += f" - {metadata['type']}: {metadata['name']}"
53
+
54
+
55
+ return f"<details><summary>{title}</summary>\n\n```python\n{code}\n```\n\n</details>"
56
+
57
+ async def stream_response(
58
  self,
59
  message: str,
60
+ history: List[ChatMessage]
61
+ ) -> AsyncGenerator[List[ChatMessage], None]:
62
+ """Stream chat responses with progress updates"""
63
+ # Add user message
64
+ history.append(ChatMessage(role="user", content=message))
65
+ yield history
66
+
67
+ current_progress: ChatMessage | None = None
68
+ code_blocks_added = False
69
+
70
+ async for result in stream_chat_progress(message, self.collection, self.config):
71
+ # Handle progress updates
72
+ if result.status != ChatStatus.COMPLETE:
73
+ if current_progress:
74
+ history.remove(current_progress)
75
+
76
+ current_progress = ChatMessage(
77
+ role="assistant",
78
+ content=result.progress_message,
79
+ metadata={
80
+ "title": f"{result.status.value.title()} Progress",
81
+ "status": "pending" if result.status != ChatStatus.ERROR else "error"
82
+ }
83
+ )
84
+ history.append(current_progress)
85
+ yield history
86
+ continue
87
+
88
+ # When complete, remove progress message and add final content
89
+ if current_progress:
90
+ history.remove(current_progress)
91
+ current_progress = None
92
+
93
+ # Add code blocks before final answer if not added yet
94
+ if not code_blocks_added and result.retrieved_context and result.retrieved_context.metadatas:
95
+ total_code_blocks = []
96
+ for metadata in result.retrieved_context.metadatas:
97
+ code_block = self._format_code_block(metadata)
98
+ if code_block:
99
+ total_code_blocks.append(code_block)
100
+
101
+ code_blocks_added = True
102
+ history.append(ChatMessage(
103
+ role="assistant",
104
+ content='\n\n'.join(total_code_blocks),
105
+ metadata={
106
+ "title": "💻 Code Context",
107
+ "collapsible": True
108
+ }
109
+ ))
110
+ yield history
111
+
112
+ # Add final answer
113
+ history.append(ChatMessage(
114
+ role="assistant",
115
+ content=result.answer
116
+ ))
117
+ yield history
118
+
119
  def create_interface(self) -> gr.Blocks:
120
  """Create the Gradio interface"""
121
  with gr.Blocks() as interface:
 
125
  chatbot = gr.Chatbot(
126
  type="messages",
127
  bubble_full_width=False,
128
+ render_markdown=True,
129
+ height=600
130
  )
131
 
132
  msg = gr.Textbox(
133
  label="Ask about the codebase",
134
  placeholder="What does the CodeParser class do?",
135
+ container=False,
136
+ scale=7
137
  )
138
 
139
  with gr.Row():
140
+ submit = gr.Button("Submit", scale=1)
141
+ clear = gr.ClearButton([msg, chatbot], scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ async def respond(message: str, history: List[ChatMessage]) -> AsyncGenerator[List[ChatMessage], None]:
144
+ async for updated_history in self.stream_response(message, history):
145
+ yield updated_history
146
+
147
+ # Set up event handlers
148
+ msg.submit(respond, [msg, chatbot], [chatbot])
149
+ submit.click(respond, [msg, chatbot], [chatbot])
150
 
151
  return interface
152