Spaces:
Sleeping
Sleeping
import os | |
import re | |
from fastapi import FastAPI | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
app = FastAPI() | |
# 初始化模型 | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small") | |
model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small") | |
except Exception as e: | |
raise RuntimeError(f"模型加载失败: {str(e)}") | |
def sanitize_code(code: str) -> str: | |
"""清洗输入代码""" | |
code = re.sub(r"[<>&\"']", "", code) # 过滤危险字符 | |
return code[:1024] # 限制输入长度 | |
async def analyze_get(code: str): | |
try: | |
# 清洗输入 | |
code = sanitize_code(code) | |
# 构造提示词 | |
prompt = f"""Analyze the following code for security vulnerabilities in Chinese. | |
重点检查SQL注入、XSS、命令注入、路径遍历等问题。 | |
按此格式返回:\n[漏洞类型]: [风险描述]\n\n代码:\n{code}""" | |
# Tokenize输入 | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding="max_length" | |
) | |
# 生成分析结果 | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=512, | |
num_beams=5, | |
early_stopping=True, | |
temperature=0.7 | |
) | |
# 解码结果 | |
analysis = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"result": analysis} | |
except Exception as e: | |
return {"error": str(e)} | |
async def health_check(): | |
return {"status": "active"} |