code-audit-api / app.py
Forrest99's picture
Update app.py
e6e9cd7 verified
raw
history blame
1.1 kB
import os
from fastapi import FastAPI
from transformers import AutoTokenizer, T5ForConditionalGeneration
# 设置缓存路径(必须放在最前面)
os.environ["HF_HOME"] = "/app/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
app = FastAPI()
# 加载模型
try:
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
except Exception as e:
print(f"模型加载失败: {str(e)}")
raise
@app.post("/analyze")
async def analyze_code(code: str):
prompt = f"Analyze security vulnerabilities:\n{code}"
inputs = tokenizer(prompt, return_tensors="pt",
max_length=512, truncation=True)
outputs = model.generate(
inputs.input_ids,
max_length=512,
num_beams=5,
early_stopping=True
)
return {
"result": tokenizer.decode(outputs[0], skip_special_tokens=True)
}
@app.get("/health")
def health_check():
return {"status": "ok", "cache_path": os.environ["HF_HOME"]}