sugiv commited on
Commit
16fea32
·
1 Parent(s): 3081d77

First version with APIs

Browse files
Files changed (1) hide show
  1. app.py +75 -80
app.py CHANGED
@@ -1,28 +1,40 @@
1
- import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
- from llama_cpp import Llama
4
  import re
5
- from datasets import load_dataset
6
- import random
7
  import logging
8
- import os
9
- import autopep8
10
  import textwrap
 
 
 
 
11
  import jwt
12
- from datetime import datetime, timedelta
 
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
  # JWT settings
19
- JWT_SECRET = os.environ.get("JWT_SECRET", "your-secret-key")
 
 
20
  JWT_ALGORITHM = "HS256"
21
 
22
  # Model settings
23
  MODEL_NAME = "leetmonkey_peft__q8_0.gguf"
24
  REPO_ID = "sugiv/leetmonkey-peft-gguf"
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def download_model(model_name):
27
  logger.info(f"Downloading model: {model_name}")
28
  model_path = hf_hub_download(
@@ -48,22 +60,7 @@ llm = Llama(
48
  )
49
  logger.info("8-bit model loaded successfully")
50
 
51
- # Load the dataset
52
- dataset = load_dataset("sugiv/leetmonkey_python_dataset")
53
- train_dataset = dataset["train"]
54
-
55
- # Generation parameters
56
- generation_kwargs = {
57
- "max_tokens": 512,
58
- "stop": ["```", "### Instruction:", "### Response:"],
59
- "echo": False,
60
- "temperature": 0.05,
61
- "top_k": 10,
62
- "top_p": 0.9,
63
- "repeat_penalty": 1.1
64
- }
65
-
66
- def generate_solution(instruction):
67
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
68
  full_prompt = f"""### Instruction:
69
  {system_prompt}
@@ -78,32 +75,27 @@ Here's the complete Python function implementation:
78
  ```python
79
  """
80
 
81
- for chunk in llm(full_prompt, stream=True, **generation_kwargs):
82
- yield chunk["choices"][0]["text"]
83
 
84
- def extract_and_format_code(text):
85
- # Extract code between triple backticks
86
  code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
87
  if code_match:
88
  code = code_match.group(1)
89
  else:
90
  code = text
91
 
92
- # Dedent the code to remove any common leading whitespace
93
  code = textwrap.dedent(code)
94
-
95
- # Split the code into lines
96
  lines = code.split('\n')
97
 
98
- # Ensure proper indentation
99
  indented_lines = []
100
  for line in lines:
101
  if line.strip().startswith('class') or line.strip().startswith('def'):
102
- indented_lines.append(line) # Keep class and function definitions as is
103
- elif line.strip(): # If the line is not empty
104
- indented_lines.append(' ' + line) # Add 4 spaces of indentation
105
  else:
106
- indented_lines.append(line) # Keep empty lines as is
107
 
108
  formatted_code = '\n'.join(indented_lines)
109
 
@@ -112,66 +104,69 @@ def extract_and_format_code(text):
112
  except:
113
  return formatted_code
114
 
115
- def select_random_problem():
116
- return random.choice(train_dataset)['instruction']
117
-
118
- def stream_solution(problem):
119
- logger.info("Generating solution")
120
- generated_text = ""
121
- for token in generate_solution(problem):
122
- generated_text += token
123
- yield generated_text
124
-
125
- formatted_code = extract_and_format_code(generated_text)
126
- logger.info("Solution generated successfully")
127
- yield formatted_code
128
-
129
- def verify_token(token):
130
  try:
131
  jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
132
  return True
133
- except:
134
  return False
135
 
136
- def generate_token():
137
- expiration = datetime.utcnow() + timedelta(hours=1)
138
- return jwt.encode({"exp": expiration}, JWT_SECRET, algorithm=JWT_ALGORITHM)
139
-
140
- def api_random_problem(token):
141
  if not verify_token(token):
142
  return {"error": "Invalid token"}
143
- return {"problem": select_random_problem()}
144
-
145
- def api_generate_solution(problem, token):
146
- if not verify_token(token):
147
- return {"error": "Invalid token"}
148
- solution = "".join(list(stream_solution(problem)))
149
- return {"solution": solution}
150
 
151
- def api_explain_solution(solution, token):
152
  if not verify_token(token):
153
  return {"error": "Invalid token"}
154
- explanation_prompt = f"Explain the following Python code:\n\n{solution}\n\nExplanation:"
 
155
  explanation = llm(explanation_prompt, max_tokens=256)["choices"][0]["text"]
156
  return {"explanation": explanation}
157
 
158
- iface = gr.Interface(
159
- fn=[api_random_problem, api_generate_solution, api_explain_solution, generate_token],
 
 
 
 
 
 
 
160
  inputs=[
161
- gr.Textbox(label="JWT Token"),
162
- gr.Textbox(label="Problem"),
163
- gr.Textbox(label="Solution")
164
  ],
165
- outputs=[
166
- gr.JSON(label="Random Problem"),
167
- gr.JSON(label="Generated Solution"),
168
- gr.JSON(label="Explanation"),
169
- gr.Textbox(label="New JWT Token")
 
 
 
 
 
170
  ],
171
- title="LeetCode Problem Solver API",
172
- description="API endpoints for generating and explaining LeetCode solutions."
 
173
  )
174
 
 
 
 
 
 
 
 
 
 
 
 
175
  if __name__ == "__main__":
176
  logger.info("Starting Gradio API")
177
- iface.launch(share=True)
 
1
+ import os
 
 
2
  import re
 
 
3
  import logging
 
 
4
  import textwrap
5
+ import autopep8
6
+ import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+ from llama_cpp import Llama
9
  import jwt
10
+ from typing import Dict, Any
11
+ import datetime
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
  # JWT settings
18
+ JWT_SECRET = os.environ.get("JWT_SECRET")
19
+ if not JWT_SECRET:
20
+ raise ValueError("JWT_SECRET environment variable is not set")
21
  JWT_ALGORITHM = "HS256"
22
 
23
  # Model settings
24
  MODEL_NAME = "leetmonkey_peft__q8_0.gguf"
25
  REPO_ID = "sugiv/leetmonkey-peft-gguf"
26
 
27
+ # Generation parameters
28
+ generation_kwargs = {
29
+ "max_tokens": 512,
30
+ "stop": ["```", "### Instruction:", "### Response:"],
31
+ "echo": False,
32
+ "temperature": 0.05,
33
+ "top_k": 10,
34
+ "top_p": 0.9,
35
+ "repeat_penalty": 1.1
36
+ }
37
+
38
  def download_model(model_name):
39
  logger.info(f"Downloading model: {model_name}")
40
  model_path = hf_hub_download(
 
60
  )
61
  logger.info("8-bit model loaded successfully")
62
 
63
+ def generate_solution(instruction: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
65
  full_prompt = f"""### Instruction:
66
  {system_prompt}
 
75
  ```python
76
  """
77
 
78
+ response = llm(full_prompt, **generation_kwargs)
79
+ return response["choices"][0]["text"]
80
 
81
+ def extract_and_format_code(text: str) -> str:
 
82
  code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
83
  if code_match:
84
  code = code_match.group(1)
85
  else:
86
  code = text
87
 
 
88
  code = textwrap.dedent(code)
 
 
89
  lines = code.split('\n')
90
 
 
91
  indented_lines = []
92
  for line in lines:
93
  if line.strip().startswith('class') or line.strip().startswith('def'):
94
+ indented_lines.append(line)
95
+ elif line.strip():
96
+ indented_lines.append(' ' + line)
97
  else:
98
+ indented_lines.append(line)
99
 
100
  formatted_code = '\n'.join(indented_lines)
101
 
 
104
  except:
105
  return formatted_code
106
 
107
+ def verify_token(token: str) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  try:
109
  jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
110
  return True
111
+ except jwt.PyJWTError:
112
  return False
113
 
114
+ def api_generate_solution(instruction: str, token: str) -> Dict[str, Any]:
 
 
 
 
115
  if not verify_token(token):
116
  return {"error": "Invalid token"}
117
+
118
+ generated_output = generate_solution(instruction)
119
+ formatted_code = extract_and_format_code(generated_output)
120
+ return {"solution": formatted_code}
 
 
 
121
 
122
+ def api_explain_solution(code: str, token: str) -> Dict[str, Any]:
123
  if not verify_token(token):
124
  return {"error": "Invalid token"}
125
+
126
+ explanation_prompt = f"Explain the following Python code:\n\n{code}\n\nExplanation:"
127
  explanation = llm(explanation_prompt, max_tokens=256)["choices"][0]["text"]
128
  return {"explanation": explanation}
129
 
130
+ def generate_token() -> str:
131
+ expiration = datetime.datetime.utcnow() + datetime.timedelta(hours=1)
132
+ payload = {"exp": expiration}
133
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
134
+ return token
135
+
136
+ # Gradio interfaces
137
+ iface_generate = gr.Interface(
138
+ fn=api_generate_solution,
139
  inputs=[
140
+ gr.Textbox(label="LeetCode Problem Instruction"),
141
+ gr.Textbox(label="JWT Token")
 
142
  ],
143
+ outputs=gr.JSON(label="Generated Solution"),
144
+ title="LeetCode Problem Solver API - Generate Solution",
145
+ description="Provide a LeetCode problem instruction and a valid JWT token to generate a solution."
146
+ )
147
+
148
+ iface_explain = gr.Interface(
149
+ fn=api_explain_solution,
150
+ inputs=[
151
+ gr.Textbox(label="Code to Explain"),
152
+ gr.Textbox(label="JWT Token")
153
  ],
154
+ outputs=gr.JSON(label="Explanation"),
155
+ title="LeetCode Problem Solver API - Explain Solution",
156
+ description="Provide a code snippet and a valid JWT token to get an explanation."
157
  )
158
 
159
+ iface_token = gr.Interface(
160
+ fn=generate_token,
161
+ inputs=[],
162
+ outputs=gr.Textbox(label="Generated JWT Token"),
163
+ title="Generate JWT Token",
164
+ description="Generate a new JWT token for API authentication."
165
+ )
166
+
167
+ # Combine interfaces
168
+ demo = gr.TabbedInterface([iface_generate, iface_explain, iface_token], ["Generate Solution", "Explain Solution", "Generate Token"])
169
+
170
  if __name__ == "__main__":
171
  logger.info("Starting Gradio API")
172
+ demo.launch(share=True)