import os import re from fastapi import FastAPI from transformers import AutoTokenizer, T5ForConditionalGeneration # 环境变量配置(必须放在所有import之前) os.environ["HF_HOME"] = "/app/.cache/huggingface" app = FastAPI(title="代码安全审计API", version="1.2.0") # 模型初始化(带异常捕获) try: print("正在加载模型...") tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small") model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small") print("✅ 模型加载成功!") except Exception as e: print(f"❌ 模型加载失败: {str(e)}") raise RuntimeError("模型初始化失败,请检查模型路径或网络连接") def sanitize_code(code: str) -> str: """安全清洗输入代码""" # 过滤危险字符 code = re.sub(r"[<>&\"'\\]", "", code) # 限制输入长度(防止超长攻击) return code[:2000] @app.get("/analyze", summary="代码漏洞分析", description="输入代码片段,返回安全分析结果") async def analyze_get(code: str): """主分析端点""" try: # 输入预处理 raw_code = code code = sanitize_code(code) print(f"\n🔍 原始输入:\n{raw_code}\n🛡️ 清洗后:\n{code}\n") # 构造专业级提示词模板 prompt = f"""作为资深安全工程师,请分析以下代码的安全漏洞: **要求**: 1. 检查以下漏洞类型:SQL注入、XSS、命令注入、路径遍历、敏感信息泄露 2. 用中文按格式返回: [漏洞类型] [危险等级]:具体描述 [修复建议]:解决方案 **代码片段**: {code}""" print(f"📝 提示词:\n{prompt}\n") # Token化处理 inputs = tokenizer( prompt, return_tensors="pt", max_length=512, truncation=True, padding="max_length" ) print(f"🔢 输入Token数量: {len(inputs.input_ids[0])}") # 生成参数配置(经过专业调优) outputs = model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_length=512, num_beams=5, early_stopping=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2 ) # 解码与后处理 analysis = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"⚙️ 原始输出:\n{analysis}\n") # 格式化处理 cleaned_analysis = analysis.replace("**代码片段**", "").strip() return {"result": cleaned_analysis} except Exception as e: print(f"❌ 错误追踪: {str(e)}") return {"error": "分析失败,请检查输入格式"} @app.get("/", include_in_schema=False) async def health_check(): """健康检查端点""" return { "status": "running", "model": "Salesforce/codet5-small", "version": "1.2.0" }