csk99 commited on
Commit
a8c9010
·
verified ·
1 Parent(s): e324040

Update tasks/text.py

Browse files

Add inference code with sentence transformer and XGBoost model

Files changed (1) hide show
  1. tasks/text.py +21 -1
tasks/text.py CHANGED
@@ -7,6 +7,12 @@ import random
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
 
 
 
 
 
 
10
  router = APIRouter()
11
 
12
  DESCRIPTION = "Random Baseline"
@@ -53,12 +59,26 @@ async def evaluate_text(request: TextEvaluationRequest):
53
 
54
  #--------------------------------------------------------------------------------------------
55
  # YOUR MODEL INFERENCE CODE HERE
 
 
 
 
 
 
 
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
 
59
  # Make random predictions (placeholder for actual model inference)
60
  true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
 
 
 
 
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
+ #packages needed for inference
11
+ from sentence_transformers import SentenceTransformer
12
+ from xgboost import XGBClassifier
13
+ import pickle
14
+
15
+
16
  router = APIRouter()
17
 
18
  DESCRIPTION = "Random Baseline"
 
59
 
60
  #--------------------------------------------------------------------------------------------
61
  # YOUR MODEL INFERENCE CODE HERE
62
+
63
+ #Load the embedding model
64
+ model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
65
+
66
+ # Convert each sentence into a vector representation (embedding)
67
+ embeddings = model.encode(test_dataset['quote'].tolist())
68
+
69
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
70
  #--------------------------------------------------------------------------------------------
71
 
72
  # Make random predictions (placeholder for actual model inference)
73
  true_labels = test_dataset["label"]
74
+
75
+
76
+ #load the xgboost model
77
+ with open("models/stella_400_xgb_500.pkl",'rb') as f:
78
+ xgbclassifier = pickle.load(f)
79
+
80
+ #make inference
81
+ predictions = xgbclassifier.predict(embeddings)
82
 
83
  #--------------------------------------------------------------------------------------------
84
  # YOUR MODEL INFERENCE STOPS HERE