File size: 4,038 Bytes
58022ea 16bf80f 4604847 940bdb8 3d25e27 16bf80f 940bdb8 16bf80f 4604847 16bf80f 940bdb8 16bf80f 3d25e27 16bf80f 4604847 16bf80f 4604847 58022ea 16bf80f 58022ea 16bf80f 58022ea 16bf80f 58022ea 0ec3ad8 58022ea 16bf80f 58022ea 16bf80f 58022ea 16bf80f 58022ea 0ec3ad8 58022ea 4604847 58022ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
import re
# Set up logging
logging.basicConfig(
filename="app.log",
level=logging.INFO,
format="%(asctime)s:%(levelname)s:%(message)s"
)
# Model and tokenizer loading function with caching
def load_model():
"""
Loads and caches the pre-trained language model and tokenizer.
Returns:
model: Pre-trained language model.
tokenizer: Tokenizer for the model.
"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "Canstralian/pentest_ai" # Replace with the actual path if different
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map={"": device}, # This will specify CPU or GPU explicitly
load_in_8bit=False, # Disabled for stability
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
logging.info("Model and tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
logging.error(f"Error loading model: {e}")
return None, None
def sanitize_input(text):
"""
Sanitizes and validates user input text to prevent injection or formatting issues.
Args:
text (str): User input text.
Returns:
str: Sanitized text.
"""
if not isinstance(text, str):
raise ValueError("Input must be a string.")
# Basic sanitization to remove unwanted characters
sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text)
return sanitized_text.strip()
def generate_text(model, tokenizer, instruction):
"""
Generates text based on the provided instruction using the loaded model.
Args:
model: The language model.
tokenizer: Tokenizer for encoding/decoding.
instruction (str): Instruction text for the model.
Returns:
str: Generated text response from the model.
"""
try:
# Validate and sanitize instruction input
instruction = sanitize_input(instruction)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokens = tokenizer.encode(instruction, return_tensors='pt').to(device)
generated_tokens = model.generate(
tokens,
max_length=1024,
top_p=1.0,
temperature=0.5,
top_k=50
)
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
logging.info("Text generated successfully.")
return generated_text
except Exception as e:
logging.error(f"Error generating text: {e}")
return "Error in text generation."
# Gradio Interface Function
def gradio_interface(instruction):
"""
Interface function for Gradio to interact with the model and generate text.
"""
# Load the model and tokenizer
model, tokenizer = load_model()
if not model or not tokenizer:
return "Failed to load model or tokenizer. Please check your configuration."
# Generate the text
try:
generated_text = generate_text(model, tokenizer, instruction)
return generated_text
except ValueError as ve:
return f"Invalid input: {ve}"
except Exception as e:
logging.error(f"Error during text generation: {e}")
return "An error occurred. Please try again."
# Create Gradio Interface
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Enter an instruction for the model:", placeholder="Type your instruction here..."),
outputs=gr.Textbox(label="Generated Text:"),
title="Penetration Testing AI Assistant",
description="This tool allows you to interact with a pre-trained AI model for penetration testing assistance. Enter an instruction to generate a response.",
)
# Launch the Gradio interface
iface.launch()
|