|
import os |
|
import re |
|
import logging |
|
import textwrap |
|
import autopep8 |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from llama_cpp import Llama |
|
import jwt |
|
from typing import Dict, Any |
|
import datetime |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
JWT_SECRET = os.environ.get("JWT_SECRET") |
|
if not JWT_SECRET: |
|
raise ValueError("JWT_SECRET environment variable is not set") |
|
JWT_ALGORITHM = "HS256" |
|
|
|
|
|
MODEL_NAME = "leetmonkey_peft__q8_0.gguf" |
|
REPO_ID = "sugiv/leetmonkey-peft-gguf" |
|
|
|
|
|
generation_kwargs = { |
|
"max_tokens": 512, |
|
"stop": ["```", "### Instruction:", "### Response:"], |
|
"echo": False, |
|
"temperature": 0.05, |
|
"top_k": 10, |
|
"top_p": 0.9, |
|
"repeat_penalty": 1.1 |
|
} |
|
|
|
def download_model(model_name): |
|
logger.info(f"Downloading model: {model_name}") |
|
model_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename=model_name, |
|
cache_dir="./models", |
|
force_download=True, |
|
resume_download=True |
|
) |
|
logger.info(f"Model downloaded: {model_path}") |
|
return model_path |
|
|
|
|
|
model_path = download_model(MODEL_NAME) |
|
llm = Llama( |
|
model_path=model_path, |
|
n_ctx=1024, |
|
n_threads=8, |
|
n_gpu_layers=-1, |
|
verbose=False, |
|
n_batch=512, |
|
mlock=True |
|
) |
|
logger.info("8-bit model loaded successfully") |
|
|
|
def generate_solution(instruction: str) -> str: |
|
system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions." |
|
full_prompt = f"""### Instruction: |
|
{system_prompt} |
|
|
|
Implement the following function for the LeetCode problem: |
|
|
|
{instruction} |
|
|
|
### Response: |
|
Here's the complete Python function implementation: |
|
|
|
```python |
|
""" |
|
|
|
response = llm(full_prompt, **generation_kwargs) |
|
return response["choices"][0]["text"] |
|
|
|
def extract_and_format_code(text: str) -> str: |
|
code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL) |
|
if code_match: |
|
code = code_match.group(1) |
|
else: |
|
code = text |
|
|
|
code = textwrap.dedent(code) |
|
lines = code.split('\n') |
|
|
|
indented_lines = [] |
|
for line in lines: |
|
if line.strip().startswith('class') or line.strip().startswith('def'): |
|
indented_lines.append(line) |
|
elif line.strip(): |
|
indented_lines.append(' ' + line) |
|
else: |
|
indented_lines.append(line) |
|
|
|
formatted_code = '\n'.join(indented_lines) |
|
|
|
try: |
|
return autopep8.fix_code(formatted_code) |
|
except: |
|
return formatted_code |
|
|
|
def verify_token(token: str) -> bool: |
|
try: |
|
jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) |
|
return True |
|
except jwt.PyJWTError: |
|
return False |
|
|
|
def api_generate_solution(instruction: str, token: str) -> Dict[str, Any]: |
|
if not verify_token(token): |
|
return {"error": "Invalid token"} |
|
|
|
generated_output = generate_solution(instruction) |
|
formatted_code = extract_and_format_code(generated_output) |
|
return {"solution": formatted_code} |
|
|
|
def api_explain_solution(code: str, token: str) -> Dict[str, Any]: |
|
if not verify_token(token): |
|
return {"error": "Invalid token"} |
|
|
|
explanation_prompt = f"Explain the following Python code:\n\n{code}\n\nExplanation:" |
|
explanation = llm(explanation_prompt, max_tokens=256)["choices"][0]["text"] |
|
return {"explanation": explanation} |
|
|
|
def generate_token() -> str: |
|
expiration = datetime.datetime.utcnow() + datetime.timedelta(hours=1) |
|
payload = {"exp": expiration} |
|
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) |
|
return token |
|
|
|
|
|
iface_generate = gr.Interface( |
|
fn=api_generate_solution, |
|
inputs=[ |
|
gr.Textbox(label="LeetCode Problem Instruction"), |
|
gr.Textbox(label="JWT Token") |
|
], |
|
outputs=gr.JSON(label="Generated Solution"), |
|
title="LeetCode Problem Solver API - Generate Solution", |
|
description="Provide a LeetCode problem instruction and a valid JWT token to generate a solution." |
|
) |
|
|
|
iface_explain = gr.Interface( |
|
fn=api_explain_solution, |
|
inputs=[ |
|
gr.Textbox(label="Code to Explain"), |
|
gr.Textbox(label="JWT Token") |
|
], |
|
outputs=gr.JSON(label="Explanation"), |
|
title="LeetCode Problem Solver API - Explain Solution", |
|
description="Provide a code snippet and a valid JWT token to get an explanation." |
|
) |
|
|
|
iface_token = gr.Interface( |
|
fn=generate_token, |
|
inputs=[], |
|
outputs=gr.Textbox(label="Generated JWT Token"), |
|
title="Generate JWT Token", |
|
description="Generate a new JWT token for API authentication." |
|
) |
|
|
|
|
|
demo = gr.TabbedInterface([iface_generate, iface_explain, iface_token], ["Generate Solution", "Explain Solution", "Generate Token"]) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting Gradio API") |
|
demo.launch(share=True) |
|
|