Forrest99 commited on
Commit
7de8bc5
·
verified ·
1 Parent(s): db0f531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -1,15 +1,28 @@
1
  import os
2
  from fastapi import FastAPI
3
- from transformers import AutoTokenizer, T5ForConditionalGeneration
4
 
5
- os.environ["HF_HOME"] = "/app/.cache"
 
 
6
 
7
  app = FastAPI()
8
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
9
- model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
 
 
 
 
 
 
 
 
10
 
11
  @app.post("/analyze")
12
  async def analyze(code: str):
13
- inputs = tokenizer(f"Analyze vulnerabilities:\n{code}", return_tensors="pt", max_length=512, truncation=True)
14
- outputs = model.generate(inputs.input_ids, max_length=512)
15
- return {"result": tokenizer.decode(outputs[0], skip_special_tokens=True)}
 
 
 
 
1
  import os
2
  from fastapi import FastAPI
3
+ from transformers import pipeline
4
 
5
+ # 必须在所有导入前设置环境变量
6
+ os.environ["HF_HOME"] = "/app/.cache/huggingface"
7
+ os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
8
 
9
  app = FastAPI()
10
+
11
+ # 模型加载(带错误处理)
12
+ try:
13
+ analyzer = pipeline(
14
+ "text2text-generation",
15
+ model="Salesforce/codet5-small",
16
+ tokenizer="Salesforce/codet5-small"
17
+ )
18
+ except Exception as e:
19
+ raise RuntimeError(f"模型加载失败: {str(e)}")
20
 
21
  @app.post("/analyze")
22
  async def analyze(code: str):
23
+ result = analyzer(
24
+ f"Analyze code vulnerabilities:\n{code}",
25
+ max_length=512,
26
+ num_beams=5
27
+ )[0]['generated_text']
28
+ return {"result": result}