csk99 commited on
Commit
3a03067
·
verified ·
1 Parent(s): b70634b

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +38 -1
tasks/text.py CHANGED
@@ -62,6 +62,43 @@ async def evaluate_text(request: TextEvaluationRequest):
62
  tracker.start_task("inference")
63
 
64
  #--------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  #predictions = xgb.predict(embeddings)
@@ -80,7 +117,7 @@ async def evaluate_text(request: TextEvaluationRequest):
80
  emissions_data = tracker.stop_task()
81
 
82
  # Calculate accuracy
83
- accuracy = accuracy_score(true_labels, np.array([1]*len(true_labels)))
84
 
85
  # Prepare results dictionary
86
  results = {
 
62
  tracker.start_task("inference")
63
 
64
  #--------------------------------------------------------------------------------------------
65
+ # Load a pre-trained Sentence-BERT model
66
+ model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
67
+ # Generate sentence embeddings
68
+ sentence_embeddings = model.encode(test_dataset["quote"])
69
+
70
+ #load the models
71
+ with open("xgb_bin.pkl","rb") as f:
72
+ xgb_bin = pickle.load(f)
73
+
74
+ with open("xgb_multi.pkl","rb") as f:
75
+ xgb_multi = pickle.load(f)
76
+
77
+
78
+ X_train = test_dataset["quote"]
79
+
80
+ y_train = test_dataset["label"].copy()
81
+
82
+ #binary
83
+ y_train_binary = y_train.copy()
84
+ y_train_binary[y_train_binary != 0] = 1
85
+
86
+
87
+ #multi class
88
+ X_train_multi = X_train[y_train != 0]
89
+
90
+ y_train_multi = y_train[y_train != 0]
91
+
92
+ #predictions
93
+ y_pred_bin = xgb_bin.predict(X_train)
94
+
95
+ y_pred_multi = xgb_multi.predict(X_train_multi) + 1
96
+
97
+ y_pred_bin[y_pred_bin==1] = y_pred_multi
98
+
99
+
100
+
101
+
102
 
103
 
104
  #predictions = xgb.predict(embeddings)
 
117
  emissions_data = tracker.stop_task()
118
 
119
  # Calculate accuracy
120
+ accuracy = accuracy_score(true_labels, y_true))
121
 
122
  # Prepare results dictionary
123
  results = {