codebertBase / app.py
Last commit not found
raw
history blame
1.75 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware # 新增 CORS 支持
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
# === FastAPI 初始化 ===
app = FastAPI()
# 添加 CORS 中间件(关键步骤)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有请求头
)
# === 模型加载 ===
os.environ["HF_HOME"] = "/app/.cache/huggingface"
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
# === HTTP API 接口 ===
@app.post("/detect")
async def api_detect(code: str):
"""HTTP API 接口"""
try:
inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
label_id = outputs.logits.argmax().item()
return {
"label": int(label_id), # 强制返回 0/1 数字
"score": outputs.logits.softmax(dim=-1)[0][label_id].item()
}
except Exception as e:
return {"error": str(e)}
# === Gradio 界面(可选)===
def gradio_predict(code: str):
result = api_detect(code)
return f"Prediction: {result['label']} (Confidence: {result['score']:.2f})"
gr_interface = gr.Interface(
fn=gradio_predict,
inputs=gr.Textbox(lines=10, placeholder="Paste code here..."),
outputs="text",
title="Code Security Detector"
)
app = gr.mount_gradio_app(app, gr_interface, path="/")