hbanduk commited on
Commit
9086772
·
verified ·
1 Parent(s): 7804b0f

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +18 -32
tasks/text.py CHANGED
@@ -60,45 +60,31 @@ async def evaluate_text(request: TextEvaluationRequest):
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
- from transformers import DistilBertTokenizer
64
- import numpy as np
65
- import onnxruntime as ort
66
- from huggingface_hub import hf_hub_download
 
 
 
 
67
 
68
- # Load the ONNX model and tokenizer
69
- MODEL_REPO = "ClimateDebunk/Quantized_DistilBertForSequenceClassification"
70
- MODEL_FILENAME = "distilbert_quantized_dynamic.onnx"
71
- MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
72
-
73
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
74
- ort_session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
75
-
76
- # Preprocess the text data
77
  def preprocess(texts):
78
- return tokenizer(
79
- texts,
80
- padding='max_length',
81
- truncation=True,
82
- max_length=365,
83
- return_tensors="np"
84
- )
85
-
86
- # Run inference
87
  def predict(texts):
 
88
  inputs = preprocess(texts)
89
- ort_inputs = {
90
- "input_ids": inputs["input_ids"].astype(np.int64),
91
- "attention_mask": inputs["attention_mask"].astype(np.int64)
92
- }
93
- ort_outputs = ort_session.run(None, ort_inputs)
94
- logits = ort_outputs[0]
95
- predictions = np.argmax(logits, axis=1)
96
  return predictions
97
-
98
-
99
  texts = test_dataset["quote"]
100
  predictions = predict(texts)
101
-
102
  true_labels = test_dataset["label"]
103
  #--------------------------------------------------------------------------------------------
104
  # YOUR MODEL INFERENCE STOPS HERE
 
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
+ import torch
65
+
66
+ # Load model and tokenizer from Hugging Face Hub
67
+ MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
68
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
69
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
70
+ model.eval() # Set to evaluation mode
71
 
 
 
 
 
 
 
 
 
 
72
  def preprocess(texts):
73
+ """ Tokenize text inputs for DistilBERT """
74
+ return tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
75
+
 
 
 
 
 
 
76
  def predict(texts):
77
+ """ Run inference using the fine-tuned DistilBERT model """
78
  inputs = preprocess(texts)
79
+ with torch.no_grad():
80
+ outputs = model(**inputs)
81
+ predictions = torch.argmax(outputs.logits, dim=1).tolist()
 
 
 
 
82
  return predictions
83
+
84
+ # Run inference
85
  texts = test_dataset["quote"]
86
  predictions = predict(texts)
87
+
88
  true_labels = test_dataset["label"]
89
  #--------------------------------------------------------------------------------------------
90
  # YOUR MODEL INFERENCE STOPS HERE