import streamlit as st from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, ) import torch import os # Define the model names and their corresponding Hugging Face models MODEL_MAPPING = { "text2shellcommands": "t5-small", # Example seq2seq model for generating shell commands "pentest_ai": "bert-base-uncased", # Example classification model for pentesting tasks } # Function to create a sidebar for model selection def select_model(): """ Adds a dropdown to the Streamlit sidebar for selecting a model. Returns: str: The selected model key from MODEL_MAPPING. """ st.sidebar.header("Model Configuration") selected_model = st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys())) return selected_model # Function to load the model and tokenizer with caching @st.cache_resource def load_model_and_tokenizer(model_name): """ Loads the tokenizer and model for the specified Hugging Face model name. Uses caching to optimize performance. Args: model_name (str): The name of the Hugging Face model to load. Returns: tuple: A tokenizer and model instance. """ try: # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) # Determine the correct model class to use if "t5" in model_name or "seq2seq" in model_name: # Load a sequence-to-sequence model model = AutoModelForSeq2SeqLM.from_pretrained(model_name) else: # Load a sequence classification model model = AutoModelForSequenceClassification.from_pretrained(model_name) return tokenizer, model except Exception as e: # Display an error message in the Streamlit app st.error(f"An error occurred while loading the model or tokenizer: {str(e)}") return None, None # Function to handle predictions based on the selected model def predict_with_model(user_input, model, tokenizer, model_choice): """ Handles predictions using the loaded model and tokenizer. Args: user_input (str): Text input from the user. model: Loaded Hugging Face model. tokenizer: Loaded Hugging Face tokenizer. model_choice (str): Selected model key from MODEL_MAPPING. Returns: dict: A dictionary containing the prediction results. """ if model_choice == "text2shellcommands": # Generate shell commands (Seq2Seq task) inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model.generate(**inputs) generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"Generated Shell Command": generated_command} else: # Perform classification inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = torch.argmax(logits, dim=-1).item() return { "Predicted Class": predicted_class, "Logits": logits.tolist(), } # Function to process uploaded files def process_uploaded_file(uploaded_file): """ Reads and processes the uploaded file. Supports text and CSV files. Args: uploaded_file: The uploaded file. Returns: str: The content of the file as a string. """ try: if uploaded_file is not None: file_type = uploaded_file.type # Text file processing if "text" in file_type: content = uploaded_file.read().decode("utf-8") return content # CSV file processing elif "csv" in file_type: import pandas as pd df = pd.read_csv(uploaded_file) return df.to_string() # Convert the dataframe to string else: st.error("Unsupported file type. Please upload a text or CSV file.") return None except Exception as e: st.error(f"Error processing file: {e}") return None # Main function to define the Streamlit app def main(): st.title("AI Model Inference Dashboard") st.markdown( """ This dashboard allows you to interact with different AI models for inference tasks, such as generating shell commands or performing text classification. """ ) # Model selection model_choice = select_model() model_name = MODEL_MAPPING.get(model_choice) tokenizer, model = load_model_and_tokenizer(model_name) # Input text area or file upload input_choice = st.radio("Choose Input Method", ("Text Input", "Upload File")) if input_choice == "Text Input": user_input = st.text_area("Enter your text input:", placeholder="Type your text here...") # Handle prediction after submit submit_button = st.button("Submit") if submit_button and user_input: st.write("### Prediction Results:") result = predict_with_model(user_input, model, tokenizer, model_choice) for key, value in result.items(): st.write(f"**{key}:** {value}") elif input_choice == "Upload File": uploaded_file = st.file_uploader("Choose a text or CSV file", type=["txt", "csv"]) # Handle prediction after submit submit_button = st.button("Submit") if submit_button and uploaded_file: file_content = process_uploaded_file(uploaded_file) if file_content: st.write("### File Content:") st.write(file_content) result = predict_with_model(file_content, model, tokenizer, model_choice) st.write("### Prediction Results:") for key, value in result.items(): st.write(f"**{key}:** {value}") else: st.info("No valid content found in the file.") if __name__ == "__main__": main()