|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import logging |
|
import re |
|
|
|
|
|
logging.basicConfig( |
|
filename="app.log", |
|
level=logging.INFO, |
|
format="%(asctime)s:%(levelname)s:%(message)s" |
|
) |
|
|
|
|
|
def load_model(): |
|
""" |
|
Loads and caches the pre-trained language model and tokenizer. |
|
Returns: |
|
model: Pre-trained language model. |
|
tokenizer: Tokenizer for the model. |
|
""" |
|
try: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_path = "Canstralian/pentest_ai" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
device_map={"": device}, |
|
load_in_8bit=False, |
|
trust_remote_code=True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
logging.info("Model and tokenizer loaded successfully.") |
|
return model, tokenizer |
|
except Exception as e: |
|
logging.error(f"Error loading model: {e}") |
|
return None, None |
|
|
|
def sanitize_input(text): |
|
""" |
|
Sanitizes and validates user input text to prevent injection or formatting issues. |
|
|
|
Args: |
|
text (str): User input text. |
|
Returns: |
|
str: Sanitized text. |
|
""" |
|
if not isinstance(text, str): |
|
raise ValueError("Input must be a string.") |
|
|
|
sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text) |
|
return sanitized_text.strip() |
|
|
|
def generate_text(model, tokenizer, instruction): |
|
""" |
|
Generates text based on the provided instruction using the loaded model. |
|
Args: |
|
model: The language model. |
|
tokenizer: Tokenizer for encoding/decoding. |
|
instruction (str): Instruction text for the model. |
|
Returns: |
|
str: Generated text response from the model. |
|
""" |
|
try: |
|
|
|
instruction = sanitize_input(instruction) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
tokens = tokenizer.encode(instruction, return_tensors='pt').to(device) |
|
generated_tokens = model.generate( |
|
tokens, |
|
max_length=1024, |
|
top_p=1.0, |
|
temperature=0.5, |
|
top_k=50 |
|
) |
|
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) |
|
logging.info("Text generated successfully.") |
|
return generated_text |
|
except Exception as e: |
|
logging.error(f"Error generating text: {e}") |
|
return "Error in text generation." |
|
|
|
|
|
def gradio_interface(instruction): |
|
""" |
|
Interface function for Gradio to interact with the model and generate text. |
|
""" |
|
|
|
model, tokenizer = load_model() |
|
|
|
if not model or not tokenizer: |
|
return "Failed to load model or tokenizer. Please check your configuration." |
|
|
|
|
|
try: |
|
generated_text = generate_text(model, tokenizer, instruction) |
|
return generated_text |
|
except ValueError as ve: |
|
return f"Invalid input: {ve}" |
|
except Exception as e: |
|
logging.error(f"Error during text generation: {e}") |
|
return "An error occurred. Please try again." |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Textbox(label="Enter an instruction for the model:", placeholder="Type your instruction here..."), |
|
outputs=gr.Textbox(label="Generated Text:"), |
|
title="Penetration Testing AI Assistant", |
|
description="This tool allows you to interact with a pre-trained AI model for penetration testing assistance. Enter an instruction to generate a response.", |
|
) |
|
|
|
|
|
iface.launch() |
|
|