hbanduk commited on
Commit
68a6940
·
verified ·
1 Parent(s): 8afdb60

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +32 -23
tasks/text.py CHANGED
@@ -60,36 +60,45 @@ async def evaluate_text(request: TextEvaluationRequest):
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
- import torch
65
-
66
- # Load model and tokenizer from Hugging Face Hub
67
- MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
68
- MODEL_PATH = "./distilbert_trained.pth"
69
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
70
-
71
- config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=8)
72
- model = DistilBertForSequenceClassification(config)
73
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu")))
74
- #model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
75
- model.eval()
76
 
 
 
 
 
 
 
 
 
 
77
  def preprocess(texts):
78
- """ Tokenize text inputs for DistilBERT """
79
- return tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
80
-
 
 
 
 
 
 
81
  def predict(texts):
82
- """ Run inference using the fine-tuned DistilBERT model """
83
  inputs = preprocess(texts)
84
- with torch.no_grad():
85
- outputs = model(**inputs)
86
- predictions = torch.argmax(outputs.logits, dim=1).tolist()
 
 
 
 
87
  return predictions
88
-
89
- # Run inference
90
  texts = test_dataset["quote"]
91
  predictions = predict(texts)
92
-
93
  true_labels = test_dataset["label"]
94
  #--------------------------------------------------------------------------------------------
95
  # YOUR MODEL INFERENCE STOPS HERE
 
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
+ from transformers import DistilBertTokenizer
64
+ import numpy as np
65
+ import onnxruntime as ort
66
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
67
 
68
+ # Load the ONNX model and tokenizer
69
+ MODEL_REPO = "ClimateDebunk/Quantized_DistilBertForSequenceClassification"
70
+ MODEL_FILENAME = "distilbert_quantized_dynamic.onnx"
71
+ MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
72
+
73
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
74
+ ort_session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
75
+
76
+ # Preprocess the text data
77
  def preprocess(texts):
78
+ return tokenizer(
79
+ texts,
80
+ padding=True,
81
+ truncation=True,
82
+ max_length=365,
83
+ return_tensors="np"
84
+ )
85
+
86
+ # Run inference
87
  def predict(texts):
 
88
  inputs = preprocess(texts)
89
+ ort_inputs = {
90
+ "input_ids": inputs["input_ids"].astype(np.int64),
91
+ "attention_mask": inputs["attention_mask"].astype(np.int64)
92
+ }
93
+ ort_outputs = ort_session.run(None, ort_inputs)
94
+ logits = ort_outputs[0]
95
+ predictions = np.argmax(logits, axis=1)
96
  return predictions
97
+
98
+ # Replace the random predictions with actual model predictions
99
  texts = test_dataset["quote"]
100
  predictions = predict(texts)
101
+
102
  true_labels = test_dataset["label"]
103
  #--------------------------------------------------------------------------------------------
104
  # YOUR MODEL INFERENCE STOPS HERE