from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware # 新增 CORS 支持 import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import os # === FastAPI 初始化 === app = FastAPI() # 添加 CORS 中间件(关键步骤) app.add_middleware( CORSMiddleware, allow_origins=["*"], # 允许所有来源 allow_methods=["*"], # 允许所有 HTTP 方法 allow_headers=["*"], # 允许所有请求头 ) # === 模型加载 === os.environ["HF_HOME"] = "/app/.cache/huggingface" model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") # === HTTP API 接口 === @app.post("/detect") async def api_detect(code: str): """HTTP API 接口""" try: inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) label_id = outputs.logits.argmax().item() return { "label": int(label_id), # 强制返回 0/1 数字 "score": outputs.logits.softmax(dim=-1)[0][label_id].item() } except Exception as e: return {"error": str(e)} # === Gradio 界面(可选)=== def gradio_predict(code: str): result = api_detect(code) return f"Prediction: {result['label']} (Confidence: {result['score']:.2f})" gr_interface = gr.Interface( fn=gradio_predict, inputs=gr.Textbox(lines=10, placeholder="Paste code here..."), outputs="text", title="Code Security Detector" ) app = gr.mount_gradio_app(app, gr_interface, path="/")