|
import os |
|
import gradio as gr |
|
import torch |
|
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer |
|
import spaces |
|
|
|
|
|
model_configs = { |
|
"CyberSentinel": { |
|
"model_name": "dad1909/cybersentinal-2.0", |
|
"max_seq_length": 1028, |
|
"dtype": torch.float16, |
|
"load_in_4bit": True |
|
} |
|
} |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
loaded_models = {} |
|
|
|
def load_model(selected_model): |
|
if selected_model not in loaded_models: |
|
config = model_configs[selected_model] |
|
model = AutoModelForCausalLM.from_pretrained( |
|
config["model_name"], |
|
torch_dtype=config["dtype"], |
|
device_map="auto", |
|
use_auth_token=hf_token |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
config["model_name"], |
|
use_auth_token=hf_token |
|
) |
|
loaded_models[selected_model] = (model, tokenizer) |
|
return loaded_models[selected_model] |
|
|
|
alpaca_prompts = { |
|
"information": "Give me information about the following topic: {}", |
|
"vulnerable": """Identify the line of code that is vulnerable and describe the type of software vulnerability. |
|
### Code Snippet: |
|
{} |
|
### Vulnerability Description:""", |
|
"Chat": "{}" |
|
} |
|
|
|
@spaces.GPU(duration=100) |
|
def predict(selected_model, prompt, prompt_type, max_length=128): |
|
model, tokenizer = load_model(selected_model) |
|
selected_prompt = alpaca_prompts[prompt_type] |
|
formatted_prompt = selected_prompt.format(prompt) |
|
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") |
|
text_streamer = TextStreamer(tokenizer) |
|
output = model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_length) |
|
return tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
theme = gr.themes.Default( |
|
primary_hue=gr.themes.colors.rose, |
|
secondary_hue=gr.themes.colors.blue, |
|
font=gr.themes.GoogleFont("Source Sans Pro") |
|
) |
|
|
|
load_model("CyberSentinel") |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
selected_model = gr.Dropdown(choices=list(model_configs.keys()), value="CyberSentinel", label="Model") |
|
prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt") |
|
prompt_type = gr.Dropdown(choices=list(alpaca_prompts.keys()), value="Chat", label="Prompt Type") |
|
max_length = gr.Slider(minimum=128, maximum=512, step=128, value=128, label="Max Length") |
|
generated_text = gr.Textbox(label="Generated Text") |
|
|
|
generate_button = gr.Button("Generate") |
|
|
|
generate_button.click(predict, inputs=[selected_model, prompt, prompt_type, max_length], outputs=generated_text) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["CyberSentinel", "What is SQL injection?", "information", 128], |
|
["CyberSentinel", "$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);", "vulnerable", 128], |
|
["CyberSentinel", "Can you tell me a joke?", "Chat", 128] |
|
], |
|
inputs=[selected_model, prompt, prompt_type, max_length] |
|
) |
|
|
|
demo.queue(default_concurrency_limit=20).launch( |
|
server_name="0.0.0.0", |
|
allowed_paths=["/"], |
|
share=True |
|
) |