jennasparks commited on
Commit
bb9daa0
·
verified ·
1 Parent(s): 884cf55

loading model from hf

Browse files
Files changed (1) hide show
  1. tasks/text.py +11 -10
tasks/text.py CHANGED
@@ -8,13 +8,11 @@ from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  import tensorflow as tf
11
-
12
- # Download from Google Drive
13
- import gdown
14
 
15
  router = APIRouter()
16
 
17
- DESCRIPTION = "Random Baseline"
18
  ROUTE = "/text"
19
 
20
  @router.post(ROUTE, tags=["Text Task"],
@@ -42,10 +40,11 @@ async def evaluate_text(request: TextEvaluationRequest):
42
  "7_fossil_fuels_needed": 7
43
  }
44
 
45
- # Download our pre-trained model
46
- url = 'https://drive.google.com/uc?id=1-HWE2G6ANbd7mqILdB9DPrvF3DrKI1e4'
47
- output = 'checkpoint_epoch_5.weights.h5'
48
- gdown.download(url, output, quiet=False)
 
49
 
50
  # Load and prepare the dataset
51
  dataset = load_dataset(request.dataset_name)
@@ -66,9 +65,11 @@ async def evaluate_text(request: TextEvaluationRequest):
66
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
67
  #--------------------------------------------------------------------------------------------
68
 
69
- # Make random predictions (placeholder for actual model inference)
 
 
 
70
  true_labels = test_dataset["label"]
71
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
72
 
73
  #--------------------------------------------------------------------------------------------
74
  # YOUR MODEL INFERENCE STOPS HERE
 
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  import tensorflow as tf
11
+ from huggingface_hub import hf_hub_download
 
 
12
 
13
  router = APIRouter()
14
 
15
+ DESCRIPTION = "Electra"
16
  ROUTE = "/text"
17
 
18
  @router.post(ROUTE, tags=["Text Task"],
 
40
  "7_fossil_fuels_needed": 7
41
  }
42
 
43
+ # Download our pre-trained model from Hugging Face
44
+ model_path = hf_hub_download(repo_id="jennasparks/frugal-ai-text-electra-base", filename="checkpoint_epoch_5.weights.h5")
45
+
46
+ # Load the model
47
+ model = tf.keras.models.load_model(model_path)
48
 
49
  # Load and prepare the dataset
50
  dataset = load_dataset(request.dataset_name)
 
65
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
66
  #--------------------------------------------------------------------------------------------
67
 
68
+ # Make predictions
69
+ predictions = model.predict(test_dataset)
70
+
71
+ # Get true labels
72
  true_labels = test_dataset["label"]
 
73
 
74
  #--------------------------------------------------------------------------------------------
75
  # YOUR MODEL INFERENCE STOPS HERE