hbanduk commited on
Commit
371a733
·
verified ·
1 Parent(s): 739d55a

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +20 -5
tasks/text.py CHANGED
@@ -68,23 +68,38 @@ async def evaluate_text(request: TextEvaluationRequest):
68
  # Load the ONNX model and tokenizer
69
  MODEL_REPO = "ClimateDebunk/Quantized_DistilBertForSequenceClassification"
70
  MODEL_FILENAME = "distilbert_quantized_dynamic.onnx"
71
- MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
72
 
73
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
74
- ort_session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Preprocess the text data
77
  def preprocess(texts):
78
- return tokenizer(
 
79
  texts,
80
- padding=True,
81
  truncation=True,
82
  max_length=365,
83
  return_tensors="np"
84
  )
 
 
 
85
 
86
  # Run inference
87
  def predict(texts):
 
88
  inputs = preprocess(texts)
89
  ort_inputs = {
90
  "input_ids": inputs["input_ids"].astype(np.int64),
 
68
  # Load the ONNX model and tokenizer
69
  MODEL_REPO = "ClimateDebunk/Quantized_DistilBertForSequenceClassification"
70
  MODEL_FILENAME = "distilbert_quantized_dynamic.onnx"
 
71
 
72
+ try:
73
+ MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
74
+ print(f"Model successfully downloaded at: {MODEL_PATH}")
75
+
76
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
77
+ print("Tokenizer loaded successfully!")
78
+
79
+ ort_session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
80
+ print("ONNX session initialized successfully!")
81
+ except Exception as e:
82
+ print(f"Error loading ONNX model: {e}")
83
+
84
+
85
 
86
  # Preprocess the text data
87
  def preprocess(texts):
88
+ print(f"📌 Preprocessing {len(texts)} text samples...")
89
+ inputs = tokenizer(
90
  texts,
91
+ padding='max_length',
92
  truncation=True,
93
  max_length=365,
94
  return_tensors="np"
95
  )
96
+ print(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
97
+ print(f"Tokenized attention_mask shape: {inputs['attention_mask'].shape}")
98
+ return inputs
99
 
100
  # Run inference
101
  def predict(texts):
102
+ print(f"📌 Running inference on {len(texts)} samples...")
103
  inputs = preprocess(texts)
104
  ort_inputs = {
105
  "input_ids": inputs["input_ids"].astype(np.int64),