added tokenization as preprocessing
Browse files- tasks/text.py +31 -8
tasks/text.py
CHANGED
@@ -10,6 +10,9 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
10 |
import tensorflow as tf
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
|
|
|
|
|
|
13 |
router = APIRouter()
|
14 |
|
15 |
DESCRIPTION = "Electra"
|
@@ -40,6 +43,13 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
40 |
"7_fossil_fuels_needed": 7
|
41 |
}
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
# Download our pre-trained model from Hugging Face
|
44 |
model_path = hf_hub_download(repo_id="jennasparks/frugal-ai-text-electra-base", filename="checkpoint_epoch_5.weights.h5")
|
45 |
|
@@ -49,12 +59,19 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
49 |
# Load and prepare the dataset
|
50 |
dataset = load_dataset(request.dataset_name)
|
51 |
|
52 |
-
# Convert string labels to integers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
|
|
54 |
|
55 |
# Split dataset
|
56 |
-
|
57 |
-
test_dataset = dataset["test"]
|
58 |
|
59 |
# Start tracking emissions
|
60 |
tracker.start()
|
@@ -64,12 +81,18 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
64 |
# YOUR MODEL INFERENCE CODE HERE
|
65 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
66 |
#--------------------------------------------------------------------------------------------
|
67 |
-
|
68 |
-
# Make predictions
|
69 |
-
predictions = model.predict(test_dataset)
|
70 |
|
71 |
-
#
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
#--------------------------------------------------------------------------------------------
|
75 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
|
10 |
import tensorflow as tf
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
router = APIRouter()
|
17 |
|
18 |
DESCRIPTION = "Electra"
|
|
|
43 |
"7_fossil_fuels_needed": 7
|
44 |
}
|
45 |
|
46 |
+
# Initialize tokenizer
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
|
48 |
+
|
49 |
+
def preprocess_function(examples):
|
50 |
+
return tokenizer(examples["text"],
|
51 |
+
truncation=True, padding="max_length")
|
52 |
+
|
53 |
# Download our pre-trained model from Hugging Face
|
54 |
model_path = hf_hub_download(repo_id="jennasparks/frugal-ai-text-electra-base", filename="checkpoint_epoch_5.weights.h5")
|
55 |
|
|
|
59 |
# Load and prepare the dataset
|
60 |
dataset = load_dataset(request.dataset_name)
|
61 |
|
62 |
+
# # Convert string labels to integers
|
63 |
+
# dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
64 |
+
|
65 |
+
# # Split dataset
|
66 |
+
# train_test = dataset["train"]
|
67 |
+
# test_dataset = dataset["test"]
|
68 |
+
|
69 |
+
# Convert string labels to integers and tokenize
|
70 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
71 |
+
tokenized_dataset = dataset.map(preprocess_function, batched=True)
|
72 |
|
73 |
# Split dataset
|
74 |
+
test_dataset = tokenized_dataset["test"]
|
|
|
75 |
|
76 |
# Start tracking emissions
|
77 |
tracker.start()
|
|
|
81 |
# YOUR MODEL INFERENCE CODE HERE
|
82 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
83 |
#--------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
84 |
|
85 |
+
# Added error handling
|
86 |
+
try:
|
87 |
+
# Make predictions
|
88 |
+
predictions = model.predict(test_dataset["input_ids"])
|
89 |
+
predictions = np.argmax(predictions, axis=1)
|
90 |
+
|
91 |
+
# Get true labels
|
92 |
+
true_labels = test_dataset["label"]
|
93 |
+
except Exception as e:
|
94 |
+
print(f"An error occurred during prediction: {str(e)}")
|
95 |
+
raise
|
96 |
|
97 |
#--------------------------------------------------------------------------------------------
|
98 |
# YOUR MODEL INFERENCE STOPS HERE
|