yagnik12 commited on
Commit
a713ea4
·
verified ·
1 Parent(s): d4db00e

Update ai_text_detector_valid_final.py

Browse files
Files changed (1) hide show
  1. ai_text_detector_valid_final.py +22 -6
ai_text_detector_valid_final.py CHANGED
@@ -12,15 +12,31 @@ MODELS = {
12
 
13
  def load_model(model_id):
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
 
 
 
 
16
  return tokenizer, model
17
 
18
  def predict(text, tokenizer, model):
19
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
20
- with torch.no_grad():
21
- outputs = model(**inputs)
22
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
23
- return probs[0].numpy() # [human_prob, ai_prob]
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def verdict(ai_prob):
 
12
 
13
  def load_model(model_id):
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ # Use the zero-shot classification pipeline for NLI models
16
+ if model_id == "roberta-large-mnli":
17
+ model = pipeline("zero-shot-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
18
+ else:
19
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
20
  return tokenizer, model
21
 
22
  def predict(text, tokenizer, model):
23
+ if isinstance(model, pipeline):
24
+ # Use the roberta-mnli model for zero-shot classification
25
+ candidate_labels = ["This text was written by a human.", "This text was written by an AI."]
26
+ result = model(text, candidate_labels)
27
+
28
+ # The entailment score for each label is the probability
29
+ human_prob = result["scores"][result["labels"].index("This text was written by a human.")]
30
+ ai_prob = result["scores"][result["labels"].index("This text was written by an AI.")]
31
+
32
+ return np.array([human_prob, ai_prob])
33
+ else:
34
+ # The existing code for other models
35
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
39
+ return probs[0].numpy()
40
 
41
 
42
  def verdict(ai_prob):