Tonic commited on
Commit
6af9c73
·
unverified ·
1 Parent(s): 4357468

complete code

Browse files
Files changed (1) hide show
  1. tasks/text.py +6 -7
tasks/text.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoConfig
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
@@ -123,18 +123,17 @@ async def evaluate_text(request: TextEvaluationRequest):
123
  # Load and prepare the dataset
124
  dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
125
 
126
- # Convert string labels to integers with error handling
127
  def convert_label(example):
128
  try:
129
  return {"label": LABEL_MAPPING[example["label"]]}
130
- except KeyError as e:
131
  print(f"Warning: Unknown label {example['label']}")
132
- # Return default label or raise exception
133
- return {"label": 0} # or raise e if you want to fail on unknown labels
134
 
135
  dataset = dataset.map(convert_label)
136
 
137
- # Split dataset
138
  test_dataset = dataset["test"]
139
 
140
  # Start tracking emissions
@@ -147,7 +146,7 @@ async def evaluate_text(request: TextEvaluationRequest):
147
  classifier = TextClassifier()
148
 
149
  # Prepare batches
150
- batch_size = 24
151
  quotes = test_dataset["quote"]
152
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
153
  batches = [
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
 
123
  # Load and prepare the dataset
124
  dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
125
 
126
+ # Convert string labels to integers
127
  def convert_label(example):
128
  try:
129
  return {"label": LABEL_MAPPING[example["label"]]}
130
+ except KeyError:
131
  print(f"Warning: Unknown label {example['label']}")
132
+ return {"label": 0}
 
133
 
134
  dataset = dataset.map(convert_label)
135
 
136
+ # Get test dataset
137
  test_dataset = dataset["test"]
138
 
139
  # Start tracking emissions
 
146
  classifier = TextClassifier()
147
 
148
  # Prepare batches
149
+ batch_size = 16 # Reduced batch size for better stability
150
  quotes = test_dataset["quote"]
151
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
152
  batches = [