from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from transformers import AutoTokenizer, AutoModelForSequenceClassification import os import logging from pathlib import Path # === 初始化日志 === logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # === 检查缓存目录权限 === def check_permissions(): cache_path = Path(os.getenv("HF_HOME", "")) try: cache_path.mkdir(parents=True, exist_ok=True) test_file = cache_path / "permission_test.txt" test_file.write_text("test") test_file.unlink() logger.info(f"✅ 缓存目录权限正常: {cache_path}") except Exception as e: logger.error(f"❌ 缓存目录权限异常: {str(e)}") raise RuntimeError(f"Directory permission error: {str(e)}") check_permissions() # === FastAPI 配置 === app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # === 模型加载 === try: logger.info("🔄 加载模型中...") model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") logger.info("✅ 模型加载成功") except Exception as e: logger.error(f"❌ 模型加载失败: {str(e)}") raise # === API 接口 === @app.post("/detect") async def detect(code: str): inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) return { "label": int(outputs.logits.argmax()), "score": outputs.logits.softmax(dim=-1).max().item() }