Spaces:
Sleeping
Sleeping
complete code
Browse files- tasks/text.py +6 -7
tasks/text.py
CHANGED
@@ -7,7 +7,7 @@ import os
|
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict, Tuple
|
9 |
import torch
|
10 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
11 |
from huggingface_hub import login
|
12 |
from dotenv import load_dotenv
|
13 |
|
@@ -123,18 +123,17 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
123 |
# Load and prepare the dataset
|
124 |
dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
|
125 |
|
126 |
-
# Convert string labels to integers
|
127 |
def convert_label(example):
|
128 |
try:
|
129 |
return {"label": LABEL_MAPPING[example["label"]]}
|
130 |
-
except KeyError
|
131 |
print(f"Warning: Unknown label {example['label']}")
|
132 |
-
|
133 |
-
return {"label": 0} # or raise e if you want to fail on unknown labels
|
134 |
|
135 |
dataset = dataset.map(convert_label)
|
136 |
|
137 |
-
#
|
138 |
test_dataset = dataset["test"]
|
139 |
|
140 |
# Start tracking emissions
|
@@ -147,7 +146,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
147 |
classifier = TextClassifier()
|
148 |
|
149 |
# Prepare batches
|
150 |
-
batch_size =
|
151 |
quotes = test_dataset["quote"]
|
152 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
153 |
batches = [
|
|
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict, Tuple
|
9 |
import torch
|
10 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
11 |
from huggingface_hub import login
|
12 |
from dotenv import load_dotenv
|
13 |
|
|
|
123 |
# Load and prepare the dataset
|
124 |
dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
|
125 |
|
126 |
+
# Convert string labels to integers
|
127 |
def convert_label(example):
|
128 |
try:
|
129 |
return {"label": LABEL_MAPPING[example["label"]]}
|
130 |
+
except KeyError:
|
131 |
print(f"Warning: Unknown label {example['label']}")
|
132 |
+
return {"label": 0}
|
|
|
133 |
|
134 |
dataset = dataset.map(convert_label)
|
135 |
|
136 |
+
# Get test dataset
|
137 |
test_dataset = dataset["test"]
|
138 |
|
139 |
# Start tracking emissions
|
|
|
146 |
classifier = TextClassifier()
|
147 |
|
148 |
# Prepare batches
|
149 |
+
batch_size = 16 # Reduced batch size for better stability
|
150 |
quotes = test_dataset["quote"]
|
151 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
152 |
batches = [
|