munzirmuneer commited on
Commit
374ec5a
·
verified ·
1 Parent(s): 1ca0e38

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +29 -26
inference.py CHANGED
@@ -1,26 +1,29 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- import torch
3
- import torch.nn.functional as F
4
- from peft import PeftModel
5
-
6
- # Load model and tokenizer
7
- model_name = "munzirmuneer/phishing_url_gemma_pytorch" # Replace with your specific model
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
- model = PeftModel.from_pretrained(model, model_name)
11
-
12
- def predict(input_text):
13
- # Tokenize input
14
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
15
-
16
- # Run inference
17
- with torch.no_grad():
18
- outputs = model(**inputs)
19
-
20
- # Get logits and probabilities
21
- logits = outputs.logits
22
- probs = F.softmax(logits, dim=-1)
23
-
24
- # Get the predicted class (highest probability)
25
- pred_class = torch.argmax(probs, dim=-1)
26
- return pred_class.item(), probs[0].tolist()
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+ from huggingface_hub import HfApi
4
+ import torch.nn.functional as F
5
+ from peft import PeftModel
6
+
7
+ HfApi().set_access_token("HUGGINGFACE_HUB_TOKEN")
8
+
9
+ # Load model and tokenizer
10
+ model_name = "munzirmuneer/phishing_url_gemma_pytorch" # Replace with your specific model
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
12
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=True)
13
+ model = PeftModel.from_pretrained(model, model_name, use_auth_token=True)
14
+
15
+ def predict(input_text):
16
+ # Tokenize input
17
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
18
+
19
+ # Run inference
20
+ with torch.no_grad():
21
+ outputs = model(**inputs)
22
+
23
+ # Get logits and probabilities
24
+ logits = outputs.logits
25
+ probs = F.softmax(logits, dim=-1)
26
+
27
+ # Get the predicted class (highest probability)
28
+ pred_class = torch.argmax(probs, dim=-1)
29
+ return pred_class.item(), probs[0].tolist()