Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -2,36 +2,54 @@ from fastapi import FastAPI | |
| 2 | 
             
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
         | 
| 3 | 
             
            import torch
         | 
| 4 | 
             
            import os
         | 
|  | |
| 5 |  | 
| 6 | 
            -
            #  | 
| 7 | 
            -
             | 
|  | |
| 8 |  | 
| 9 | 
            -
            #  | 
| 10 | 
             
            os.environ["HF_HOME"] = "/app/.cache/huggingface"
         | 
| 11 |  | 
| 12 | 
            -
            #  | 
| 13 | 
             
            try:
         | 
| 14 | 
            -
                 | 
| 15 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 16 | 
             
            except Exception as e:
         | 
| 17 | 
            -
                 | 
|  | |
|  | |
|  | |
| 18 |  | 
| 19 | 
            -
            # 4. 接口定义
         | 
| 20 | 
             
            @app.post("/detect")
         | 
| 21 | 
             
            async def detect(code: str):
         | 
| 22 | 
             
                try:
         | 
| 23 | 
            -
                    #  | 
| 24 | 
            -
                     | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
                    inputs = tokenizer(code, return_tensors="pt", truncation=True)
         | 
| 28 | 
             
                    with torch.no_grad():
         | 
| 29 | 
             
                        outputs = model(**inputs)
         | 
| 30 | 
            -
             | 
|  | |
|  | |
| 31 | 
             
                    return {
         | 
| 32 | 
            -
                        "label": model.config.id2label[ | 
| 33 | 
            -
                        "score": outputs.logits.softmax(dim=-1). | 
| 34 | 
             
                    }
         | 
| 35 |  | 
| 36 | 
             
                except Exception as e:
         | 
| 37 | 
            -
                    return {"error": str(e)}
         | 
|  | |
|  | |
|  | |
|  | 
|  | |
| 2 | 
             
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
         | 
| 3 | 
             
            import torch
         | 
| 4 | 
             
            import os
         | 
| 5 | 
            +
            import logging
         | 
| 6 |  | 
| 7 | 
            +
            # 初始化日志
         | 
| 8 | 
            +
            logging.basicConfig(level=logging.INFO)
         | 
| 9 | 
            +
            logger = logging.getLogger("CodeSecurityAPI")
         | 
| 10 |  | 
| 11 | 
            +
            # 强制设置缓存路径(解决权限问题)
         | 
| 12 | 
             
            os.environ["HF_HOME"] = "/app/.cache/huggingface"
         | 
| 13 |  | 
| 14 | 
            +
            # 加载模型
         | 
| 15 | 
             
            try:
         | 
| 16 | 
            +
                logger.info("Loading model...")
         | 
| 17 | 
            +
                model = AutoModelForSequenceClassification.from_pretrained(
         | 
| 18 | 
            +
                    "mrm8488/codebert-base-finetuned-detect-insecure-code",
         | 
| 19 | 
            +
                    cache_dir=os.getenv("HF_HOME")
         | 
| 20 | 
            +
                )
         | 
| 21 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(
         | 
| 22 | 
            +
                    "mrm8488/codebert-base-finetuned-detect-insecure-code",
         | 
| 23 | 
            +
                    cache_dir=os.getenv("HF_HOME")
         | 
| 24 | 
            +
                )
         | 
| 25 | 
            +
                logger.info("Model loaded successfully")
         | 
| 26 | 
             
            except Exception as e:
         | 
| 27 | 
            +
                logger.error(f"Model load failed: {str(e)}")
         | 
| 28 | 
            +
                raise RuntimeError("模型加载失败,请检查网络连接或模型路径")
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            app = FastAPI()
         | 
| 31 |  | 
|  | |
| 32 | 
             
            @app.post("/detect")
         | 
| 33 | 
             
            async def detect(code: str):
         | 
| 34 | 
             
                try:
         | 
| 35 | 
            +
                    # 输入处理(限制长度)
         | 
| 36 | 
            +
                    code = code[:2000]  # 截断超长输入
         | 
| 37 | 
            +
                    
         | 
| 38 | 
            +
                    # 模型推理
         | 
| 39 | 
            +
                    inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512)
         | 
| 40 | 
             
                    with torch.no_grad():
         | 
| 41 | 
             
                        outputs = model(**inputs)
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                    # 解析结果
         | 
| 44 | 
            +
                    label_id = outputs.logits.argmax().item()
         | 
| 45 | 
             
                    return {
         | 
| 46 | 
            +
                        "label": model.config.id2label[label_id],
         | 
| 47 | 
            +
                        "score": outputs.logits.softmax(dim=-1)[0][label_id].item()
         | 
| 48 | 
             
                    }
         | 
| 49 |  | 
| 50 | 
             
                except Exception as e:
         | 
| 51 | 
            +
                    return {"error": str(e)}
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            @app.get("/health")
         | 
| 54 | 
            +
            async def health():
         | 
| 55 | 
            +
                return {"status": "ok"}
         |