Adding logic to rate limit, get JWT token with user identity
Browse files
    	
        app.py
    CHANGED
    
    | @@ -16,14 +16,12 @@ import threading | |
| 16 | 
             
            import hashlib
         | 
| 17 |  | 
| 18 | 
             
            # Rate limiting data structures
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            session_usage = defaultdict(int)
         | 
| 21 | 
             
            last_reset_time = time.time()
         | 
| 22 | 
             
            rate_limit_lock = threading.Lock()
         | 
| 23 |  | 
| 24 | 
             
            # Constants
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            MAX_SESSION_USAGE = 2
         | 
| 27 | 
             
            RESET_INTERVAL = 24 * 60 * 60  # 24 hours in seconds
         | 
| 28 |  | 
| 29 | 
             
            # Set up logging
         | 
| @@ -50,7 +48,6 @@ llm = Llama(model_path=model_path, n_ctx=1024, n_threads=8, n_gpu_layers=-1, ver | |
| 50 | 
             
            logger.info("8-bit model loaded successfully")
         | 
| 51 |  | 
| 52 | 
             
            # User data storage
         | 
| 53 | 
            -
            user_data = {}
         | 
| 54 | 
             
            token_to_problem_solution = {}
         | 
| 55 |  | 
| 56 | 
             
            # Generation parameters
         | 
| @@ -64,18 +61,6 @@ generation_kwargs = { | |
| 64 | 
             
                "repeat_penalty": 1.1
         | 
| 65 | 
             
            }
         | 
| 66 |  | 
| 67 | 
            -
            def generate_user_identifier(request: gr.Request) -> str:
         | 
| 68 | 
            -
                ip = request.client.ip
         | 
| 69 | 
            -
                user_agent = request.headers.get('User-Agent', '')
         | 
| 70 | 
            -
                return hashlib.sha256(f"{ip}{user_agent}".encode()).hexdigest()
         | 
| 71 | 
            -
             | 
| 72 | 
            -
            def generate_token(user_identifier: str) -> str:
         | 
| 73 | 
            -
                payload = {
         | 
| 74 | 
            -
                    'exp': int(time.time()) + 3600,  # 1 hour expiration
         | 
| 75 | 
            -
                    'user_id': user_identifier
         | 
| 76 | 
            -
                }
         | 
| 77 | 
            -
                return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
         | 
| 78 | 
            -
             | 
| 79 | 
             
            def verify_token(token: str) -> bool:
         | 
| 80 | 
             
                try:
         | 
| 81 | 
             
                    jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
         | 
| @@ -83,63 +68,52 @@ def verify_token(token: str) -> bool: | |
| 83 | 
             
                except jwt.PyJWTError:
         | 
| 84 | 
             
                    return False
         | 
| 85 |  | 
| 86 | 
            -
            def check_rate_limit( | 
| 87 | 
             
                global last_reset_time
         | 
| 88 | 
             
                with rate_limit_lock:
         | 
| 89 | 
             
                    current_time = time.time()
         | 
| 90 | 
             
                    if current_time - last_reset_time >= RESET_INTERVAL:
         | 
| 91 | 
            -
                         | 
| 92 | 
            -
                        session_usage.clear()
         | 
| 93 | 
             
                        last_reset_time = current_time
         | 
| 94 | 
            -
                    if  | 
| 95 | 
            -
                        return False, " | 
| 96 | 
            -
                     | 
| 97 | 
            -
                        return False, "Session rate limit exceeded. Please try again in 24 hours."
         | 
| 98 | 
            -
                    ip_usage[ip] += 1
         | 
| 99 | 
            -
                    session_usage[session] += 1
         | 
| 100 | 
             
                    return True, ""
         | 
| 101 |  | 
| 102 | 
            -
             | 
| 103 | 
             
            def extract_and_format_code(text):
         | 
| 104 | 
            -
                # Extract code between triple backticks
         | 
| 105 | 
             
                code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
         | 
| 106 | 
             
                if code_match:
         | 
| 107 | 
             
                    code = code_match.group(1)
         | 
| 108 | 
             
                else:
         | 
| 109 | 
             
                    code = text
         | 
| 110 | 
            -
             | 
| 111 | 
            -
                # Dedent the code to remove any common leading whitespace
         | 
| 112 | 
             
                code = textwrap.dedent(code)
         | 
| 113 | 
            -
             | 
| 114 | 
            -
                # Split the code into lines
         | 
| 115 | 
             
                lines = code.split('\n')
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                # Ensure proper indentation
         | 
| 118 | 
             
                indented_lines = []
         | 
| 119 | 
             
                for line in lines:
         | 
| 120 | 
             
                    if line.strip().startswith('class') or line.strip().startswith('def'):
         | 
| 121 | 
            -
                        indented_lines.append(line) | 
| 122 | 
            -
                    elif line.strip(): | 
| 123 | 
            -
                        indented_lines.append('    ' + line) | 
| 124 | 
             
                    else:
         | 
| 125 | 
            -
                        indented_lines.append(line) | 
| 126 | 
            -
             | 
| 127 | 
             
                formatted_code = '\n'.join(indented_lines)
         | 
| 128 | 
            -
             | 
| 129 | 
             
                try:
         | 
| 130 | 
             
                    return autopep8.fix_code(formatted_code)
         | 
| 131 | 
             
                except:
         | 
| 132 | 
             
                    return formatted_code
         | 
| 133 | 
            -
             | 
| 134 | 
             
            def generate_explanation(problem: str, solution: str, token: str) -> Dict[str, Any]:
         | 
| 135 | 
             
                if not verify_token(token):
         | 
| 136 | 
             
                    return {"error": "Invalid token"}
         | 
| 137 |  | 
|  | |
|  | |
|  | |
|  | |
| 138 | 
             
                problem_solution_hash = hashlib.sha256(f"{problem}{solution}".encode()).hexdigest()
         | 
| 139 | 
             
                if token not in token_to_problem_solution or token_to_problem_solution[token] != problem_solution_hash:
         | 
| 140 | 
             
                    return {"error": "No matching problem-solution pair found for this token"}
         | 
| 141 |  | 
| 142 | 
            -
                
         | 
| 143 | 
             
                system_prompt = "You are a Python coding assistant specialized in explaining LeetCode problem solutions. Provide a clear and concise explanation of the given solution."
         | 
| 144 | 
             
                full_prompt = f"""### Instruction:
         | 
| 145 | 
             
            {system_prompt}
         | 
| @@ -156,26 +130,20 @@ Explain this solution step by step. | |
| 156 | 
             
            Here's the explanation of the solution:
         | 
| 157 |  | 
| 158 | 
             
            """
         | 
| 159 | 
            -
                
         | 
| 160 | 
             
                generated_text = ""
         | 
| 161 | 
             
                for chunk in llm(full_prompt, stream=True, **generation_kwargs):
         | 
| 162 | 
            -
                    generated_text += chunk["choices"][ | 
| 163 |  | 
| 164 | 
             
                return {"explanation": generated_text}
         | 
| 165 |  | 
| 166 | 
            -
             | 
| 167 | 
            -
            def generate_solution(instruction: str, token: str, request: gr.Request) -> Dict[str, Any]:
         | 
| 168 | 
            -
                ip = request.client.ip
         | 
| 169 | 
            -
                session = request.client.session
         | 
| 170 | 
            -
                user_identifier = generate_user_identifier(request)
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                is_allowed, message = check_rate_limit(ip, session)
         | 
| 173 | 
            -
                if not is_allowed:
         | 
| 174 | 
            -
                    return {"error": message}
         | 
| 175 | 
            -
             | 
| 176 | 
             
                if not verify_token(token):
         | 
| 177 | 
             
                    return {"error": "Invalid token"}
         | 
| 178 |  | 
|  | |
|  | |
|  | |
|  | |
| 179 | 
             
                system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
         | 
| 180 | 
             
                full_prompt = f"""### Instruction:
         | 
| 181 | 
             
            {system_prompt}
         | 
| @@ -189,7 +157,6 @@ Here's the complete Python function implementation: | |
| 189 |  | 
| 190 | 
             
            ```python
         | 
| 191 | 
             
            """
         | 
| 192 | 
            -
                
         | 
| 193 | 
             
                generated_text = ""
         | 
| 194 | 
             
                for chunk in llm(full_prompt, stream=True, **generation_kwargs):
         | 
| 195 | 
             
                    generated_text += chunk["choices"][0]["text"]
         | 
| @@ -199,52 +166,24 @@ Here's the complete Python function implementation: | |
| 199 | 
             
                token_to_problem_solution[token] = problem_solution_hash
         | 
| 200 | 
             
                return {"solution": formatted_code}
         | 
| 201 |  | 
| 202 | 
            -
            def random_problem(token: str | 
| 203 | 
            -
                ip = request.client.ip
         | 
| 204 | 
            -
                session = request.client.session
         | 
| 205 | 
            -
                user_identifier = generate_user_identifier(request)
         | 
| 206 | 
            -
             | 
| 207 | 
            -
                is_allowed, message = check_rate_limit(ip, session)
         | 
| 208 | 
            -
                if not is_allowed:
         | 
| 209 | 
            -
                    return {"error": message}
         | 
| 210 | 
            -
             | 
| 211 | 
             
                if not verify_token(token):
         | 
| 212 | 
             
                    return {"error": "Invalid token"}
         | 
| 213 |  | 
| 214 | 
            -
                 | 
| 215 | 
            -
                 | 
|  | |
| 216 |  | 
| 217 | 
            -
                 | 
| 218 | 
             
                problem = random_item['instruction']
         | 
| 219 | 
            -
                user_data[token] = {"problem": problem, "solution": None}
         | 
| 220 | 
            -
                
         | 
| 221 | 
             
                return {"problem": problem}
         | 
| 222 |  | 
| 223 | 
            -
            def explain_solution(token: str, problem: str, solution: str, request: gr.Request) -> Dict[str, Any]:
         | 
| 224 | 
            -
                ip = request.client.ip
         | 
| 225 | 
            -
                session = request.client.session
         | 
| 226 | 
            -
                user_identifier = generate_user_identifier(request)
         | 
| 227 | 
            -
             | 
| 228 | 
            -
                is_allowed, message = check_rate_limit(ip, session)
         | 
| 229 | 
            -
                if not is_allowed:
         | 
| 230 | 
            -
                    return {"error": message}
         | 
| 231 | 
            -
             | 
| 232 | 
            -
                if not verify_token(token):
         | 
| 233 | 
            -
                    return {"error": "Invalid token"}
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                problem_solution_hash = hashlib.sha256(f"{problem}{solution}".encode()).hexdigest()
         | 
| 236 | 
            -
                if token not in token_to_problem_solution or token_to_problem_solution[token] != problem_solution_hash:
         | 
| 237 | 
            -
                    return {"error": "No matching problem-solution pair found for this token"}
         | 
| 238 | 
            -
             | 
| 239 | 
            -
                return generate_explanation(problem, solution, token)
         | 
| 240 | 
            -
             | 
| 241 | 
             
            # Create Gradio interfaces
         | 
| 242 | 
             
            generate_interface = gr.Interface(
         | 
| 243 | 
             
                fn=generate_solution,
         | 
| 244 | 
             
                inputs=[
         | 
| 245 | 
             
                    gr.Textbox(label="Problem Instruction"),
         | 
| 246 | 
            -
                    gr.Textbox(label="JWT Token") | 
| 247 | 
            -
                    gr.Request()
         | 
| 248 | 
             
                ],
         | 
| 249 | 
             
                outputs=gr.JSON(),
         | 
| 250 | 
             
                title="Generate Solution API",
         | 
| @@ -253,26 +192,22 @@ generate_interface = gr.Interface( | |
| 253 |  | 
| 254 | 
             
            random_problem_interface = gr.Interface(
         | 
| 255 | 
             
                fn=random_problem,
         | 
| 256 | 
            -
                inputs=[
         | 
| 257 | 
            -
                    gr.Textbox(label="JWT Token"),
         | 
| 258 | 
            -
                    gr.Request()
         | 
| 259 | 
            -
                ],
         | 
| 260 | 
             
                outputs=gr.JSON(),
         | 
| 261 | 
             
                title="Random Problem API",
         | 
| 262 | 
             
                description="Provide a valid JWT token to get a random LeetCode problem."
         | 
| 263 | 
             
            )
         | 
| 264 |  | 
| 265 | 
             
            explain_interface = gr.Interface(
         | 
| 266 | 
            -
                fn= | 
| 267 | 
             
                inputs=[
         | 
| 268 | 
            -
                    gr.Textbox(label="JWT Token"),
         | 
| 269 | 
             
                    gr.Textbox(label="Problem"),
         | 
| 270 | 
             
                    gr.Textbox(label="Solution"),
         | 
| 271 | 
            -
                    gr. | 
| 272 | 
             
                ],
         | 
| 273 | 
             
                outputs=gr.JSON(),
         | 
| 274 | 
             
                title="Explain Solution API",
         | 
| 275 | 
            -
                description="Provide a  | 
| 276 | 
             
            )
         | 
| 277 |  | 
| 278 | 
             
            demo = gr.TabbedInterface(
         | 
|  | |
| 16 | 
             
            import hashlib
         | 
| 17 |  | 
| 18 | 
             
            # Rate limiting data structures
         | 
| 19 | 
            +
            token_usage = defaultdict(int)
         | 
|  | |
| 20 | 
             
            last_reset_time = time.time()
         | 
| 21 | 
             
            rate_limit_lock = threading.Lock()
         | 
| 22 |  | 
| 23 | 
             
            # Constants
         | 
| 24 | 
            +
            MAX_TOKEN_USAGE = 10
         | 
|  | |
| 25 | 
             
            RESET_INTERVAL = 24 * 60 * 60  # 24 hours in seconds
         | 
| 26 |  | 
| 27 | 
             
            # Set up logging
         | 
|  | |
| 48 | 
             
            logger.info("8-bit model loaded successfully")
         | 
| 49 |  | 
| 50 | 
             
            # User data storage
         | 
|  | |
| 51 | 
             
            token_to_problem_solution = {}
         | 
| 52 |  | 
| 53 | 
             
            # Generation parameters
         | 
|  | |
| 61 | 
             
                "repeat_penalty": 1.1
         | 
| 62 | 
             
            }
         | 
| 63 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 64 | 
             
            def verify_token(token: str) -> bool:
         | 
| 65 | 
             
                try:
         | 
| 66 | 
             
                    jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
         | 
|  | |
| 68 | 
             
                except jwt.PyJWTError:
         | 
| 69 | 
             
                    return False
         | 
| 70 |  | 
| 71 | 
            +
            def check_rate_limit(token: str):
         | 
| 72 | 
             
                global last_reset_time
         | 
| 73 | 
             
                with rate_limit_lock:
         | 
| 74 | 
             
                    current_time = time.time()
         | 
| 75 | 
             
                    if current_time - last_reset_time >= RESET_INTERVAL:
         | 
| 76 | 
            +
                        token_usage.clear()
         | 
|  | |
| 77 | 
             
                        last_reset_time = current_time
         | 
| 78 | 
            +
                    if token_usage[token] >= MAX_TOKEN_USAGE:
         | 
| 79 | 
            +
                        return False, "Rate limit exceeded. Please try again later."
         | 
| 80 | 
            +
                    token_usage[token] += 1
         | 
|  | |
|  | |
|  | |
| 81 | 
             
                    return True, ""
         | 
| 82 |  | 
|  | |
| 83 | 
             
            def extract_and_format_code(text):
         | 
|  | |
| 84 | 
             
                code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
         | 
| 85 | 
             
                if code_match:
         | 
| 86 | 
             
                    code = code_match.group(1)
         | 
| 87 | 
             
                else:
         | 
| 88 | 
             
                    code = text
         | 
|  | |
|  | |
| 89 | 
             
                code = textwrap.dedent(code)
         | 
|  | |
|  | |
| 90 | 
             
                lines = code.split('\n')
         | 
|  | |
|  | |
| 91 | 
             
                indented_lines = []
         | 
| 92 | 
             
                for line in lines:
         | 
| 93 | 
             
                    if line.strip().startswith('class') or line.strip().startswith('def'):
         | 
| 94 | 
            +
                        indented_lines.append(line)
         | 
| 95 | 
            +
                    elif line.strip():
         | 
| 96 | 
            +
                        indented_lines.append('    ' + line)
         | 
| 97 | 
             
                    else:
         | 
| 98 | 
            +
                        indented_lines.append(line)
         | 
|  | |
| 99 | 
             
                formatted_code = '\n'.join(indented_lines)
         | 
|  | |
| 100 | 
             
                try:
         | 
| 101 | 
             
                    return autopep8.fix_code(formatted_code)
         | 
| 102 | 
             
                except:
         | 
| 103 | 
             
                    return formatted_code
         | 
| 104 | 
            +
             | 
| 105 | 
             
            def generate_explanation(problem: str, solution: str, token: str) -> Dict[str, Any]:
         | 
| 106 | 
             
                if not verify_token(token):
         | 
| 107 | 
             
                    return {"error": "Invalid token"}
         | 
| 108 |  | 
| 109 | 
            +
                is_allowed, message = check_rate_limit(token)
         | 
| 110 | 
            +
                if not is_allowed:
         | 
| 111 | 
            +
                    return {"error": message}
         | 
| 112 | 
            +
                
         | 
| 113 | 
             
                problem_solution_hash = hashlib.sha256(f"{problem}{solution}".encode()).hexdigest()
         | 
| 114 | 
             
                if token not in token_to_problem_solution or token_to_problem_solution[token] != problem_solution_hash:
         | 
| 115 | 
             
                    return {"error": "No matching problem-solution pair found for this token"}
         | 
| 116 |  | 
|  | |
| 117 | 
             
                system_prompt = "You are a Python coding assistant specialized in explaining LeetCode problem solutions. Provide a clear and concise explanation of the given solution."
         | 
| 118 | 
             
                full_prompt = f"""### Instruction:
         | 
| 119 | 
             
            {system_prompt}
         | 
|  | |
| 130 | 
             
            Here's the explanation of the solution:
         | 
| 131 |  | 
| 132 | 
             
            """
         | 
|  | |
| 133 | 
             
                generated_text = ""
         | 
| 134 | 
             
                for chunk in llm(full_prompt, stream=True, **generation_kwargs):
         | 
| 135 | 
            +
                    generated_text += chunk["choices"]["text"]
         | 
| 136 |  | 
| 137 | 
             
                return {"explanation": generated_text}
         | 
| 138 |  | 
| 139 | 
            +
            def generate_solution(instruction: str, token: str) -> Dict[str, Any]:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 140 | 
             
                if not verify_token(token):
         | 
| 141 | 
             
                    return {"error": "Invalid token"}
         | 
| 142 |  | 
| 143 | 
            +
                is_allowed, message = check_rate_limit(token)
         | 
| 144 | 
            +
                if not is_allowed:
         | 
| 145 | 
            +
                    return {"error": message}
         | 
| 146 | 
            +
             | 
| 147 | 
             
                system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
         | 
| 148 | 
             
                full_prompt = f"""### Instruction:
         | 
| 149 | 
             
            {system_prompt}
         | 
|  | |
| 157 |  | 
| 158 | 
             
            ```python
         | 
| 159 | 
             
            """
         | 
|  | |
| 160 | 
             
                generated_text = ""
         | 
| 161 | 
             
                for chunk in llm(full_prompt, stream=True, **generation_kwargs):
         | 
| 162 | 
             
                    generated_text += chunk["choices"][0]["text"]
         | 
|  | |
| 166 | 
             
                token_to_problem_solution[token] = problem_solution_hash
         | 
| 167 | 
             
                return {"solution": formatted_code}
         | 
| 168 |  | 
| 169 | 
            +
            def random_problem(token: str) -> Dict[str, Any]:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 170 | 
             
                if not verify_token(token):
         | 
| 171 | 
             
                    return {"error": "Invalid token"}
         | 
| 172 |  | 
| 173 | 
            +
                is_allowed, message = check_rate_limit(token)
         | 
| 174 | 
            +
                if not is_allowed:
         | 
| 175 | 
            +
                    return {"error": message}
         | 
| 176 |  | 
| 177 | 
            +
                random_item = random.choice(train_dataset)
         | 
| 178 | 
             
                problem = random_item['instruction']
         | 
|  | |
|  | |
| 179 | 
             
                return {"problem": problem}
         | 
| 180 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 181 | 
             
            # Create Gradio interfaces
         | 
| 182 | 
             
            generate_interface = gr.Interface(
         | 
| 183 | 
             
                fn=generate_solution,
         | 
| 184 | 
             
                inputs=[
         | 
| 185 | 
             
                    gr.Textbox(label="Problem Instruction"),
         | 
| 186 | 
            +
                    gr.Textbox(label="JWT Token")
         | 
|  | |
| 187 | 
             
                ],
         | 
| 188 | 
             
                outputs=gr.JSON(),
         | 
| 189 | 
             
                title="Generate Solution API",
         | 
|  | |
| 192 |  | 
| 193 | 
             
            random_problem_interface = gr.Interface(
         | 
| 194 | 
             
                fn=random_problem,
         | 
| 195 | 
            +
                inputs=[gr.Textbox(label="JWT Token")],
         | 
|  | |
|  | |
|  | |
| 196 | 
             
                outputs=gr.JSON(),
         | 
| 197 | 
             
                title="Random Problem API",
         | 
| 198 | 
             
                description="Provide a valid JWT token to get a random LeetCode problem."
         | 
| 199 | 
             
            )
         | 
| 200 |  | 
| 201 | 
             
            explain_interface = gr.Interface(
         | 
| 202 | 
            +
                fn=generate_explanation,
         | 
| 203 | 
             
                inputs=[
         | 
|  | |
| 204 | 
             
                    gr.Textbox(label="Problem"),
         | 
| 205 | 
             
                    gr.Textbox(label="Solution"),
         | 
| 206 | 
            +
                    gr.Textbox(label="JWT Token")
         | 
| 207 | 
             
                ],
         | 
| 208 | 
             
                outputs=gr.JSON(),
         | 
| 209 | 
             
                title="Explain Solution API",
         | 
| 210 | 
            +
                description="Provide a problem, solution, and valid JWT token to get an explanation of the solution."
         | 
| 211 | 
             
            )
         | 
| 212 |  | 
| 213 | 
             
            demo = gr.TabbedInterface(
         | 
