from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os

# === FastAPI 配置 ===
app = FastAPI()

# 解决 CSP 限制的关键配置
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# === 模型加载 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")

# === HTTP API 接口 ===
@app.post("/detect")
async def detect(code: str):
    try:
        inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
        label_id = outputs.logits.argmax().item()
        return {
            "label": int(label_id),  # 严格返回 0/1
            "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
        }
    except Exception as e:
        return {"error": str(e)}