Oriaz commited on
Commit
c98f02f
·
verified ·
1 Parent(s): a1920b3

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +20 -3
tasks/text.py CHANGED
@@ -7,6 +7,12 @@ import random
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
 
 
 
 
 
 
10
  router = APIRouter()
11
 
12
  DESCRIPTION = "Random test"
@@ -56,9 +62,20 @@ async def evaluate_text(request: TextEvaluationRequest):
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
 
59
- # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
+ ## add-on imports
11
+ from sentence_transformers import SentenceTransformer
12
+ from sklearn.preprocessing import MinMaxScaler
13
+ import numpy as np
14
+ import pickle
15
+
16
  router = APIRouter()
17
 
18
  DESCRIPTION = "Random test"
 
62
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
63
  #--------------------------------------------------------------------------------------------
64
 
65
+ ## Models loading
66
+ query_prompt_name = "s2s_query"
67
+ model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True).cuda()
68
+
69
+ with open('logistic_regression_model.pkl', 'rb') as file:
70
+ disp = pickle.load(file)
71
+
72
+ ## Data prep
73
+ embeddings = model.encode(list(test_dataset['quote']), prompt_name=query_prompt_name)
74
+ scaler = MinMaxScaler()
75
+ X_scaled = scaler.fit_transform(embeddings)
76
+
77
+ ## Predictions
78
+ predictions= disp.predict(X_scaled)
79
 
80
  #--------------------------------------------------------------------------------------------
81
  # YOUR MODEL INFERENCE STOPS HERE