File size: 821 Bytes
3b1fc96
 
 
7f24230
3b1fc96
 
 
bb7eccc
3b1fc96
 
 
bb7eccc
3b1fc96
 
 
 
 
 
 
 
 
 
 
bb7eccc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import gradio as gr
from transformers import pipeline

MODEL_NAME = "CIRCL/cwe-vulnerability-classification-distilbert-base"

classifier = pipeline("text-classification", model=MODEL_NAME, return_all_scores=True)

def classify_cwe(text):
    results = classifier(text)[0]
    # Sort by confidence score descending
    sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
    return {res["label"]: round(res["score"], 4) for res in sorted_results[:5]}

interface = gr.Interface(
    fn=classify_cwe,
    inputs=gr.Textbox(lines=5, placeholder="Enter vulnerability description..."),
    outputs=gr.Label(num_top_classes=5),
    title="CWE Vulnerability Classifier",
    description="Enter a vulnerability description to predict the most likely CWE types."
)

# Launch the Gradio app
interface.launch()