sugiv commited on
Commit
394f072
·
1 Parent(s): 55f0c97

First version with APIs

Browse files
Files changed (2) hide show
  1. app.py +177 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from llama_cpp import Llama
4
+ import re
5
+ from datasets import load_dataset
6
+ import random
7
+ import logging
8
+ import os
9
+ import autopep8
10
+ import textwrap
11
+ import jwt
12
+ from datetime import datetime, timedelta
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # JWT settings
19
+ JWT_SECRET = os.environ.get("JWT_SECRET", "your-secret-key")
20
+ JWT_ALGORITHM = "HS256"
21
+
22
+ # Model settings
23
+ MODEL_NAME = "leetmonkey_peft__q8_0.gguf"
24
+ REPO_ID = "sugiv/leetmonkey-peft-gguf"
25
+
26
+ def download_model(model_name):
27
+ logger.info(f"Downloading model: {model_name}")
28
+ model_path = hf_hub_download(
29
+ repo_id=REPO_ID,
30
+ filename=model_name,
31
+ cache_dir="./models",
32
+ force_download=True,
33
+ resume_download=True
34
+ )
35
+ logger.info(f"Model downloaded: {model_path}")
36
+ return model_path
37
+
38
+ # Download and load the 8-bit model at startup
39
+ model_path = download_model(MODEL_NAME)
40
+ llm = Llama(
41
+ model_path=model_path,
42
+ n_ctx=1024,
43
+ n_threads=8,
44
+ n_gpu_layers=-1, # Use all available GPU layers
45
+ verbose=False,
46
+ n_batch=512,
47
+ mlock=True
48
+ )
49
+ logger.info("8-bit model loaded successfully")
50
+
51
+ # Load the dataset
52
+ dataset = load_dataset("sugiv/leetmonkey_python_dataset")
53
+ train_dataset = dataset["train"]
54
+
55
+ # Generation parameters
56
+ generation_kwargs = {
57
+ "max_tokens": 512,
58
+ "stop": ["```", "### Instruction:", "### Response:"],
59
+ "echo": False,
60
+ "temperature": 0.05,
61
+ "top_k": 10,
62
+ "top_p": 0.9,
63
+ "repeat_penalty": 1.1
64
+ }
65
+
66
+ def generate_solution(instruction):
67
+ system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
68
+ full_prompt = f"""### Instruction:
69
+ {system_prompt}
70
+
71
+ Implement the following function for the LeetCode problem:
72
+
73
+ {instruction}
74
+
75
+ ### Response:
76
+ Here's the complete Python function implementation:
77
+
78
+ ```python
79
+ """
80
+
81
+ for chunk in llm(full_prompt, stream=True, **generation_kwargs):
82
+ yield chunk["choices"][0]["text"]
83
+
84
+ def extract_and_format_code(text):
85
+ # Extract code between triple backticks
86
+ code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
87
+ if code_match:
88
+ code = code_match.group(1)
89
+ else:
90
+ code = text
91
+
92
+ # Dedent the code to remove any common leading whitespace
93
+ code = textwrap.dedent(code)
94
+
95
+ # Split the code into lines
96
+ lines = code.split('\n')
97
+
98
+ # Ensure proper indentation
99
+ indented_lines = []
100
+ for line in lines:
101
+ if line.strip().startswith('class') or line.strip().startswith('def'):
102
+ indented_lines.append(line) # Keep class and function definitions as is
103
+ elif line.strip(): # If the line is not empty
104
+ indented_lines.append(' ' + line) # Add 4 spaces of indentation
105
+ else:
106
+ indented_lines.append(line) # Keep empty lines as is
107
+
108
+ formatted_code = '\n'.join(indented_lines)
109
+
110
+ try:
111
+ return autopep8.fix_code(formatted_code)
112
+ except:
113
+ return formatted_code
114
+
115
+ def select_random_problem():
116
+ return random.choice(train_dataset)['instruction']
117
+
118
+ def stream_solution(problem):
119
+ logger.info("Generating solution")
120
+ generated_text = ""
121
+ for token in generate_solution(problem):
122
+ generated_text += token
123
+ yield generated_text
124
+
125
+ formatted_code = extract_and_format_code(generated_text)
126
+ logger.info("Solution generated successfully")
127
+ yield formatted_code
128
+
129
+ def verify_token(token):
130
+ try:
131
+ jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
132
+ return True
133
+ except:
134
+ return False
135
+
136
+ def generate_token():
137
+ expiration = datetime.utcnow() + timedelta(hours=1)
138
+ return jwt.encode({"exp": expiration}, JWT_SECRET, algorithm=JWT_ALGORITHM)
139
+
140
+ def api_random_problem(token):
141
+ if not verify_token(token):
142
+ return {"error": "Invalid token"}
143
+ return {"problem": select_random_problem()}
144
+
145
+ def api_generate_solution(problem, token):
146
+ if not verify_token(token):
147
+ return {"error": "Invalid token"}
148
+ solution = "".join(list(stream_solution(problem)))
149
+ return {"solution": solution}
150
+
151
+ def api_explain_solution(solution, token):
152
+ if not verify_token(token):
153
+ return {"error": "Invalid token"}
154
+ explanation_prompt = f"Explain the following Python code:\n\n{solution}\n\nExplanation:"
155
+ explanation = llm(explanation_prompt, max_tokens=256)["choices"][0]["text"]
156
+ return {"explanation": explanation}
157
+
158
+ iface = gr.Interface(
159
+ fn=[api_random_problem, api_generate_solution, api_explain_solution, generate_token],
160
+ inputs=[
161
+ gr.Textbox(label="JWT Token"),
162
+ gr.Textbox(label="Problem"),
163
+ gr.Textbox(label="Solution")
164
+ ],
165
+ outputs=[
166
+ gr.JSON(label="Random Problem"),
167
+ gr.JSON(label="Generated Solution"),
168
+ gr.JSON(label="Explanation"),
169
+ gr.Textbox(label="New JWT Token")
170
+ ],
171
+ title="LeetCode Problem Solver API",
172
+ description="API endpoints for generating and explaining LeetCode solutions."
173
+ )
174
+
175
+ if __name__ == "__main__":
176
+ logger.info("Starting Gradio API")
177
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ llama-cpp-python
3
+ datasets
4
+ transformers
5
+ autopep8
6
+ huggingface_hub