sugiv commited on
Commit
283cfad
·
1 Parent(s): debd5cf

Adding logic to rate limit, get JWT token with user identity

Browse files
Files changed (1) hide show
  1. app.py +102 -23
app.py CHANGED
@@ -9,23 +9,31 @@ import jwt
9
  from typing import Dict, Any
10
  import autopep8
11
  import textwrap
12
-
13
  from datasets import load_dataset
14
- import random
15
- import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- # Load the dataset (you might want to do this once at the start of your script)
22
  dataset = load_dataset("sugiv/leetmonkey_python_dataset")
23
  train_dataset = dataset["train"]
24
 
25
- # Set up logging
26
- logging.basicConfig(level=logging.INFO)
27
- logger = logging.getLogger(__name__)
28
-
29
  # JWT settings
30
  JWT_SECRET = os.environ.get("JWT_SECRET")
31
  if not JWT_SECRET:
@@ -41,7 +49,9 @@ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_NAME, cache_dir="./
41
  llm = Llama(model_path=model_path, n_ctx=1024, n_threads=8, n_gpu_layers=-1, verbose=False, mlock=True)
42
  logger.info("8-bit model loaded successfully")
43
 
 
44
  user_data = {}
 
45
 
46
  # Generation parameters
47
  generation_kwargs = {
@@ -54,6 +64,18 @@ generation_kwargs = {
54
  "repeat_penalty": 1.1
55
  }
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def verify_token(token: str) -> bool:
58
  try:
59
  jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
@@ -61,6 +83,23 @@ def verify_token(token: str) -> bool:
61
  except jwt.PyJWTError:
62
  return False
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def extract_and_format_code(text):
65
  # Extract code between triple backticks
66
  code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
@@ -96,6 +135,11 @@ def generate_explanation(problem: str, solution: str, token: str) -> Dict[str, A
96
  if not verify_token(token):
97
  return {"error": "Invalid token"}
98
 
 
 
 
 
 
99
  system_prompt = "You are a Python coding assistant specialized in explaining LeetCode problem solutions. Provide a clear and concise explanation of the given solution."
100
  full_prompt = f"""### Instruction:
101
  {system_prompt}
@@ -120,7 +164,15 @@ Here's the explanation of the solution:
120
  return {"explanation": generated_text}
121
 
122
 
123
- def generate_solution(instruction: str, token: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
124
  if not verify_token(token):
125
  return {"error": "Invalid token"}
126
 
@@ -143,10 +195,19 @@ Here's the complete Python function implementation:
143
  generated_text += chunk["choices"][0]["text"]
144
 
145
  formatted_code = extract_and_format_code(generated_text)
146
- user_data[token] = {"problem": instruction, "solution": formatted_code}
 
147
  return {"solution": formatted_code}
148
 
149
- def random_problem(token: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
150
  if not verify_token(token):
151
  return {"error": "Invalid token"}
152
 
@@ -159,22 +220,32 @@ def random_problem(token: str) -> Dict[str, Any]:
159
 
160
  return {"problem": problem}
161
 
162
- def explain_solution(token: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
163
  if not verify_token(token):
164
  return {"error": "Invalid token"}
165
-
166
- if token not in user_data or not user_data[token].get("solution"):
167
- return {"error": "No solution available to explain. Please generate a solution first."}
168
-
169
- problem = user_data[token]["problem"]
170
- solution = user_data[token]["solution"]
171
-
172
  return generate_explanation(problem, solution, token)
173
 
174
  # Create Gradio interfaces
175
  generate_interface = gr.Interface(
176
  fn=generate_solution,
177
- inputs=[gr.Textbox(label="Problem Instruction"), gr.Textbox(label="JWT Token")],
 
 
 
 
178
  outputs=gr.JSON(),
179
  title="Generate Solution API",
180
  description="Provide a LeetCode problem instruction and a valid JWT token to generate a solution."
@@ -182,7 +253,10 @@ generate_interface = gr.Interface(
182
 
183
  random_problem_interface = gr.Interface(
184
  fn=random_problem,
185
- inputs=gr.Textbox(label="JWT Token"),
 
 
 
186
  outputs=gr.JSON(),
187
  title="Random Problem API",
188
  description="Provide a valid JWT token to get a random LeetCode problem."
@@ -190,10 +264,15 @@ random_problem_interface = gr.Interface(
190
 
191
  explain_interface = gr.Interface(
192
  fn=explain_solution,
193
- inputs=gr.Textbox(label="JWT Token"),
 
 
 
 
 
194
  outputs=gr.JSON(),
195
  title="Explain Solution API",
196
- description="Provide a valid JWT token to get an explanation of the last generated solution."
197
  )
198
 
199
  demo = gr.TabbedInterface(
 
9
  from typing import Dict, Any
10
  import autopep8
11
  import textwrap
 
12
  from datasets import load_dataset
13
+ import time
14
+ from collections import defaultdict
15
+ import threading
16
+ import hashlib
17
+
18
+ # Rate limiting data structures
19
+ ip_usage = defaultdict(int)
20
+ session_usage = defaultdict(int)
21
+ last_reset_time = time.time()
22
+ rate_limit_lock = threading.Lock()
23
+
24
+ # Constants
25
+ MAX_IP_USAGE = 10
26
+ MAX_SESSION_USAGE = 2
27
+ RESET_INTERVAL = 24 * 60 * 60 # 24 hours in seconds
28
 
29
  # Set up logging
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
32
 
33
+ # Load the dataset
34
  dataset = load_dataset("sugiv/leetmonkey_python_dataset")
35
  train_dataset = dataset["train"]
36
 
 
 
 
 
37
  # JWT settings
38
  JWT_SECRET = os.environ.get("JWT_SECRET")
39
  if not JWT_SECRET:
 
49
  llm = Llama(model_path=model_path, n_ctx=1024, n_threads=8, n_gpu_layers=-1, verbose=False, mlock=True)
50
  logger.info("8-bit model loaded successfully")
51
 
52
+ # User data storage
53
  user_data = {}
54
+ token_to_problem_solution = {}
55
 
56
  # Generation parameters
57
  generation_kwargs = {
 
64
  "repeat_penalty": 1.1
65
  }
66
 
67
+ def generate_user_identifier(request: gr.Request) -> str:
68
+ ip = request.client.ip
69
+ user_agent = request.headers.get('User-Agent', '')
70
+ return hashlib.sha256(f"{ip}{user_agent}".encode()).hexdigest()
71
+
72
+ def generate_token(user_identifier: str) -> str:
73
+ payload = {
74
+ 'exp': int(time.time()) + 3600, # 1 hour expiration
75
+ 'user_id': user_identifier
76
+ }
77
+ return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
78
+
79
  def verify_token(token: str) -> bool:
80
  try:
81
  jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
 
83
  except jwt.PyJWTError:
84
  return False
85
 
86
+ def check_rate_limit(ip, session):
87
+ global last_reset_time
88
+ with rate_limit_lock:
89
+ current_time = time.time()
90
+ if current_time - last_reset_time >= RESET_INTERVAL:
91
+ ip_usage.clear()
92
+ session_usage.clear()
93
+ last_reset_time = current_time
94
+ if ip_usage[ip] >= MAX_IP_USAGE:
95
+ return False, "IP rate limit exceeded. Please try again in 24 hours."
96
+ if session_usage[session] >= MAX_SESSION_USAGE:
97
+ return False, "Session rate limit exceeded. Please try again in 24 hours."
98
+ ip_usage[ip] += 1
99
+ session_usage[session] += 1
100
+ return True, ""
101
+
102
+
103
  def extract_and_format_code(text):
104
  # Extract code between triple backticks
105
  code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
 
135
  if not verify_token(token):
136
  return {"error": "Invalid token"}
137
 
138
+ problem_solution_hash = hashlib.sha256(f"{problem}{solution}".encode()).hexdigest()
139
+ if token not in token_to_problem_solution or token_to_problem_solution[token] != problem_solution_hash:
140
+ return {"error": "No matching problem-solution pair found for this token"}
141
+
142
+
143
  system_prompt = "You are a Python coding assistant specialized in explaining LeetCode problem solutions. Provide a clear and concise explanation of the given solution."
144
  full_prompt = f"""### Instruction:
145
  {system_prompt}
 
164
  return {"explanation": generated_text}
165
 
166
 
167
+ def generate_solution(instruction: str, token: str, request: gr.Request) -> Dict[str, Any]:
168
+ ip = request.client.ip
169
+ session = request.client.session
170
+ user_identifier = generate_user_identifier(request)
171
+
172
+ is_allowed, message = check_rate_limit(ip, session)
173
+ if not is_allowed:
174
+ return {"error": message}
175
+
176
  if not verify_token(token):
177
  return {"error": "Invalid token"}
178
 
 
195
  generated_text += chunk["choices"][0]["text"]
196
 
197
  formatted_code = extract_and_format_code(generated_text)
198
+ problem_solution_hash = hashlib.sha256(f"{instruction}{formatted_code}".encode()).hexdigest()
199
+ token_to_problem_solution[token] = problem_solution_hash
200
  return {"solution": formatted_code}
201
 
202
+ def random_problem(token: str, request: gr.Request) -> Dict[str, Any]:
203
+ ip = request.client.ip
204
+ session = request.client.session
205
+ user_identifier = generate_user_identifier(request)
206
+
207
+ is_allowed, message = check_rate_limit(ip, session)
208
+ if not is_allowed:
209
+ return {"error": message}
210
+
211
  if not verify_token(token):
212
  return {"error": "Invalid token"}
213
 
 
220
 
221
  return {"problem": problem}
222
 
223
+ def explain_solution(token: str, problem: str, solution: str, request: gr.Request) -> Dict[str, Any]:
224
+ ip = request.client.ip
225
+ session = request.client.session
226
+ user_identifier = generate_user_identifier(request)
227
+
228
+ is_allowed, message = check_rate_limit(ip, session)
229
+ if not is_allowed:
230
+ return {"error": message}
231
+
232
  if not verify_token(token):
233
  return {"error": "Invalid token"}
234
+
235
+ problem_solution_hash = hashlib.sha256(f"{problem}{solution}".encode()).hexdigest()
236
+ if token not in token_to_problem_solution or token_to_problem_solution[token] != problem_solution_hash:
237
+ return {"error": "No matching problem-solution pair found for this token"}
238
+
 
 
239
  return generate_explanation(problem, solution, token)
240
 
241
  # Create Gradio interfaces
242
  generate_interface = gr.Interface(
243
  fn=generate_solution,
244
+ inputs=[
245
+ gr.Textbox(label="Problem Instruction"),
246
+ gr.Textbox(label="JWT Token"),
247
+ gr.Request()
248
+ ],
249
  outputs=gr.JSON(),
250
  title="Generate Solution API",
251
  description="Provide a LeetCode problem instruction and a valid JWT token to generate a solution."
 
253
 
254
  random_problem_interface = gr.Interface(
255
  fn=random_problem,
256
+ inputs=[
257
+ gr.Textbox(label="JWT Token"),
258
+ gr.Request()
259
+ ],
260
  outputs=gr.JSON(),
261
  title="Random Problem API",
262
  description="Provide a valid JWT token to get a random LeetCode problem."
 
264
 
265
  explain_interface = gr.Interface(
266
  fn=explain_solution,
267
+ inputs=[
268
+ gr.Textbox(label="JWT Token"),
269
+ gr.Textbox(label="Problem"),
270
+ gr.Textbox(label="Solution"),
271
+ gr.Request()
272
+ ],
273
  outputs=gr.JSON(),
274
  title="Explain Solution API",
275
+ description="Provide a valid JWT token, problem, and solution to get an explanation of the solution."
276
  )
277
 
278
  demo = gr.TabbedInterface(