jennasparks commited on
Commit
69b42b9
·
verified ·
1 Parent(s): 0ae53cb

Add electra model

Browse files
Files changed (1) hide show
  1. tasks/text.py +20 -3
tasks/text.py CHANGED
@@ -4,12 +4,15 @@ from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
6
 
 
 
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -22,6 +25,15 @@ async def evaluate_text(request: TextEvaluationRequest):
22
  - Makes random predictions from the label space (0-7)
23
  - Used as a baseline for comparison
24
  """
 
 
 
 
 
 
 
 
 
25
  # Get space info
26
  username, space_url = get_space_info()
27
 
@@ -57,13 +69,18 @@ async def evaluate_text(request: TextEvaluationRequest):
57
  #--------------------------------------------------------------------------------------------
58
 
59
  # Make random predictions (placeholder for actual model inference)
 
 
 
 
 
 
 
60
  true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
65
  #--------------------------------------------------------------------------------------------
66
-
67
 
68
  # Stop tracking emissions
69
  emissions_data = tracker.stop_task()
 
4
  from sklearn.metrics import accuracy_score
5
  import random
6
 
7
+ # Load model using Keras
8
+ from tensorflow import keras
9
+
10
  from .utils.evaluation import TextEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
13
  router = APIRouter()
14
 
15
+ DESCRIPTION = "electra fine tune"
16
  ROUTE = "/text"
17
 
18
  @router.post(ROUTE, tags=["Text Task"],
 
25
  - Makes random predictions from the label space (0-7)
26
  - Used as a baseline for comparison
27
  """
28
+ # Download from Google Drive
29
+ import gdown
30
+
31
+ url = 'https://drive.google.com/uc?id=1-HWE2G6ANbd7mqILdB9DPrvF3DrKI1e4'
32
+ output = 'checkpoint_epoch_5.weights.h5'
33
+ gdown.download(url, output, quiet=False)
34
+
35
+ model = tf.keras.models.load_model('checkpoint_epoch_5.weights.h5')
36
+
37
  # Get space info
38
  username, space_url = get_space_info()
39
 
 
69
  #--------------------------------------------------------------------------------------------
70
 
71
  # Make random predictions (placeholder for actual model inference)
72
+ # true_labels = test_dataset["label"]
73
+ # predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
74
+
75
+ # Make predictions
76
+ predictions = model.predict(test_dataset)
77
+
78
+ # Get true labels
79
  true_labels = test_dataset["label"]
 
80
 
81
  #--------------------------------------------------------------------------------------------
82
  # YOUR MODEL INFERENCE STOPS HERE
83
  #--------------------------------------------------------------------------------------------
 
84
 
85
  # Stop tracking emissions
86
  emissions_data = tracker.stop_task()