|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
MODEL_NAME = "CIRCL/cwe-vulnerability-classification-distilbert-base" |
|
|
|
classifier = pipeline("text-classification", model=MODEL_NAME, return_all_scores=True) |
|
|
|
def classify_cwe(text): |
|
results = classifier(text)[0] |
|
|
|
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True) |
|
return {res["label"]: round(res["score"], 4) for res in sorted_results[:5]} |
|
|
|
interface = gr.Interface( |
|
fn=classify_cwe, |
|
inputs=gr.Textbox(lines=5, placeholder="Enter vulnerability description..."), |
|
outputs=gr.Label(num_top_classes=5), |
|
title="CWE Vulnerability Classifier", |
|
description="Enter a vulnerability description to predict the most likely CWE types." |
|
) |
|
|
|
|
|
interface.launch() |
|
|
|
|