import os import gradio as gr import torch from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer import spaces # Define the model configurations model_configs = { "CyberSentinel": { "model_name": "dad1909/cybersentinal-2.0", "max_seq_length": 1028, "dtype": torch.float16, "load_in_4bit": True } } # Hugging Face token hf_token = os.getenv("HF_TOKEN") # Load the model when the application starts loaded_models = {} def load_model(selected_model): if selected_model not in loaded_models: config = model_configs[selected_model] model = AutoModelForCausalLM.from_pretrained( config["model_name"], torch_dtype=config["dtype"], device_map="auto", use_auth_token=hf_token ) tokenizer = AutoTokenizer.from_pretrained( config["model_name"], use_auth_token=hf_token ) loaded_models[selected_model] = (model, tokenizer) return loaded_models[selected_model] alpaca_prompts = { "information": "Give me information about the following topic: {}", "vulnerable": """Identify the line of code that is vulnerable and describe the type of software vulnerability. ### Code Snippet: {} ### Vulnerability Description:""", "Chat": "{}" } @spaces.GPU(duration=100) def predict(selected_model, prompt, prompt_type, max_length=128): model, tokenizer = load_model(selected_model) selected_prompt = alpaca_prompts[prompt_type] formatted_prompt = selected_prompt.format(prompt) inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") text_streamer = TextStreamer(tokenizer) output = model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_length) return tokenizer.decode(output[0], skip_special_tokens=True) theme = gr.themes.Default( primary_hue=gr.themes.colors.rose, secondary_hue=gr.themes.colors.blue, font=gr.themes.GoogleFont("Source Sans Pro") ) load_model("CyberSentinel") with gr.Blocks(theme=theme) as demo: selected_model = gr.Dropdown(choices=list(model_configs.keys()), value="CyberSentinel", label="Model") prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt") prompt_type = gr.Dropdown(choices=list(alpaca_prompts.keys()), value="Chat", label="Prompt Type") max_length = gr.Slider(minimum=128, maximum=512, step=128, value=128, label="Max Length") generated_text = gr.Textbox(label="Generated Text") generate_button = gr.Button("Generate") generate_button.click(predict, inputs=[selected_model, prompt, prompt_type, max_length], outputs=generated_text) gr.Examples( examples=[ ["CyberSentinel", "What is SQL injection?", "information", 128], ["CyberSentinel", "$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);", "vulnerable", 128], ["CyberSentinel", "Can you tell me a joke?", "Chat", 128] ], inputs=[selected_model, prompt, prompt_type, max_length] ) demo.queue(default_concurrency_limit=20).launch( server_name="0.0.0.0", allowed_paths=["/"], share=True )