Forrest99 commited on
Commit
e6e9cd7
·
verified ·
1 Parent(s): a69f5f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -36
app.py CHANGED
@@ -1,42 +1,39 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
4
- import torch
5
 
6
- app = FastAPI()
 
 
7
 
8
- # 全局加载模型
9
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
10
- model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
11
 
12
- class CodeRequest(BaseModel):
13
- code: str
14
- max_length: int = 512
 
 
 
 
15
 
16
- @app.post("/v1/analyze")
17
- async def analyze_code(request: CodeRequest):
18
- try:
19
- # 构造提示词
20
- prompt = f"Analyze security vulnerabilities in this code:\n{request.code}"
21
-
22
- # 生成分析结果
23
- inputs = tokenizer(prompt, return_tensors="pt",
24
- max_length=512, truncation=True)
25
- outputs = model.generate(
26
- inputs.input_ids,
27
- max_length=request.max_length,
28
- num_beams=5,
29
- early_stopping=True
30
- )
31
-
32
- # 解码结果
33
- analysis = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
-
35
- return {
36
- "status": "success",
37
- "analysis": analysis,
38
- "model": "Salesforce/codet5-small"
39
- }
40
 
41
- except Exception as e:
42
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
 
4
 
5
+ # 设置缓存路径(必须放在最前面)
6
+ os.environ["HF_HOME"] = "/app/.cache/huggingface"
7
+ os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
8
 
9
+ app = FastAPI()
 
 
10
 
11
+ # 加载模型
12
+ try:
13
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
14
+ model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
15
+ except Exception as e:
16
+ print(f"模型加载失败: {str(e)}")
17
+ raise
18
 
19
+ @app.post("/analyze")
20
+ async def analyze_code(code: str):
21
+ prompt = f"Analyze security vulnerabilities:\n{code}"
22
+
23
+ inputs = tokenizer(prompt, return_tensors="pt",
24
+ max_length=512, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ outputs = model.generate(
27
+ inputs.input_ids,
28
+ max_length=512,
29
+ num_beams=5,
30
+ early_stopping=True
31
+ )
32
+
33
+ return {
34
+ "result": tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+ }
36
+
37
+ @app.get("/health")
38
+ def health_check():
39
+ return {"status": "ok", "cache_path": os.environ["HF_HOME"]}