hbanduk commited on
Commit
bf8c867
·
verified ·
1 Parent(s): 0ae53cb

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +38 -1
tasks/text.py CHANGED
@@ -57,8 +57,45 @@ async def evaluate_text(request: TextEvaluationRequest):
57
  #--------------------------------------------------------------------------------------------
58
 
59
  # Make random predictions (placeholder for actual model inference)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
 
57
  #--------------------------------------------------------------------------------------------
58
 
59
  # Make random predictions (placeholder for actual model inference)
60
+ #true_labels = test_dataset["label"]
61
+ #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
+
63
+ from transformers import DistilBertTokenizer
64
+ import numpy as np
65
+ import onnxruntime as ort
66
+
67
+ # Load the ONNX model and tokenizer
68
+ MODEL_PATH = "/Users/hinabandukwala/Documents/frugalai/submission-template/models/distilbert_quantized_dynamic.onnx"
69
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
70
+ ort_session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
71
+
72
+ # Preprocess the text data
73
+ def preprocess(texts):
74
+ return tokenizer(
75
+ texts,
76
+ padding=True,
77
+ truncation=True,
78
+ max_length=365,
79
+ return_tensors="np"
80
+ )
81
+
82
+ # Run inference
83
+ def predict(texts):
84
+ inputs = preprocess(texts)
85
+ ort_inputs = {
86
+ "input_ids": inputs["input_ids"].astype(np.int64),
87
+ "attention_mask": inputs["attention_mask"].astype(np.int64)
88
+ }
89
+ ort_outputs = ort_session.run(None, ort_inputs)
90
+ logits = ort_outputs[0]
91
+ predictions = np.argmax(logits, axis=1)
92
+ return predictions
93
+
94
+ # Replace the random predictions with actual model predictions
95
+ texts = test_dataset["text"]
96
+ predictions = predict(texts)
97
+
98
  true_labels = test_dataset["label"]
 
99
 
100
  #--------------------------------------------------------------------------------------------
101
  # YOUR MODEL INFERENCE STOPS HERE