gabykim commited on
Commit
8f0a66b
·
1 Parent(s): 8e1fac7

gradio rate limiting

Browse files
src/know_lang_bot/chat_bot/chat_interface.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from know_lang_bot.config import AppConfig
3
  from know_lang_bot.utils.fancy_log import FancyLogger
 
4
  from know_lang_bot.chat_bot.chat_graph import stream_chat_progress, ChatStatus
5
  import chromadb
6
  from typing import List, Dict, AsyncGenerator
@@ -15,6 +16,7 @@ class CodeQAChatInterface:
15
  self.config = config
16
  self._init_chroma()
17
  self.codebase_dir = Path(config.db.codebase_directory)
 
18
 
19
  def _init_chroma(self):
20
  """Initialize ChromaDB connection"""
@@ -57,12 +59,32 @@ class CodeQAChatInterface:
57
  async def stream_response(
58
  self,
59
  message: str,
60
- history: List[ChatMessage]
 
61
  ) -> AsyncGenerator[List[ChatMessage], None]:
62
  """Stream chat responses with progress updates"""
63
  # Add user message
64
  history.append(ChatMessage(role="user", content=message))
65
  yield history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  current_progress: ChatMessage | None = None
68
  code_blocks_added = False
@@ -140,8 +162,8 @@ class CodeQAChatInterface:
140
  submit = gr.Button("Submit", scale=1)
141
  clear = gr.ClearButton([msg, chatbot], scale=1)
142
 
143
- async def respond(message: str, history: List[ChatMessage]) -> AsyncGenerator[List[ChatMessage], None]:
144
- async for updated_history in self.stream_response(message, history):
145
  yield updated_history
146
 
147
  # Set up event handlers
 
1
  import gradio as gr
2
  from know_lang_bot.config import AppConfig
3
  from know_lang_bot.utils.fancy_log import FancyLogger
4
+ from know_lang_bot.utils.rate_limiter import RateLimiter
5
  from know_lang_bot.chat_bot.chat_graph import stream_chat_progress, ChatStatus
6
  import chromadb
7
  from typing import List, Dict, AsyncGenerator
 
16
  self.config = config
17
  self._init_chroma()
18
  self.codebase_dir = Path(config.db.codebase_directory)
19
+ self.rate_limiter = RateLimiter()
20
 
21
  def _init_chroma(self):
22
  """Initialize ChromaDB connection"""
 
59
  async def stream_response(
60
  self,
61
  message: str,
62
+ history: List[ChatMessage],
63
+ request: gr.Request, # gradio injects the request object
64
  ) -> AsyncGenerator[List[ChatMessage], None]:
65
  """Stream chat responses with progress updates"""
66
  # Add user message
67
  history.append(ChatMessage(role="user", content=message))
68
  yield history
69
+
70
+ # Check rate limit before processing
71
+ client_ip : str = request.request.client.host
72
+ print(f"Client IP: {client_ip}")
73
+ if self.rate_limiter.check_rate_limit(client_ip):
74
+ wait_time = self.rate_limiter.get_remaining_time(client_ip)
75
+ rate_limit_message = (
76
+ f"Rate limit exceeded. Please wait {wait_time:.0f} seconds before sending another message."
77
+ )
78
+ history.append(ChatMessage(
79
+ role="assistant",
80
+ content=rate_limit_message,
81
+ metadata={
82
+ "title": "⚠️ Rate Limit Warning",
83
+ "status": "done"
84
+ }
85
+ ))
86
+ yield history
87
+ return
88
 
89
  current_progress: ChatMessage | None = None
90
  code_blocks_added = False
 
162
  submit = gr.Button("Submit", scale=1)
163
  clear = gr.ClearButton([msg, chatbot], scale=1)
164
 
165
+ async def respond(message: str, history: List[ChatMessage], request: gr.Request) -> AsyncGenerator[List[ChatMessage], None]:
166
+ async for updated_history in self.stream_response(message, history, request):
167
  yield updated_history
168
 
169
  # Set up event handlers
src/know_lang_bot/utils/rate_limiter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from collections import defaultdict
3
+ import time
4
+ from typing import Dict, List
5
+ from threading import Lock
6
+
7
+
8
+ @dataclass
9
+ class RateLimiter:
10
+ """Rate limiter implementation for chat interface"""
11
+ requests_per_minute: int = 2
12
+ window_size: int = 60 # seconds
13
+
14
+ # Track request timestamps per user IP
15
+ _requests: Dict[str, List[float]] = field(default_factory=lambda: defaultdict(list))
16
+ _lock: Lock = field(default_factory=Lock)
17
+
18
+ def _clean_old_requests(self, user_ip: str) -> None:
19
+ """Remove requests older than the window size"""
20
+ current_time = time.time()
21
+ with self._lock:
22
+ self._requests[user_ip] = [
23
+ timestamp for timestamp in self._requests[user_ip]
24
+ if current_time - timestamp < self.window_size
25
+ ]
26
+
27
+ def check_rate_limit(self, user_ip: str) -> bool:
28
+ """
29
+ Check if user has exceeded rate limit
30
+ Returns True if rate limit exceeded, False otherwise
31
+ """
32
+ self._clean_old_requests(user_ip)
33
+
34
+ with self._lock:
35
+ if len(self._requests[user_ip]) >= self.requests_per_minute:
36
+ return True
37
+
38
+ self._requests[user_ip].append(time.time())
39
+ return False
40
+
41
+ def get_remaining_time(self, user_ip: str) -> float:
42
+ """Get remaining time until next request is allowed"""
43
+ self._clean_old_requests(user_ip)
44
+
45
+ with self._lock:
46
+ if not self._requests[user_ip]:
47
+ return 0
48
+
49
+ oldest_request = min(self._requests[user_ip])
50
+ current_time = time.time()
51
+ time_until_reset = self.window_size - (current_time - oldest_request)
52
+
53
+ return max(0, time_until_reset)