|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import torch.nn.functional as F |
|
from peft import PeftModel |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
""" |
|
Initialize the model and tokenizer using the provided model directory. |
|
""" |
|
model_name = "munzirmuneer/phishing_url_gemma_pytorch" |
|
model_name2 = "google/gemma-2b" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name2) |
|
base_model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
self.model = PeftModel.from_pretrained(base_model, model_name) |
|
|
|
def __call__(self, input_data): |
|
""" |
|
Perform inference on the input text and return predictions. |
|
""" |
|
|
|
if 'inputs' in input_data: |
|
input_text = input_data['inputs'] |
|
else: |
|
raise ValueError("Input data must contain the 'inputs' key with a URL.") |
|
|
|
inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
pred_class = torch.argmax(probs, dim=-1) |
|
|
|
return { |
|
"predicted_class": pred_class.item(), |
|
"probabilities": probs[0].tolist() |
|
} |
|
|