File size: 3,220 Bytes
4deeced
 
493a1a4
 
 
 
 
 
 
 
 
 
 
 
 
4deeced
493a1a4
729180c
4deeced
493a1a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4deeced
 
493a1a4
 
4deeced
 
493a1a4
 
 
 
 
 
 
 
 
 
 
 
 
 
f260b79
4f2dec9
493a1a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import gradio as gr
import torch
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer
import spaces

# Define the model configurations
model_configs = {
    "CyberSentinel": {
        "model_name": "dad1909/cybersentinal-2.0",
        "max_seq_length": 1028,
        "dtype": torch.float16,
        "load_in_4bit": True
    }
}

# Hugging Face token
hf_token = os.getenv("HF_TOKEN")

# Load the model when the application starts
loaded_models = {}

def load_model(selected_model):
    if selected_model not in loaded_models:
        config = model_configs[selected_model]
        model = AutoModelForCausalLM.from_pretrained(
            config["model_name"],
            torch_dtype=config["dtype"],
            device_map="auto",
            use_auth_token=hf_token
        )
        tokenizer = AutoTokenizer.from_pretrained(
            config["model_name"],
            use_auth_token=hf_token
        )
        loaded_models[selected_model] = (model, tokenizer)
    return loaded_models[selected_model]

alpaca_prompts = {
    "information": "Give me information about the following topic: {}",
    "vulnerable": """Identify the line of code that is vulnerable and describe the type of software vulnerability.
### Code Snippet:
{}
### Vulnerability Description:""",
    "Chat": "{}"
}

@spaces.GPU(duration=100)
def predict(selected_model, prompt, prompt_type, max_length=128):
    model, tokenizer = load_model(selected_model)
    selected_prompt = alpaca_prompts[prompt_type]
    formatted_prompt = selected_prompt.format(prompt)
    inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
    text_streamer = TextStreamer(tokenizer)
    output = model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_length)
    return tokenizer.decode(output[0], skip_special_tokens=True)

theme = gr.themes.Default(
    primary_hue=gr.themes.colors.rose,
    secondary_hue=gr.themes.colors.blue,
    font=gr.themes.GoogleFont("Source Sans Pro")
)

load_model("CyberSentinel")

with gr.Blocks(theme=theme) as demo:
    selected_model = gr.Dropdown(choices=list(model_configs.keys()), value="CyberSentinel", label="Model")
    prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt")
    prompt_type = gr.Dropdown(choices=list(alpaca_prompts.keys()), value="Chat", label="Prompt Type")
    max_length = gr.Slider(minimum=128, maximum=512, step=128, value=128, label="Max Length")
    generated_text = gr.Textbox(label="Generated Text")

    generate_button = gr.Button("Generate")

    generate_button.click(predict, inputs=[selected_model, prompt, prompt_type, max_length], outputs=generated_text)

    gr.Examples(
        examples=[
            ["CyberSentinel", "What is SQL injection?", "information", 128],
            ["CyberSentinel", "$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);", "vulnerable", 128],
            ["CyberSentinel", "Can you tell me a joke?", "Chat", 128]
        ],
        inputs=[selected_model, prompt, prompt_type, max_length]
    )

demo.queue(default_concurrency_limit=20).launch(
    server_name="0.0.0.0",
    allowed_paths=["/"],
    share=True
)