fastapi-deepseek / server.py
Elaineyy's picture
Update server.py
bb40de3 verified
raw
history blame
2.25 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import subprocess
import tempfile
import os
app = FastAPI()
# Load DeepSeek-Coder-V2-Base Model
model_name = "deepseek-ai/DeepSeek-Coder-V2-Base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
class CodeRequest(BaseModel):
user_story: str
class TestRequest(BaseModel):
code: str
@app.post("/generate-code")
def generate_code(request: CodeRequest):
"""Generates code based on user story"""
prompt = f"Generate structured code for: {request.user_story}"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
output = model.generate(**inputs, max_length=300)
generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
return {"generated_code": generated_code}
@app.post("/test-code")
def test_code(request: TestRequest):
"""Runs automated testing on the generated code"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
temp_file.write(request.code.encode())
temp_file.close()
result = subprocess.run(["pytest", temp_file.name], capture_output=True, text=True)
os.unlink(temp_file.name)
if result.returncode == 0:
return {"test_status": "All tests passed!"}
else:
return {"test_status": "Test failed!", "details": result.stderr}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/execute-code")
def execute_code():
"""Executes AI-generated code"""
sample_code = "print('Hello from AI-generated code!')"
try:
result = subprocess.run(["python3", "-c", sample_code], capture_output=True, text=True)
return {"status": "Execution successful!", "output": result.stdout}
except Exception as e:
return {"status": "Execution failed!", "error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)