hbanduk commited on
Commit
bc7edfa
·
verified ·
1 Parent(s): 371a733

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +22 -45
tasks/text.py CHANGED
@@ -60,60 +60,37 @@ async def evaluate_text(request: TextEvaluationRequest):
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
- from transformers import DistilBertTokenizer
64
- import numpy as np
65
- import onnxruntime as ort
66
- from huggingface_hub import hf_hub_download
67
-
68
- # Load the ONNX model and tokenizer
69
- MODEL_REPO = "ClimateDebunk/Quantized_DistilBertForSequenceClassification"
70
- MODEL_FILENAME = "distilbert_quantized_dynamic.onnx"
71
 
72
- try:
73
- MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
74
- print(f"Model successfully downloaded at: {MODEL_PATH}")
75
-
76
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
77
- print("Tokenizer loaded successfully!")
78
-
79
- ort_session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
80
- print("ONNX session initialized successfully!")
81
- except Exception as e:
82
- print(f"Error loading ONNX model: {e}")
83
-
84
 
 
 
85
 
86
- # Preprocess the text data
 
 
87
  def preprocess(texts):
88
- print(f"📌 Preprocessing {len(texts)} text samples...")
89
- inputs = tokenizer(
90
- texts,
91
- padding='max_length',
92
- truncation=True,
93
- max_length=365,
94
- return_tensors="np"
95
- )
96
- print(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
97
- print(f"Tokenized attention_mask shape: {inputs['attention_mask'].shape}")
98
- return inputs
99
-
100
- # Run inference
101
  def predict(texts):
102
- print(f"📌 Running inference on {len(texts)} samples...")
103
  inputs = preprocess(texts)
104
- ort_inputs = {
105
- "input_ids": inputs["input_ids"].astype(np.int64),
106
- "attention_mask": inputs["attention_mask"].astype(np.int64)
107
- }
108
- ort_outputs = ort_session.run(None, ort_inputs)
109
- logits = ort_outputs[0]
110
- predictions = np.argmax(logits, axis=1)
111
  return predictions
112
-
113
- # Replace the random predictions with actual model predictions
114
  texts = test_dataset["quote"]
115
  predictions = predict(texts)
116
-
117
  true_labels = test_dataset["label"]
118
  #--------------------------------------------------------------------------------------------
119
  # YOUR MODEL INFERENCE STOPS HERE
 
60
  #true_labels = test_dataset["label"]
61
  #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
+ import torch
65
+
66
+ # Load model and tokenizer from Hugging Face Hub
67
+ MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
68
+ MODEL_FILENAME = "distilbert_trained.pth"
 
 
69
 
70
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
73
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
74
 
75
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
76
+ model.eval() # Set to evaluation mode
77
+
78
  def preprocess(texts):
79
+ """ Tokenize text inputs for DistilBERT """
80
+ return tokenizer(texts, padding='max_length', truncation=True, max_length=365, return_tensors="pt")
81
+
 
 
 
 
 
 
 
 
 
 
82
  def predict(texts):
83
+ """ Run inference using the fine-tuned DistilBERT model """
84
  inputs = preprocess(texts)
85
+ with torch.no_grad():
86
+ outputs = model(**inputs)
87
+ predictions = torch.argmax(outputs.logits, dim=1).tolist()
 
 
 
 
88
  return predictions
89
+
90
+ # Run inference
91
  texts = test_dataset["quote"]
92
  predictions = predict(texts)
93
+
94
  true_labels = test_dataset["label"]
95
  #--------------------------------------------------------------------------------------------
96
  # YOUR MODEL INFERENCE STOPS HERE