Forrest99 commited on
Commit
338753c
·
verified ·
1 Parent(s): a324382

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -18
app.py CHANGED
@@ -1,13 +1,31 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
- import torch
5
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # === FastAPI 配置 ===
8
  app = FastAPI()
9
-
10
- # 解决 CSP 限制的关键配置
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
@@ -16,21 +34,22 @@ app.add_middleware(
16
  )
17
 
18
  # === 模型加载 ===
19
- os.environ["HF_HOME"] = "/app/.cache/huggingface"
20
- model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
21
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
 
 
 
 
 
22
 
23
- # === HTTP API 接口 ===
24
  @app.post("/detect")
25
  async def detect(code: str):
26
- try:
27
- inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
28
- with torch.no_grad():
29
- outputs = model(**inputs)
30
- label_id = outputs.logits.argmax().item()
31
- return {
32
- "label": int(label_id), # 严格返回 0/1
33
- "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
34
- }
35
- except Exception as e:
36
- return {"error": str(e)}
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
  import os
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ # === 初始化日志 ===
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # === 检查缓存目录权限 ===
13
+ def check_permissions():
14
+ cache_path = Path(os.getenv("HF_HOME", ""))
15
+ try:
16
+ cache_path.mkdir(parents=True, exist_ok=True)
17
+ test_file = cache_path / "permission_test.txt"
18
+ test_file.write_text("test")
19
+ test_file.unlink()
20
+ logger.info(f"✅ 缓存目录权限正常: {cache_path}")
21
+ except Exception as e:
22
+ logger.error(f"❌ 缓存目录权限异常: {str(e)}")
23
+ raise RuntimeError(f"Directory permission error: {str(e)}")
24
+
25
+ check_permissions()
26
 
27
  # === FastAPI 配置 ===
28
  app = FastAPI()
 
 
29
  app.add_middleware(
30
  CORSMiddleware,
31
  allow_origins=["*"],
 
34
  )
35
 
36
  # === 模型加载 ===
37
+ try:
38
+ logger.info("🔄 加载模型中...")
39
+ model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
40
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code")
41
+ logger.info("✅ 模型加载成功")
42
+ except Exception as e:
43
+ logger.error(f"❌ 模型加载失败: {str(e)}")
44
+ raise
45
 
46
+ # === API 接口 ===
47
  @app.post("/detect")
48
  async def detect(code: str):
49
+ inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512)
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+ return {
53
+ "label": int(outputs.logits.argmax()),
54
+ "score": outputs.logits.softmax(dim=-1).max().item()
55
+ }