Oriaz commited on
Commit
a8a6edb
·
verified ·
1 Parent(s): d2de887

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +4 -4
tasks/text.py CHANGED
@@ -11,7 +11,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
  from sentence_transformers import SentenceTransformer
12
  from sklearn.preprocessing import MinMaxScaler
13
  import numpy as np
14
- import pickle
15
 
16
  router = APIRouter()
17
 
@@ -66,8 +66,8 @@ async def evaluate_text(request: TextEvaluationRequest):
66
  query_prompt_name = "s2s_query"
67
  model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True).cuda()
68
 
69
- with open('logistic_regression_model.pkl', 'rb') as file:
70
- disp = pickle.load(file)
71
 
72
  ## Data prep
73
  embeddings = model.encode(list(test_dataset['quote']), prompt_name=query_prompt_name)
@@ -81,11 +81,11 @@ async def evaluate_text(request: TextEvaluationRequest):
81
  # YOUR MODEL INFERENCE STOPS HERE
82
  #--------------------------------------------------------------------------------------------
83
 
84
-
85
  # Stop tracking emissions
86
  emissions_data = tracker.stop_task()
87
 
88
  # Calculate accuracy
 
89
  accuracy = accuracy_score(true_labels, predictions)
90
 
91
  # Prepare results dictionary
 
11
  from sentence_transformers import SentenceTransformer
12
  from sklearn.preprocessing import MinMaxScaler
13
  import numpy as np
14
+ import skops.io as sio
15
 
16
  router = APIRouter()
17
 
 
66
  query_prompt_name = "s2s_query"
67
  model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True).cuda()
68
 
69
+ trusted_types = ['sklearn.feature_selection._univariate_selection.f_classif']
70
+ disp = sio.load('logistic_regression_model.skops',trusted=trusted_types)
71
 
72
  ## Data prep
73
  embeddings = model.encode(list(test_dataset['quote']), prompt_name=query_prompt_name)
 
81
  # YOUR MODEL INFERENCE STOPS HERE
82
  #--------------------------------------------------------------------------------------------
83
 
 
84
  # Stop tracking emissions
85
  emissions_data = tracker.stop_task()
86
 
87
  # Calculate accuracy
88
+ true_labels = test_dataset["label"]
89
  accuracy = accuracy_score(true_labels, predictions)
90
 
91
  # Prepare results dictionary