elselse's picture
Update app.py
7f24230 verified
raw
history blame contribute delete
821 Bytes
import gradio as gr
from transformers import pipeline
MODEL_NAME = "CIRCL/cwe-vulnerability-classification-distilbert-base"
classifier = pipeline("text-classification", model=MODEL_NAME, return_all_scores=True)
def classify_cwe(text):
results = classifier(text)[0]
# Sort by confidence score descending
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
return {res["label"]: round(res["score"], 4) for res in sorted_results[:5]}
interface = gr.Interface(
fn=classify_cwe,
inputs=gr.Textbox(lines=5, placeholder="Enter vulnerability description..."),
outputs=gr.Label(num_top_classes=5),
title="CWE Vulnerability Classifier",
description="Enter a vulnerability description to predict the most likely CWE types."
)
# Launch the Gradio app
interface.launch()