Update tasks/text.py
Browse files- tasks/text.py +4 -4
tasks/text.py
CHANGED
@@ -11,7 +11,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
from sklearn.preprocessing import MinMaxScaler
|
13 |
import numpy as np
|
14 |
-
import
|
15 |
|
16 |
router = APIRouter()
|
17 |
|
@@ -66,8 +66,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
66 |
query_prompt_name = "s2s_query"
|
67 |
model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True).cuda()
|
68 |
|
69 |
-
|
70 |
-
|
71 |
|
72 |
## Data prep
|
73 |
embeddings = model.encode(list(test_dataset['quote']), prompt_name=query_prompt_name)
|
@@ -81,11 +81,11 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
81 |
# YOUR MODEL INFERENCE STOPS HERE
|
82 |
#--------------------------------------------------------------------------------------------
|
83 |
|
84 |
-
|
85 |
# Stop tracking emissions
|
86 |
emissions_data = tracker.stop_task()
|
87 |
|
88 |
# Calculate accuracy
|
|
|
89 |
accuracy = accuracy_score(true_labels, predictions)
|
90 |
|
91 |
# Prepare results dictionary
|
|
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
from sklearn.preprocessing import MinMaxScaler
|
13 |
import numpy as np
|
14 |
+
import skops.io as sio
|
15 |
|
16 |
router = APIRouter()
|
17 |
|
|
|
66 |
query_prompt_name = "s2s_query"
|
67 |
model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True).cuda()
|
68 |
|
69 |
+
trusted_types = ['sklearn.feature_selection._univariate_selection.f_classif']
|
70 |
+
disp = sio.load('logistic_regression_model.skops',trusted=trusted_types)
|
71 |
|
72 |
## Data prep
|
73 |
embeddings = model.encode(list(test_dataset['quote']), prompt_name=query_prompt_name)
|
|
|
81 |
# YOUR MODEL INFERENCE STOPS HERE
|
82 |
#--------------------------------------------------------------------------------------------
|
83 |
|
|
|
84 |
# Stop tracking emissions
|
85 |
emissions_data = tracker.stop_task()
|
86 |
|
87 |
# Calculate accuracy
|
88 |
+
true_labels = test_dataset["label"]
|
89 |
accuracy = accuracy_score(true_labels, predictions)
|
90 |
|
91 |
# Prepare results dictionary
|