csk99 commited on
Commit
bcc7121
·
verified ·
1 Parent(s): 0fb4e95

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +13 -1
tasks/text.py CHANGED
@@ -9,6 +9,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
 
11
  import os
 
12
  import numpy as np
13
  print(os.getcwd())
14
  #
@@ -19,6 +20,10 @@ import pickle
19
  import xgboost as xgb
20
 
21
 
 
 
 
 
22
 
23
  router = APIRouter()
24
 
@@ -67,7 +72,7 @@ async def evaluate_text(request: TextEvaluationRequest):
67
  #--------------------------------------------------------------------------------------------
68
  # Load a pre-trained Sentence-BERT model
69
  print("loading model")
70
- model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device='cpu')
71
  # Generate sentence embeddings
72
  sentence_embeddings = model.encode(test_dataset["quote"])
73
 
@@ -88,6 +93,9 @@ async def evaluate_text(request: TextEvaluationRequest):
88
  #xgb_multi.load_model("xgb_model_muli.bin")
89
 
90
 
 
 
 
91
  X_train = sentence_embeddings.copy()
92
 
93
  y_train = test_dataset["label"].copy()
@@ -101,6 +109,10 @@ async def evaluate_text(request: TextEvaluationRequest):
101
  X_train_multi = X_train[y_train != 0]
102
 
103
  y_train_multi = y_train[y_train != 0]
 
 
 
 
104
 
105
  #predictions
106
  y_pred_bin = xgb_bin.predict(X_train)
 
9
 
10
 
11
  import os
12
+ import logging
13
  import numpy as np
14
  print(os.getcwd())
15
  #
 
20
  import xgboost as xgb
21
 
22
 
23
+ #logging
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+ logging.info("LAS ESTRELLAS!!!!!")
27
 
28
  router = APIRouter()
29
 
 
72
  #--------------------------------------------------------------------------------------------
73
  # Load a pre-trained Sentence-BERT model
74
  print("loading model")
75
+ model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
76
  # Generate sentence embeddings
77
  sentence_embeddings = model.encode(test_dataset["quote"])
78
 
 
93
  #xgb_multi.load_model("xgb_model_muli.bin")
94
 
95
 
96
+
97
+
98
+
99
  X_train = sentence_embeddings.copy()
100
 
101
  y_train = test_dataset["label"].copy()
 
109
  X_train_multi = X_train[y_train != 0]
110
 
111
  y_train_multi = y_train[y_train != 0]
112
+
113
+ logging.info(f"Xtrain_multi_shape:{X_train_multi.shape}")
114
+ logging.info(f"Xtrain shape:{X_train.shape}")
115
+
116
 
117
  #predictions
118
  y_pred_bin = xgb_bin.predict(X_train)