Tonic commited on
Commit
dc058e1
·
unverified ·
1 Parent(s): 1a885c6
Files changed (1) hide show
  1. tasks/text.py +7 -5
tasks/text.py CHANGED
@@ -35,7 +35,7 @@ class TextClassifier:
35
  def __init__(self):
36
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
  max_retries = 3
38
- model_name = "Tonic/climate-guard-toxic-agent"
39
 
40
  for attempt in range(max_retries):
41
  try:
@@ -111,6 +111,7 @@ class TextClassifier:
111
  del self.model
112
  if torch.cuda.is_available():
113
  torch.cuda.empty_cache()
 
114
 
115
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
116
  async def evaluate_text(request: TextEvaluationRequest):
@@ -133,7 +134,7 @@ async def evaluate_text(request: TextEvaluationRequest):
133
 
134
  try:
135
  # Load and prepare the dataset
136
- dataset = load_dataset(request.dataset_name)
137
 
138
  # Convert string labels to integers
139
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
@@ -151,7 +152,7 @@ async def evaluate_text(request: TextEvaluationRequest):
151
  classifier = TextClassifier()
152
 
153
  # Prepare batches
154
- batch_size = 16 # Reduced batch size
155
  quotes = test_dataset["quote"]
156
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
157
  batches = [
@@ -163,7 +164,7 @@ async def evaluate_text(request: TextEvaluationRequest):
163
  batch_results = [[] for _ in range(num_batches)]
164
 
165
  # Process batches in parallel
166
- max_workers = min(os.cpu_count(), 2) # Reduced workers
167
  print(f"Processing with {max_workers} workers")
168
 
169
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -219,4 +220,5 @@ async def evaluate_text(request: TextEvaluationRequest):
219
 
220
  except Exception as e:
221
  print(f"Error in evaluate_text: {str(e)}")
222
- raise Exception(f"Failed to process request: {str(e)}")
 
 
35
  def __init__(self):
36
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
  max_retries = 3
38
+ model_name = "answerdotai/ModernBERT-base"
39
 
40
  for attempt in range(max_retries):
41
  try:
 
111
  del self.model
112
  if torch.cuda.is_available():
113
  torch.cuda.empty_cache()
114
+
115
 
116
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
117
  async def evaluate_text(request: TextEvaluationRequest):
 
134
 
135
  try:
136
  # Load and prepare the dataset
137
+ dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
138
 
139
  # Convert string labels to integers
140
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
152
  classifier = TextClassifier()
153
 
154
  # Prepare batches
155
+ batch_size = 24
156
  quotes = test_dataset["quote"]
157
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
158
  batches = [
 
164
  batch_results = [[] for _ in range(num_batches)]
165
 
166
  # Process batches in parallel
167
+ max_workers = min(os.cpu_count(), 4)
168
  print(f"Processing with {max_workers} workers")
169
 
170
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
 
220
 
221
  except Exception as e:
222
  print(f"Error in evaluate_text: {str(e)}")
223
+ raise Exception(f"Failed to process request: {str(e)}")
224
+