csk99 commited on
Commit
13923c6
·
verified ·
1 Parent(s): 8e91a1e

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +23 -1
tasks/text.py CHANGED
@@ -7,6 +7,14 @@ import random
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
 
 
 
 
 
 
 
 
10
  router = APIRouter()
11
 
12
  DESCRIPTION = "Random Baseline"
@@ -52,13 +60,27 @@ async def evaluate_text(request: TextEvaluationRequest):
52
  tracker.start_task("inference")
53
 
54
  #--------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
55
  # YOUR MODEL INFERENCE CODE HERE
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
58
 
59
  # Make random predictions (placeholder for actual model inference)
60
  true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
+
11
+
12
+ #
13
+ from sentence_transformers import SentenceTransformer
14
+ from xgboost import XGBClassifier
15
+ import pickle
16
+
17
+
18
  router = APIRouter()
19
 
20
  DESCRIPTION = "Random Baseline"
 
60
  tracker.start_task("inference")
61
 
62
  #--------------------------------------------------------------------------------------------
63
+
64
+ #load
65
+ # Step 1: Use Sentence-BERT to convert text to embeddings
66
+ model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
67
+
68
+ # Convert each sentence into a vector representation (embedding)
69
+ embeddings = model.encode(test_dataset['quote'])
70
  # YOUR MODEL INFERENCE CODE HERE
71
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
72
  #--------------------------------------------------------------------------------------------
73
+
74
+ #load model
75
+ with open("models/stella_400_xgb_500.pkl","rb") as f:
76
+ xgb = pickle.load(f)
77
+
78
+ #predictions = xgb.predict(embeddings)
79
+
80
 
81
  # Make random predictions (placeholder for actual model inference)
82
  true_labels = test_dataset["label"]
83
+ predictions = xgb.predict(embeddings)
84
 
85
  #--------------------------------------------------------------------------------------------
86
  # YOUR MODEL INFERENCE STOPS HERE