Spaces:
Running
Running
| from fastapi import FastAPI | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import os | |
| # 1. 基础配置 | |
| app = FastAPI() | |
| # 2. 强制设置缓存路径(解决权限问题) | |
| os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
| # 3. 加载模型(自动缓存到指定路径) | |
| try: | |
| model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") | |
| tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") | |
| except Exception as e: | |
| raise RuntimeError(f"模型加载失败: {str(e)}") | |
| # 4. 接口定义 | |
| async def detect(code: str): | |
| try: | |
| # 简单处理超长输入 | |
| if len(code) > 2000: | |
| code = code[:2000] | |
| inputs = tokenizer(code, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return { | |
| "label": model.config.id2label[outputs.logits.argmax().item()], | |
| "score": outputs.logits.softmax(dim=-1).max().item() | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} |