jennasparks commited on
Commit
54b86de
·
verified ·
1 Parent(s): 19efa32

changed editing style

Browse files
Files changed (1) hide show
  1. tasks/text.py +18 -62
tasks/text.py CHANGED
@@ -18,30 +18,15 @@ router = APIRouter()
18
  DESCRIPTION = "Electra_Base"
19
  ROUTE = "/text"
20
 
21
- class CustomTFDataset(tf.data.Dataset):
22
- def __init__(self, texts, labels, tokenizer, max_length=128):
23
- self.texts = texts
24
- self.labels = labels
25
- self.tokenizer = tokenizer
26
- self.max_length = max_length
27
-
28
- def __len__(self):
29
- return len(self.texts)
30
-
31
- def __iter__(self):
32
- for text, label in zip(self.texts, self.labels):
33
- encoding = self.tokenizer(
34
- text,
35
- truncation=True,
36
- padding='max_length',
37
- max_length=self.max_length,
38
- return_tensors='tf'
39
- )
40
- yield {
41
- 'input_ids': encoding['input_ids'][0],
42
- 'attention_mask': encoding['attention_mask'][0],
43
- 'label': tf.constant(label, dtype=tf.int32)
44
- }
45
 
46
  @router.post(ROUTE, tags=["Text Task"],
47
  description=DESCRIPTION)
@@ -66,41 +51,15 @@ async def evaluate_text(request: TextEvaluationRequest):
66
  "7_fossil_fuels_needed": 7
67
  }
68
 
69
- # Download pre-trained model weights and config from Hugging Face
70
- model_weights_path = hf_hub_download(repo_id="jennasparks/electra-tf", filename="tf_model.h5")
71
- model_config_path = hf_hub_download(repo_id="jennasparks/electra-tf", filename="config.json")
72
-
73
- # Load the configuration
74
- config = ElectraConfig.from_json_file(model_config_path)
75
-
76
- # Create the model with the loaded configuration
77
- model = TFElectraForSequenceClassification(config)
78
-
79
- # Load the weights
80
- model.load_weights(model_weights_path)
81
-
82
- # Load the tokenizer
83
- tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-discriminator")
84
-
85
- # Compile the model (if needed for inference)
86
- model.compile(optimizer='adam',
87
- loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
88
- metrics=['accuracy'])
89
-
90
  # Load and prepare the dataset
91
  dataset = load_dataset(request.dataset_name)
92
 
93
  # Convert string labels to integers
94
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
95
 
96
- # Tokenize the dataset
97
- def tokenize_function(examples):
98
- return tokenizer(examples["text"], padding="max_length", truncation=True)
99
-
100
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
101
-
102
- # Get the test dataset
103
- test_dataset = tokenized_dataset["test"]
104
 
105
  # Start tracking emissions
106
  tracker.start()
@@ -111,17 +70,14 @@ async def evaluate_text(request: TextEvaluationRequest):
111
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
112
  #--------------------------------------------------------------------------------------------
113
 
114
- # Add error handling
115
  try:
116
- # Prepare input for the model
117
- input_ids = tf.convert_to_tensor(test_dataset["input_ids"])
118
- attention_mask = tf.convert_to_tensor(test_dataset["attention_mask"])
119
-
120
  # Make predictions
121
- predictions = model(input_ids, attention_mask=attention_mask, training=False)
122
- predictions = tf.nn.softmax(predictions.logits, axis=-1)
123
- predictions = tf.argmax(predictions, axis=-1).numpy()
124
-
125
  # Get true labels
126
  true_labels = test_dataset["label"]
127
  except Exception as e:
 
18
  DESCRIPTION = "Electra_Base"
19
  ROUTE = "/text"
20
 
21
+ # Load model and tokenizer
22
+ model_weights_path = hf_hub_download(repo_id="jennasparks/electra-tf", filename="tf_model.h5")
23
+ model_config_path = hf_hub_download(repo_id="jennasparks/electra-tf", filename="config.json")
24
+
25
+ config = ElectraConfig.from_json_file(model_config_path)
26
+ model = TFElectraForSequenceClassification(config)
27
+ model.load_weights(model_weights_path)
28
+ tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-discriminator")
29
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @router.post(ROUTE, tags=["Text Task"],
32
  description=DESCRIPTION)
 
51
  "7_fossil_fuels_needed": 7
52
  }
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Load and prepare the dataset
55
  dataset = load_dataset(request.dataset_name)
56
 
57
  # Convert string labels to integers
58
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
59
 
60
+ # Split dataset
61
+ train_test = dataset["train"]
62
+ test_dataset = dataset["test"]
 
 
 
 
 
63
 
64
  # Start tracking emissions
65
  tracker.start()
 
70
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
71
  #--------------------------------------------------------------------------------------------
72
 
 
73
  try:
74
+ # Tokenize the input texts
75
+ encoded_input = tokenizer(test_dataset["text"], truncation=True, padding=True, return_tensors="tf")
76
+
 
77
  # Make predictions
78
+ outputs = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"], training=False)
79
+ predictions = tf.argmax(outputs.logits, axis=1).numpy()
80
+
 
81
  # Get true labels
82
  true_labels = test_dataset["label"]
83
  except Exception as e: