Spaces:
Running
Running
Update tasks/text.py
Browse files- tasks/text.py +38 -1
tasks/text.py
CHANGED
@@ -62,6 +62,43 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
62 |
tracker.start_task("inference")
|
63 |
|
64 |
#--------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
#predictions = xgb.predict(embeddings)
|
@@ -80,7 +117,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
80 |
emissions_data = tracker.stop_task()
|
81 |
|
82 |
# Calculate accuracy
|
83 |
-
accuracy = accuracy_score(true_labels,
|
84 |
|
85 |
# Prepare results dictionary
|
86 |
results = {
|
|
|
62 |
tracker.start_task("inference")
|
63 |
|
64 |
#--------------------------------------------------------------------------------------------
|
65 |
+
# Load a pre-trained Sentence-BERT model
|
66 |
+
model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
|
67 |
+
# Generate sentence embeddings
|
68 |
+
sentence_embeddings = model.encode(test_dataset["quote"])
|
69 |
+
|
70 |
+
#load the models
|
71 |
+
with open("xgb_bin.pkl","rb") as f:
|
72 |
+
xgb_bin = pickle.load(f)
|
73 |
+
|
74 |
+
with open("xgb_multi.pkl","rb") as f:
|
75 |
+
xgb_multi = pickle.load(f)
|
76 |
+
|
77 |
+
|
78 |
+
X_train = test_dataset["quote"]
|
79 |
+
|
80 |
+
y_train = test_dataset["label"].copy()
|
81 |
+
|
82 |
+
#binary
|
83 |
+
y_train_binary = y_train.copy()
|
84 |
+
y_train_binary[y_train_binary != 0] = 1
|
85 |
+
|
86 |
+
|
87 |
+
#multi class
|
88 |
+
X_train_multi = X_train[y_train != 0]
|
89 |
+
|
90 |
+
y_train_multi = y_train[y_train != 0]
|
91 |
+
|
92 |
+
#predictions
|
93 |
+
y_pred_bin = xgb_bin.predict(X_train)
|
94 |
+
|
95 |
+
y_pred_multi = xgb_multi.predict(X_train_multi) + 1
|
96 |
+
|
97 |
+
y_pred_bin[y_pred_bin==1] = y_pred_multi
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
|
103 |
|
104 |
#predictions = xgb.predict(embeddings)
|
|
|
117 |
emissions_data = tracker.stop_task()
|
118 |
|
119 |
# Calculate accuracy
|
120 |
+
accuracy = accuracy_score(true_labels, y_true))
|
121 |
|
122 |
# Prepare results dictionary
|
123 |
results = {
|