hbanduk commited on
Commit
e740326
·
verified ·
1 Parent(s): 12f942f

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +32 -19
tasks/text.py CHANGED
@@ -60,33 +60,46 @@ async def evaluate_text(request: TextEvaluationRequest):
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
- import torch
65
-
66
- # Load model and tokenizer from Hugging Face Hub
67
- MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
68
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
69
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
70
- model.eval() # Set to evaluation mode
71
 
 
 
 
 
 
 
 
 
 
72
  def preprocess(texts):
73
- """ Tokenize text inputs for DistilBERT """
74
- return tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
75
-
 
 
 
 
 
 
76
  def predict(texts):
77
- """ Run inference using the fine-tuned DistilBERT model """
78
  inputs = preprocess(texts)
79
- with torch.no_grad():
80
- outputs = model(**inputs)
81
- predictions = torch.argmax(outputs.logits, dim=1).tolist()
 
 
 
 
82
  return predictions
83
-
84
- # Run inference
85
  texts = test_dataset["text"]
86
  predictions = predict(texts)
87
-
88
- true_labels = test_dataset["label"]
89
 
 
90
  #--------------------------------------------------------------------------------------------
91
  # YOUR MODEL INFERENCE STOPS HERE
92
  #--------------------------------------------------------------------------------------------
 
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
+ from transformers import DistilBertTokenizer
64
+ import numpy as np
65
+ import onnxruntime as ort
66
+ from huggingface_hub import hf_hub_download
 
 
 
 
67
 
68
+ # Load the ONNX model and tokenizer
69
+ MODEL_REPO = "ClimateDebunk/Quantized_DistilBertForSequenceClassification"
70
+ MODEL_FILENAME = "distilbert_quantized_dynamic.onnx"
71
+ MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
72
+
73
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
74
+ ort_session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
75
+
76
+ # Preprocess the text data
77
  def preprocess(texts):
78
+ return tokenizer(
79
+ texts,
80
+ padding=True,
81
+ truncation=True,
82
+ max_length=365,
83
+ return_tensors="np"
84
+ )
85
+
86
+ # Run inference
87
  def predict(texts):
 
88
  inputs = preprocess(texts)
89
+ ort_inputs = {
90
+ "input_ids": inputs["input_ids"].astype(np.int64),
91
+ "attention_mask": inputs["attention_mask"].astype(np.int64)
92
+ }
93
+ ort_outputs = ort_session.run(None, ort_inputs)
94
+ logits = ort_outputs[0]
95
+ predictions = np.argmax(logits, axis=1)
96
  return predictions
97
+
98
+
99
  texts = test_dataset["text"]
100
  predictions = predict(texts)
 
 
101
 
102
+ true_labels = test_dataset["label"]
103
  #--------------------------------------------------------------------------------------------
104
  # YOUR MODEL INFERENCE STOPS HERE
105
  #--------------------------------------------------------------------------------------------