from fastapi import FastAPI from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import os import logging # 初始化日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("CodeSecurityAPI") # 强制设置缓存路径 os.environ["HF_HOME"] = "/app/.cache/huggingface" app = FastAPI() # === 新增根路径响应 === @app.get("/") async def read_root(): return { "message": "欢迎使用代码安全检测API", "endpoints": { "detect": "POST /detect", "health": "GET /health" } } # === 加载模型(必须放在FastAPI实例之后) === try: logger.info("Loading model...") model = AutoModelForSequenceClassification.from_pretrained( "mrm8488/codebert-base-finetuned-detect-insecure-code", cache_dir=os.getenv("HF_HOME") ) tokenizer = AutoTokenizer.from_pretrained( "mrm8488/codebert-base-finetuned-detect-insecure-code", cache_dir=os.getenv("HF_HOME") ) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Model load failed: {str(e)}") raise RuntimeError("模型加载失败,请检查网络连接或模型路径") @app.post("/detect") async def detect(code: str): try: code = code[:2000] inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) label_id = outputs.logits.argmax().item() return { "label": model.config.id2label[label_id], "score": outputs.logits.softmax(dim=-1)[0][label_id].item() } except Exception as e: return {"error": str(e)} @app.get("/health") async def health(): return {"status": "ok"}