Adding logic to rate limit, get JWT token with user identity
Browse files
app.py
CHANGED
@@ -9,23 +9,31 @@ import jwt
|
|
9 |
from typing import Dict, Any
|
10 |
import autopep8
|
11 |
import textwrap
|
12 |
-
|
13 |
from datasets import load_dataset
|
14 |
-
import
|
15 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Set up logging
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
-
# Load the dataset
|
22 |
dataset = load_dataset("sugiv/leetmonkey_python_dataset")
|
23 |
train_dataset = dataset["train"]
|
24 |
|
25 |
-
# Set up logging
|
26 |
-
logging.basicConfig(level=logging.INFO)
|
27 |
-
logger = logging.getLogger(__name__)
|
28 |
-
|
29 |
# JWT settings
|
30 |
JWT_SECRET = os.environ.get("JWT_SECRET")
|
31 |
if not JWT_SECRET:
|
@@ -41,7 +49,9 @@ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_NAME, cache_dir="./
|
|
41 |
llm = Llama(model_path=model_path, n_ctx=1024, n_threads=8, n_gpu_layers=-1, verbose=False, mlock=True)
|
42 |
logger.info("8-bit model loaded successfully")
|
43 |
|
|
|
44 |
user_data = {}
|
|
|
45 |
|
46 |
# Generation parameters
|
47 |
generation_kwargs = {
|
@@ -54,6 +64,18 @@ generation_kwargs = {
|
|
54 |
"repeat_penalty": 1.1
|
55 |
}
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def verify_token(token: str) -> bool:
|
58 |
try:
|
59 |
jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
@@ -61,6 +83,23 @@ def verify_token(token: str) -> bool:
|
|
61 |
except jwt.PyJWTError:
|
62 |
return False
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def extract_and_format_code(text):
|
65 |
# Extract code between triple backticks
|
66 |
code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
|
@@ -96,6 +135,11 @@ def generate_explanation(problem: str, solution: str, token: str) -> Dict[str, A
|
|
96 |
if not verify_token(token):
|
97 |
return {"error": "Invalid token"}
|
98 |
|
|
|
|
|
|
|
|
|
|
|
99 |
system_prompt = "You are a Python coding assistant specialized in explaining LeetCode problem solutions. Provide a clear and concise explanation of the given solution."
|
100 |
full_prompt = f"""### Instruction:
|
101 |
{system_prompt}
|
@@ -120,7 +164,15 @@ Here's the explanation of the solution:
|
|
120 |
return {"explanation": generated_text}
|
121 |
|
122 |
|
123 |
-
def generate_solution(instruction: str, token: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
if not verify_token(token):
|
125 |
return {"error": "Invalid token"}
|
126 |
|
@@ -143,10 +195,19 @@ Here's the complete Python function implementation:
|
|
143 |
generated_text += chunk["choices"][0]["text"]
|
144 |
|
145 |
formatted_code = extract_and_format_code(generated_text)
|
146 |
-
|
|
|
147 |
return {"solution": formatted_code}
|
148 |
|
149 |
-
def random_problem(token: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
if not verify_token(token):
|
151 |
return {"error": "Invalid token"}
|
152 |
|
@@ -159,22 +220,32 @@ def random_problem(token: str) -> Dict[str, Any]:
|
|
159 |
|
160 |
return {"problem": problem}
|
161 |
|
162 |
-
def explain_solution(token: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if not verify_token(token):
|
164 |
return {"error": "Invalid token"}
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
solution = user_data[token]["solution"]
|
171 |
-
|
172 |
return generate_explanation(problem, solution, token)
|
173 |
|
174 |
# Create Gradio interfaces
|
175 |
generate_interface = gr.Interface(
|
176 |
fn=generate_solution,
|
177 |
-
inputs=[
|
|
|
|
|
|
|
|
|
178 |
outputs=gr.JSON(),
|
179 |
title="Generate Solution API",
|
180 |
description="Provide a LeetCode problem instruction and a valid JWT token to generate a solution."
|
@@ -182,7 +253,10 @@ generate_interface = gr.Interface(
|
|
182 |
|
183 |
random_problem_interface = gr.Interface(
|
184 |
fn=random_problem,
|
185 |
-
inputs=
|
|
|
|
|
|
|
186 |
outputs=gr.JSON(),
|
187 |
title="Random Problem API",
|
188 |
description="Provide a valid JWT token to get a random LeetCode problem."
|
@@ -190,10 +264,15 @@ random_problem_interface = gr.Interface(
|
|
190 |
|
191 |
explain_interface = gr.Interface(
|
192 |
fn=explain_solution,
|
193 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
|
194 |
outputs=gr.JSON(),
|
195 |
title="Explain Solution API",
|
196 |
-
description="Provide a valid JWT token to get an explanation of the
|
197 |
)
|
198 |
|
199 |
demo = gr.TabbedInterface(
|
|
|
9 |
from typing import Dict, Any
|
10 |
import autopep8
|
11 |
import textwrap
|
|
|
12 |
from datasets import load_dataset
|
13 |
+
import time
|
14 |
+
from collections import defaultdict
|
15 |
+
import threading
|
16 |
+
import hashlib
|
17 |
+
|
18 |
+
# Rate limiting data structures
|
19 |
+
ip_usage = defaultdict(int)
|
20 |
+
session_usage = defaultdict(int)
|
21 |
+
last_reset_time = time.time()
|
22 |
+
rate_limit_lock = threading.Lock()
|
23 |
+
|
24 |
+
# Constants
|
25 |
+
MAX_IP_USAGE = 10
|
26 |
+
MAX_SESSION_USAGE = 2
|
27 |
+
RESET_INTERVAL = 24 * 60 * 60 # 24 hours in seconds
|
28 |
|
29 |
# Set up logging
|
30 |
logging.basicConfig(level=logging.INFO)
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
33 |
+
# Load the dataset
|
34 |
dataset = load_dataset("sugiv/leetmonkey_python_dataset")
|
35 |
train_dataset = dataset["train"]
|
36 |
|
|
|
|
|
|
|
|
|
37 |
# JWT settings
|
38 |
JWT_SECRET = os.environ.get("JWT_SECRET")
|
39 |
if not JWT_SECRET:
|
|
|
49 |
llm = Llama(model_path=model_path, n_ctx=1024, n_threads=8, n_gpu_layers=-1, verbose=False, mlock=True)
|
50 |
logger.info("8-bit model loaded successfully")
|
51 |
|
52 |
+
# User data storage
|
53 |
user_data = {}
|
54 |
+
token_to_problem_solution = {}
|
55 |
|
56 |
# Generation parameters
|
57 |
generation_kwargs = {
|
|
|
64 |
"repeat_penalty": 1.1
|
65 |
}
|
66 |
|
67 |
+
def generate_user_identifier(request: gr.Request) -> str:
|
68 |
+
ip = request.client.ip
|
69 |
+
user_agent = request.headers.get('User-Agent', '')
|
70 |
+
return hashlib.sha256(f"{ip}{user_agent}".encode()).hexdigest()
|
71 |
+
|
72 |
+
def generate_token(user_identifier: str) -> str:
|
73 |
+
payload = {
|
74 |
+
'exp': int(time.time()) + 3600, # 1 hour expiration
|
75 |
+
'user_id': user_identifier
|
76 |
+
}
|
77 |
+
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
78 |
+
|
79 |
def verify_token(token: str) -> bool:
|
80 |
try:
|
81 |
jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
|
|
83 |
except jwt.PyJWTError:
|
84 |
return False
|
85 |
|
86 |
+
def check_rate_limit(ip, session):
|
87 |
+
global last_reset_time
|
88 |
+
with rate_limit_lock:
|
89 |
+
current_time = time.time()
|
90 |
+
if current_time - last_reset_time >= RESET_INTERVAL:
|
91 |
+
ip_usage.clear()
|
92 |
+
session_usage.clear()
|
93 |
+
last_reset_time = current_time
|
94 |
+
if ip_usage[ip] >= MAX_IP_USAGE:
|
95 |
+
return False, "IP rate limit exceeded. Please try again in 24 hours."
|
96 |
+
if session_usage[session] >= MAX_SESSION_USAGE:
|
97 |
+
return False, "Session rate limit exceeded. Please try again in 24 hours."
|
98 |
+
ip_usage[ip] += 1
|
99 |
+
session_usage[session] += 1
|
100 |
+
return True, ""
|
101 |
+
|
102 |
+
|
103 |
def extract_and_format_code(text):
|
104 |
# Extract code between triple backticks
|
105 |
code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
|
|
|
135 |
if not verify_token(token):
|
136 |
return {"error": "Invalid token"}
|
137 |
|
138 |
+
problem_solution_hash = hashlib.sha256(f"{problem}{solution}".encode()).hexdigest()
|
139 |
+
if token not in token_to_problem_solution or token_to_problem_solution[token] != problem_solution_hash:
|
140 |
+
return {"error": "No matching problem-solution pair found for this token"}
|
141 |
+
|
142 |
+
|
143 |
system_prompt = "You are a Python coding assistant specialized in explaining LeetCode problem solutions. Provide a clear and concise explanation of the given solution."
|
144 |
full_prompt = f"""### Instruction:
|
145 |
{system_prompt}
|
|
|
164 |
return {"explanation": generated_text}
|
165 |
|
166 |
|
167 |
+
def generate_solution(instruction: str, token: str, request: gr.Request) -> Dict[str, Any]:
|
168 |
+
ip = request.client.ip
|
169 |
+
session = request.client.session
|
170 |
+
user_identifier = generate_user_identifier(request)
|
171 |
+
|
172 |
+
is_allowed, message = check_rate_limit(ip, session)
|
173 |
+
if not is_allowed:
|
174 |
+
return {"error": message}
|
175 |
+
|
176 |
if not verify_token(token):
|
177 |
return {"error": "Invalid token"}
|
178 |
|
|
|
195 |
generated_text += chunk["choices"][0]["text"]
|
196 |
|
197 |
formatted_code = extract_and_format_code(generated_text)
|
198 |
+
problem_solution_hash = hashlib.sha256(f"{instruction}{formatted_code}".encode()).hexdigest()
|
199 |
+
token_to_problem_solution[token] = problem_solution_hash
|
200 |
return {"solution": formatted_code}
|
201 |
|
202 |
+
def random_problem(token: str, request: gr.Request) -> Dict[str, Any]:
|
203 |
+
ip = request.client.ip
|
204 |
+
session = request.client.session
|
205 |
+
user_identifier = generate_user_identifier(request)
|
206 |
+
|
207 |
+
is_allowed, message = check_rate_limit(ip, session)
|
208 |
+
if not is_allowed:
|
209 |
+
return {"error": message}
|
210 |
+
|
211 |
if not verify_token(token):
|
212 |
return {"error": "Invalid token"}
|
213 |
|
|
|
220 |
|
221 |
return {"problem": problem}
|
222 |
|
223 |
+
def explain_solution(token: str, problem: str, solution: str, request: gr.Request) -> Dict[str, Any]:
|
224 |
+
ip = request.client.ip
|
225 |
+
session = request.client.session
|
226 |
+
user_identifier = generate_user_identifier(request)
|
227 |
+
|
228 |
+
is_allowed, message = check_rate_limit(ip, session)
|
229 |
+
if not is_allowed:
|
230 |
+
return {"error": message}
|
231 |
+
|
232 |
if not verify_token(token):
|
233 |
return {"error": "Invalid token"}
|
234 |
+
|
235 |
+
problem_solution_hash = hashlib.sha256(f"{problem}{solution}".encode()).hexdigest()
|
236 |
+
if token not in token_to_problem_solution or token_to_problem_solution[token] != problem_solution_hash:
|
237 |
+
return {"error": "No matching problem-solution pair found for this token"}
|
238 |
+
|
|
|
|
|
239 |
return generate_explanation(problem, solution, token)
|
240 |
|
241 |
# Create Gradio interfaces
|
242 |
generate_interface = gr.Interface(
|
243 |
fn=generate_solution,
|
244 |
+
inputs=[
|
245 |
+
gr.Textbox(label="Problem Instruction"),
|
246 |
+
gr.Textbox(label="JWT Token"),
|
247 |
+
gr.Request()
|
248 |
+
],
|
249 |
outputs=gr.JSON(),
|
250 |
title="Generate Solution API",
|
251 |
description="Provide a LeetCode problem instruction and a valid JWT token to generate a solution."
|
|
|
253 |
|
254 |
random_problem_interface = gr.Interface(
|
255 |
fn=random_problem,
|
256 |
+
inputs=[
|
257 |
+
gr.Textbox(label="JWT Token"),
|
258 |
+
gr.Request()
|
259 |
+
],
|
260 |
outputs=gr.JSON(),
|
261 |
title="Random Problem API",
|
262 |
description="Provide a valid JWT token to get a random LeetCode problem."
|
|
|
264 |
|
265 |
explain_interface = gr.Interface(
|
266 |
fn=explain_solution,
|
267 |
+
inputs=[
|
268 |
+
gr.Textbox(label="JWT Token"),
|
269 |
+
gr.Textbox(label="Problem"),
|
270 |
+
gr.Textbox(label="Solution"),
|
271 |
+
gr.Request()
|
272 |
+
],
|
273 |
outputs=gr.JSON(),
|
274 |
title="Explain Solution API",
|
275 |
+
description="Provide a valid JWT token, problem, and solution to get an explanation of the solution."
|
276 |
)
|
277 |
|
278 |
demo = gr.TabbedInterface(
|