Forrest99 commited on
Commit
113ca35
·
verified ·
1 Parent(s): c70f44b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -2,36 +2,54 @@ from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import os
 
5
 
6
- # 1. 基础配置
7
- app = FastAPI()
 
8
 
9
- # 2. 强制设置缓存路径(解决权限问题)
10
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
11
 
12
- # 3. 加载模型(自动缓存到指定路径)
13
  try:
14
- model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
15
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
 
 
 
 
 
 
 
 
16
  except Exception as e:
17
- raise RuntimeError(f"模型加载失败: {str(e)}")
 
 
 
18
 
19
- # 4. 接口定义
20
  @app.post("/detect")
21
  async def detect(code: str):
22
  try:
23
- # 简单处理超长输入
24
- if len(code) > 2000:
25
- code = code[:2000]
26
-
27
- inputs = tokenizer(code, return_tensors="pt", truncation=True)
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
-
 
 
31
  return {
32
- "label": model.config.id2label[outputs.logits.argmax().item()],
33
- "score": outputs.logits.softmax(dim=-1).max().item()
34
  }
35
 
36
  except Exception as e:
37
- return {"error": str(e)}
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import os
5
+ import logging
6
 
7
+ # 初始化日志
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger("CodeSecurityAPI")
10
 
11
+ # 强制设置缓存路径(解决权限问题)
12
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
13
 
14
+ # 加载模型
15
  try:
16
+ logger.info("Loading model...")
17
+ model = AutoModelForSequenceClassification.from_pretrained(
18
+ "mrm8488/codebert-base-finetuned-detect-insecure-code",
19
+ cache_dir=os.getenv("HF_HOME")
20
+ )
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ "mrm8488/codebert-base-finetuned-detect-insecure-code",
23
+ cache_dir=os.getenv("HF_HOME")
24
+ )
25
+ logger.info("Model loaded successfully")
26
  except Exception as e:
27
+ logger.error(f"Model load failed: {str(e)}")
28
+ raise RuntimeError("模型加载失败,请检查网络连接或模型路径")
29
+
30
+ app = FastAPI()
31
 
 
32
  @app.post("/detect")
33
  async def detect(code: str):
34
  try:
35
+ # 输入处理(限制长度)
36
+ code = code[:2000] # 截断超长输入
37
+
38
+ # 模型推理
39
+ inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
40
  with torch.no_grad():
41
  outputs = model(**inputs)
42
+
43
+ # 解析结果
44
+ label_id = outputs.logits.argmax().item()
45
  return {
46
+ "label": model.config.id2label[label_id],
47
+ "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
48
  }
49
 
50
  except Exception as e:
51
+ return {"error": str(e)}
52
+
53
+ @app.get("/health")
54
+ async def health():
55
+ return {"status": "ok"}