Forrest99 commited on
Commit
b025fa1
·
verified ·
1 Parent(s): 9f324c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -31
app.py CHANGED
@@ -2,38 +2,14 @@ import os
2
  from fastapi import FastAPI
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
4
 
5
- # 设置缓存路径(必须放在最前面)
6
- os.environ["HF_HOME"] = "/app/.cache/huggingface"
7
- os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
8
 
9
  app = FastAPI()
10
-
11
- # 加载模型
12
- try:
13
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
14
- model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
15
- except Exception as e:
16
- print(f"模型加载失败: {str(e)}")
17
- raise
18
 
19
  @app.post("/analyze")
20
- async def analyze_code(code: str):
21
- prompt = f"Analyze security vulnerabilities:\n{code}"
22
-
23
- inputs = tokenizer(prompt, return_tensors="pt",
24
- max_length=512, truncation=True)
25
-
26
- outputs = model.generate(
27
- inputs.input_ids,
28
- max_length=512,
29
- num_beams=5,
30
- early_stopping=True
31
- )
32
-
33
- return {
34
- "result": tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- }
36
-
37
- @app.get("/health")
38
- def health_check():
39
- return {"status": "ok", "cache_path": os.environ["HF_HOME"]}
 
2
  from fastapi import FastAPI
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
4
 
5
+ os.environ["HF_HOME"] = "/app/.cache"
 
 
6
 
7
  app = FastAPI()
8
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
9
+ model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
 
 
 
 
 
 
10
 
11
  @app.post("/analyze")
12
+ async def analyze(code: str):
13
+ inputs = tokenizer(f"Analyze vulnerabilities:\n{code}", return_tensors="pt", max_length=512, truncation=True)
14
+ outputs = model.generate(inputs.input_ids, max_length=512)
15
+ return {"result": tokenizer.decode(outputs[0], skip_special_tokens=True)}