File size: 821 Bytes
3b1fc96 7f24230 3b1fc96 bb7eccc 3b1fc96 bb7eccc 3b1fc96 bb7eccc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import gradio as gr
from transformers import pipeline
MODEL_NAME = "CIRCL/cwe-vulnerability-classification-distilbert-base"
classifier = pipeline("text-classification", model=MODEL_NAME, return_all_scores=True)
def classify_cwe(text):
results = classifier(text)[0]
# Sort by confidence score descending
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
return {res["label"]: round(res["score"], 4) for res in sorted_results[:5]}
interface = gr.Interface(
fn=classify_cwe,
inputs=gr.Textbox(lines=5, placeholder="Enter vulnerability description..."),
outputs=gr.Label(num_top_classes=5),
title="CWE Vulnerability Classifier",
description="Enter a vulnerability description to predict the most likely CWE types."
)
# Launch the Gradio app
interface.launch()
|