Forrest99 commited on
Commit
eb9892c
·
verified ·
1 Parent(s): f1f2ca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -8,10 +8,23 @@ import logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger("CodeSecurityAPI")
10
 
11
- # 强制设置缓存路径(解决权限问题)
12
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
13
 
14
- # 加载模型
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
  logger.info("Loading model...")
17
  model = AutoModelForSequenceClassification.from_pretrained(
@@ -27,26 +40,18 @@ except Exception as e:
27
  logger.error(f"Model load failed: {str(e)}")
28
  raise RuntimeError("模型加载失败,请检查网络连接或模型路径")
29
 
30
- app = FastAPI()
31
-
32
  @app.post("/detect")
33
  async def detect(code: str):
34
  try:
35
- # 输入处理(限制长度)
36
- code = code[:2000] # 截断超长输入
37
-
38
- # 模型推理
39
  inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
40
  with torch.no_grad():
41
  outputs = model(**inputs)
42
-
43
- # 解析结果
44
  label_id = outputs.logits.argmax().item()
45
  return {
46
  "label": model.config.id2label[label_id],
47
  "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
48
  }
49
-
50
  except Exception as e:
51
  return {"error": str(e)}
52
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger("CodeSecurityAPI")
10
 
11
+ # 强制设置缓存路径
12
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
13
 
14
+ app = FastAPI()
15
+
16
+ # === 新增根路径响应 ===
17
+ @app.get("/")
18
+ async def read_root():
19
+ return {
20
+ "message": "欢迎使用代码安全检测API",
21
+ "endpoints": {
22
+ "detect": "POST /detect",
23
+ "health": "GET /health"
24
+ }
25
+ }
26
+
27
+ # === 加载模型(必须放在FastAPI实例之后) ===
28
  try:
29
  logger.info("Loading model...")
30
  model = AutoModelForSequenceClassification.from_pretrained(
 
40
  logger.error(f"Model load failed: {str(e)}")
41
  raise RuntimeError("模型加载失败,请检查网络连接或模型路径")
42
 
 
 
43
  @app.post("/detect")
44
  async def detect(code: str):
45
  try:
46
+ code = code[:2000]
 
 
 
47
  inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
48
  with torch.no_grad():
49
  outputs = model(**inputs)
 
 
50
  label_id = outputs.logits.argmax().item()
51
  return {
52
  "label": model.config.id2label[label_id],
53
  "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
54
  }
 
55
  except Exception as e:
56
  return {"error": str(e)}
57