Forrest99 commited on
Commit
fc986a8
·
verified ·
1 Parent(s): eb9892c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -40
app.py CHANGED
@@ -1,60 +1,52 @@
1
  from fastapi import FastAPI
 
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import os
5
- import logging
6
-
7
- # 初始化日志
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger("CodeSecurityAPI")
10
-
11
- # 强制设置缓存路径
12
- os.environ["HF_HOME"] = "/app/.cache/huggingface"
13
 
 
14
  app = FastAPI()
15
 
16
- # === 新增根路径响应 ===
17
- @app.get("/")
18
- async def read_root():
19
- return {
20
- "message": "欢迎使用代码安全检测API",
21
- "endpoints": {
22
- "detect": "POST /detect",
23
- "health": "GET /health"
24
- }
25
- }
26
 
27
- # === 加载模型(必须放在FastAPI实例之后) ===
28
- try:
29
- logger.info("Loading model...")
30
- model = AutoModelForSequenceClassification.from_pretrained(
31
- "mrm8488/codebert-base-finetuned-detect-insecure-code",
32
- cache_dir=os.getenv("HF_HOME")
33
- )
34
- tokenizer = AutoTokenizer.from_pretrained(
35
- "mrm8488/codebert-base-finetuned-detect-insecure-code",
36
- cache_dir=os.getenv("HF_HOME")
37
- )
38
- logger.info("Model loaded successfully")
39
- except Exception as e:
40
- logger.error(f"Model load failed: {str(e)}")
41
- raise RuntimeError("模型加载失败,请检查网络连接或模型路径")
42
 
 
43
  @app.post("/detect")
44
- async def detect(code: str):
 
45
  try:
46
- code = code[:2000]
47
- inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
48
  with torch.no_grad():
49
  outputs = model(**inputs)
50
  label_id = outputs.logits.argmax().item()
51
  return {
52
- "label": model.config.id2label[label_id],
53
  "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
54
  }
55
  except Exception as e:
56
  return {"error": str(e)}
57
 
58
- @app.get("/health")
59
- async def health():
60
- return {"status": "ok"}
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware # 新增 CORS 支持
3
+ import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
  import os
 
 
 
 
 
 
 
 
7
 
8
+ # === FastAPI 初始化 ===
9
  app = FastAPI()
10
 
11
+ # 添加 CORS 中间件(关键步骤)
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"], # 允许所有来源
15
+ allow_methods=["*"], # 允许所有 HTTP 方法
16
+ allow_headers=["*"], # 允许所有请求头
17
+ )
 
 
 
18
 
19
+ # === 模型加载 ===
20
+ os.environ["HF_HOME"] = "/app/.cache/huggingface"
21
+ model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
22
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # === HTTP API 接口 ===
25
  @app.post("/detect")
26
+ async def api_detect(code: str):
27
+ """HTTP API 接口"""
28
  try:
29
+ inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
 
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
  label_id = outputs.logits.argmax().item()
33
  return {
34
+ "label": int(label_id), # 强制返回 0/1 数字
35
  "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
36
  }
37
  except Exception as e:
38
  return {"error": str(e)}
39
 
40
+ # === Gradio 界面(可选)===
41
+ def gradio_predict(code: str):
42
+ result = api_detect(code)
43
+ return f"Prediction: {result['label']} (Confidence: {result['score']:.2f})"
44
+
45
+ gr_interface = gr.Interface(
46
+ fn=gradio_predict,
47
+ inputs=gr.Textbox(lines=10, placeholder="Paste code here..."),
48
+ outputs="text",
49
+ title="Code Security Detector"
50
+ )
51
+
52
+ app = gr.mount_gradio_app(app, gr_interface, path="/")