import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM import torch # Define the model names and mappings MODEL_MAPPING = { "text2shellcommands": "Canstralian/text2shellcommands", "pentest_ai": "Canstralian/pentest_ai", } # Sidebar for model selection def select_model(): st.sidebar.header("Model Configuration") return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys())) # Load model and tokenizer with caching @st.cache_resource def load_model_and_tokenizer(model_name): try: # Use a fallback model for testing if model_name == "Canstralian/text2shellcommands": model_name = "t5-small" # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name) if "seq2seq" in model_name.lower(): model = AutoModelForSeq2SeqLM.from_pretrained(model_name) else: model = AutoModelForSequenceClassification.from_pretrained(model_name) return tokenizer, model except Exception as e: st.error(f"Error loading model: {e}") return None, None # Handle predictions def predict_with_model(user_input, model, tokenizer, model_choice): if model_choice == "text2shellcommands": # Generate shell commands inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model.generate(**inputs) generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"Generated Shell Command": generated_command} else: # Perform classification inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = torch.argmax(logits, dim=-1).item() return { "Predicted Class": predicted_class, "Logits": logits.tolist(), } # Main Streamlit app def main(): st.title("AI Model Inference Dashboard") # Model selection model_choice = select_model() model_name = MODEL_MAPPING.get(model_choice) tokenizer, model = load_model_and_tokenizer(model_name) # Input text box user_input = st.text_area("Enter text:") # Perform prediction if input and models are available if user_input and model and tokenizer: result = predict_with_model(user_input, model, tokenizer, model_choice) for key, value in result.items(): st.write(f"{key}: {value}") else: st.info("Please enter some text for prediction.") if __name__ == "__main__": main()