Spaces:
Sleeping
Sleeping
fix typo
Browse files- tasks/text.py +7 -5
tasks/text.py
CHANGED
@@ -35,7 +35,7 @@ class TextClassifier:
|
|
35 |
def __init__(self):
|
36 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
max_retries = 3
|
38 |
-
model_name = "
|
39 |
|
40 |
for attempt in range(max_retries):
|
41 |
try:
|
@@ -111,6 +111,7 @@ class TextClassifier:
|
|
111 |
del self.model
|
112 |
if torch.cuda.is_available():
|
113 |
torch.cuda.empty_cache()
|
|
|
114 |
|
115 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
116 |
async def evaluate_text(request: TextEvaluationRequest):
|
@@ -133,7 +134,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
133 |
|
134 |
try:
|
135 |
# Load and prepare the dataset
|
136 |
-
dataset = load_dataset(
|
137 |
|
138 |
# Convert string labels to integers
|
139 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
@@ -151,7 +152,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
151 |
classifier = TextClassifier()
|
152 |
|
153 |
# Prepare batches
|
154 |
-
batch_size =
|
155 |
quotes = test_dataset["quote"]
|
156 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
157 |
batches = [
|
@@ -163,7 +164,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
163 |
batch_results = [[] for _ in range(num_batches)]
|
164 |
|
165 |
# Process batches in parallel
|
166 |
-
max_workers = min(os.cpu_count(),
|
167 |
print(f"Processing with {max_workers} workers")
|
168 |
|
169 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
@@ -219,4 +220,5 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
219 |
|
220 |
except Exception as e:
|
221 |
print(f"Error in evaluate_text: {str(e)}")
|
222 |
-
raise Exception(f"Failed to process request: {str(e)}")
|
|
|
|
35 |
def __init__(self):
|
36 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
max_retries = 3
|
38 |
+
model_name = "answerdotai/ModernBERT-base"
|
39 |
|
40 |
for attempt in range(max_retries):
|
41 |
try:
|
|
|
111 |
del self.model
|
112 |
if torch.cuda.is_available():
|
113 |
torch.cuda.empty_cache()
|
114 |
+
|
115 |
|
116 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
117 |
async def evaluate_text(request: TextEvaluationRequest):
|
|
|
134 |
|
135 |
try:
|
136 |
# Load and prepare the dataset
|
137 |
+
dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
|
138 |
|
139 |
# Convert string labels to integers
|
140 |
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
|
|
152 |
classifier = TextClassifier()
|
153 |
|
154 |
# Prepare batches
|
155 |
+
batch_size = 24
|
156 |
quotes = test_dataset["quote"]
|
157 |
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
158 |
batches = [
|
|
|
164 |
batch_results = [[] for _ in range(num_batches)]
|
165 |
|
166 |
# Process batches in parallel
|
167 |
+
max_workers = min(os.cpu_count(), 4)
|
168 |
print(f"Processing with {max_workers} workers")
|
169 |
|
170 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
220 |
|
221 |
except Exception as e:
|
222 |
print(f"Error in evaluate_text: {str(e)}")
|
223 |
+
raise Exception(f"Failed to process request: {str(e)}")
|
224 |
+
|