from fastapi import FastAPI from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import os import logging # 初始化日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("CodeSecurityAPI") # 强制设置缓存路径(解决权限问题) os.environ["HF_HOME"] = "/app/.cache/huggingface" # 加载模型 try: logger.info("Loading model...") model = AutoModelForSequenceClassification.from_pretrained( "mrm8488/codebert-base-finetuned-detect-insecure-code", cache_dir=os.getenv("HF_HOME") ) tokenizer = AutoTokenizer.from_pretrained( "mrm8488/codebert-base-finetuned-detect-insecure-code", cache_dir=os.getenv("HF_HOME") ) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Model load failed: {str(e)}") raise RuntimeError("模型加载失败,请检查网络连接或模型路径") app = FastAPI() @app.post("/detect") async def detect(code: str): try: # 输入处理(限制长度) code = code[:2000] # 截断超长输入 # 模型推理 inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) # 解析结果 label_id = outputs.logits.argmax().item() return { "label": model.config.id2label[label_id], "score": outputs.logits.softmax(dim=-1)[0][label_id].item() } except Exception as e: return {"error": str(e)} @app.get("/health") async def health(): return {"status": "ok"}