code-audit-api / app.py
Forrest99's picture
Update app.py
d763957 verified
raw
history blame
3 kB
import os
import re
from fastapi import FastAPI
from transformers import AutoTokenizer, T5ForConditionalGeneration
# 环境变量配置(必须放在所有import之前)
os.environ["HF_HOME"] = "/app/.cache/huggingface"
app = FastAPI(title="代码安全审计API", version="1.2.0")
# 模型初始化(带异常捕获)
try:
print("正在加载模型...")
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")
model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
print("✅ 模型加载成功!")
except Exception as e:
print(f"❌ 模型加载失败: {str(e)}")
raise RuntimeError("模型初始化失败,请检查模型路径或网络连接")
def sanitize_code(code: str) -> str:
"""安全清洗输入代码"""
# 过滤危险字符
code = re.sub(r"[<>&\"'\\]", "", code)
# 限制输入长度(防止超长攻击)
return code[:2000]
@app.get("/analyze", summary="代码漏洞分析", description="输入代码片段,返回安全分析结果")
async def analyze_get(code: str):
"""主分析端点"""
try:
# 输入预处理
raw_code = code
code = sanitize_code(code)
print(f"\n🔍 原始输入:\n{raw_code}\n🛡️ 清洗后:\n{code}\n")
# 构造专业级提示词模板
prompt = f"""作为资深安全工程师,请分析以下代码的安全漏洞:
**要求**:
1. 检查以下漏洞类型:SQL注入、XSS、命令注入、路径遍历、敏感信息泄露
2. 用中文按格式返回:
[漏洞类型] [危险等级]:具体描述
[修复建议]:解决方案
**代码片段**:
{code}"""
print(f"📝 提示词:\n{prompt}\n")
# Token化处理
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding="max_length"
)
print(f"🔢 输入Token数量: {len(inputs.input_ids[0])}")
# 生成参数配置(经过专业调优)
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
num_beams=5,
early_stopping=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2
)
# 解码与后处理
analysis = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"⚙️ 原始输出:\n{analysis}\n")
# 格式化处理
cleaned_analysis = analysis.replace("**代码片段**", "").strip()
return {"result": cleaned_analysis}
except Exception as e:
print(f"❌ 错误追踪: {str(e)}")
return {"error": "分析失败,请检查输入格式"}
@app.get("/", include_in_schema=False)
async def health_check():
"""健康检查端点"""
return {
"status": "running",
"model": "Salesforce/codet5-small",
"version": "1.2.0"
}