from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import torch.nn.functional as F from peft import PeftModel class EndpointHandler: def __init__(self, model_dir): """ Initialize the model and tokenizer using the provided model directory. """ model_name = "munzirmuneer/phishing_url_gemma_pytorch" # Replace with your specific model model_name2 = "google/gemma-2b" # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_name2) base_model = AutoModelForSequenceClassification.from_pretrained(model_name) self.model = PeftModel.from_pretrained(base_model, model_name) def __call__(self, input_data): """ Perform inference on the input text and return predictions. """ # Extract the URL from the input_data dictionary if 'inputs' in input_data: input_text = input_data['inputs'] # Expecting a single URL as a string else: raise ValueError("Input data must contain the 'inputs' key with a URL.") # Tokenize input inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) # Run inference with torch.no_grad(): outputs = self.model(**inputs) # Get logits and probabilities logits = outputs.logits probs = F.softmax(logits, dim=-1) # Get the predicted class (highest probability) pred_class = torch.argmax(probs, dim=-1) return { "predicted_class": pred_class.item(), "probabilities": probs[0].tolist() }