File size: 1,826 Bytes
8f0a66b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from dataclasses import dataclass, field
from collections import defaultdict
import time
from typing import Dict, List
from threading import Lock


@dataclass
class RateLimiter:
    """Rate limiter implementation for chat interface"""
    requests_per_minute: int = 2
    window_size: int = 60  # seconds
    
    # Track request timestamps per user IP
    _requests: Dict[str, List[float]] = field(default_factory=lambda: defaultdict(list))
    _lock: Lock = field(default_factory=Lock)

    def _clean_old_requests(self, user_ip: str) -> None:
        """Remove requests older than the window size"""
        current_time = time.time()
        with self._lock:
            self._requests[user_ip] = [
                timestamp for timestamp in self._requests[user_ip]
                if current_time - timestamp < self.window_size
            ]

    def check_rate_limit(self, user_ip: str) -> bool:
        """
        Check if user has exceeded rate limit
        Returns True if rate limit exceeded, False otherwise
        """
        self._clean_old_requests(user_ip)
        
        with self._lock:
            if len(self._requests[user_ip]) >= self.requests_per_minute:
                return True
                
            self._requests[user_ip].append(time.time())
            return False

    def get_remaining_time(self, user_ip: str) -> float:
        """Get remaining time until next request is allowed"""
        self._clean_old_requests(user_ip)
        
        with self._lock:
            if not self._requests[user_ip]:
                return 0
                
            oldest_request = min(self._requests[user_ip])
            current_time = time.time()
            time_until_reset = self.window_size - (current_time - oldest_request)
            
            return max(0, time_until_reset)