Tonic commited on
Commit
4477f42
·
unverified ·
1 Parent(s): 71340db

fix transformers

Browse files
Files changed (1) hide show
  1. tasks/text.py +47 -60
tasks/text.py CHANGED
@@ -1,14 +1,13 @@
1
  from fastapi import APIRouter
2
  from datetime import datetime
 
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
- import random
6
- from transformers import pipeline, AutoConfig
7
  import os
8
  from concurrent.futures import ThreadPoolExecutor
9
  from typing import List, Dict, Tuple
10
- import numpy as np
11
  import torch
 
12
 
13
  from .utils.evaluation import TextEvaluationRequest
14
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
@@ -18,23 +17,23 @@ os.environ["TORCH_COMPILE_DISABLE"] = "1"
18
 
19
  router = APIRouter()
20
 
21
- DESCRIPTION = "Random Baseline"
22
  ROUTE = "/text"
23
 
24
  class TextClassifier:
25
  def __init__(self):
26
- # Add retry mechanism for model initialization
27
  max_retries = 3
28
  for attempt in range(max_retries):
29
  try:
30
- self.config = AutoConfig.from_pretrained("Tonic/climate-guard-toxic-agent")
31
- self.label2id = self.config.label2id
32
- self.classifier = pipeline(
33
- "text-classification",
34
- "Tonic/climate-guard-toxic-agent",
35
- device="cpu",
36
- batch_size=16
37
  )
 
38
  print("Model initialized successfully")
39
  break
40
  except Exception as e:
@@ -43,21 +42,37 @@ class TextClassifier:
43
  print(f"Attempt {attempt + 1} failed, retrying...")
44
  time.sleep(1)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
47
  """Process a batch of texts and return their predictions"""
48
  max_retries = 3
49
  for attempt in range(max_retries):
50
  try:
51
  print(f"Processing batch {batch_idx} with {len(batch)} items (attempt {attempt + 1})")
52
- # Process texts one by one in case of errors
53
  predictions = []
 
 
54
  for text in batch:
55
- try:
56
- pred = self.classifier(text)
57
- pred_label = self.label2id[pred[0]["label"]]
58
- predictions.append(pred_label)
59
- except Exception as e:
60
- print(f"Error processing text in batch {batch_idx}: {str(e)}")
61
 
62
  if not predictions:
63
  raise Exception("No predictions generated for batch")
@@ -68,21 +83,14 @@ class TextClassifier:
68
  except Exception as e:
69
  if attempt == max_retries - 1:
70
  print(f"Final error in batch {batch_idx}: {str(e)}")
71
- return [0] * len(batch), batch_idx # Return default predictions instead of empty list
72
  print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}")
73
  time.sleep(1)
74
 
75
-
76
- @router.post(ROUTE, tags=["Text Task"],
77
- description=DESCRIPTION)
78
  async def evaluate_text(request: TextEvaluationRequest):
79
- """
80
- Evaluate text classification for climate disinformation detection.
81
 
82
- Current Model: Random Baseline
83
- - Makes random predictions from the label space (0-7)
84
- - Used as a baseline for comparison
85
- """
86
  # Get space info
87
  username, space_url = get_space_info()
88
 
@@ -100,30 +108,20 @@ async def evaluate_text(request: TextEvaluationRequest):
100
 
101
  # Load and prepare the dataset
102
  dataset = load_dataset(request.dataset_name)
103
-
104
- # Convert string labels to integers
105
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
106
-
107
- # Split dataset
108
- train_test = dataset["train"]
109
  test_dataset = dataset["test"]
110
 
111
  # Start tracking emissions
112
  tracker.start()
113
  tracker.start_task("inference")
114
 
115
- #--------------------------------------------------------------------------------------------
116
- # YOUR MODEL INFERENCE CODE HERE
117
- # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
118
- #--------------------------------------------------------------------------------------------
119
-
120
  true_labels = test_dataset["label"]
121
 
122
  # Initialize the model once
123
  classifier = TextClassifier()
124
 
125
  # Prepare batches
126
- batch_size = 32
127
  quotes = test_dataset["quote"]
128
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
129
  batches = [
@@ -131,54 +129,44 @@ async def evaluate_text(request: TextEvaluationRequest):
131
  for i in range(num_batches)
132
  ]
133
 
134
- # Initialize batch_results before parallel processing
135
  batch_results = [[] for _ in range(num_batches)]
136
 
137
  # Process batches in parallel
138
- max_workers = min(os.cpu_count(), 4) # Limit to 4 workers or CPU count
139
  print(f"Processing with {max_workers} workers")
140
 
141
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
142
- # Submit all batches for processing
143
  future_to_batch = {
144
- executor.submit(
145
- classifier.process_batch,
146
- batch,
147
- idx
148
- ): idx for idx, batch in enumerate(batches)
149
  }
150
 
151
- # Collect results in order
152
  for future in future_to_batch:
153
  batch_idx = future_to_batch[future]
154
  try:
155
  predictions, idx = future.result()
156
- if predictions: # Only store non-empty predictions
157
  batch_results[idx] = predictions
158
  print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
159
  except Exception as e:
160
  print(f"Failed to get results for batch {batch_idx}: {e}")
161
- # Use default predictions instead of empty list
162
  batch_results[batch_idx] = [0] * len(batches[batch_idx])
163
 
164
- # Flatten predictions while maintaining order
165
  predictions = []
166
  for batch_preds in batch_results:
167
  if batch_preds is not None:
168
  predictions.extend(batch_preds)
169
-
170
- #--------------------------------------------------------------------------------------------
171
- # YOUR MODEL INFERENCE STOPS HERE
172
- #--------------------------------------------------------------------------------------------
173
 
174
  # Stop tracking emissions
175
  emissions_data = tracker.stop_task()
176
 
177
  # Calculate accuracy
178
  accuracy = accuracy_score(true_labels, predictions)
179
- print("accuracy : ", accuracy)
180
 
181
- # Prepare results dictionary
182
  results = {
183
  "username": username,
184
  "space_url": space_url,
@@ -196,6 +184,5 @@ async def evaluate_text(request: TextEvaluationRequest):
196
  }
197
  }
198
 
199
- print("results : ", results)
200
-
201
  return results
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
+ import time
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
 
 
6
  import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
 
9
  import torch
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
17
 
18
  router = APIRouter()
19
 
20
+ DESCRIPTION = "Climate Guard Toxic Agent Classifier"
21
  ROUTE = "/text"
22
 
23
  class TextClassifier:
24
  def __init__(self):
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
  max_retries = 3
27
  for attempt in range(max_retries):
28
  try:
29
+ # Load model and tokenizer directly instead of using pipeline
30
+ self.model = AutoModelForSequenceClassification.from_pretrained(
31
+ "Tonic/climate-guard-toxic-agent"
32
+ ).to(self.device)
33
+ self.tokenizer = AutoTokenizer.from_pretrained(
34
+ "Tonic/climate-guard-toxic-agent"
 
35
  )
36
+ self.model.eval() # Set to evaluation mode
37
  print("Model initialized successfully")
38
  break
39
  except Exception as e:
 
42
  print(f"Attempt {attempt + 1} failed, retrying...")
43
  time.sleep(1)
44
 
45
+ def predict_single(self, text: str) -> int:
46
+ """Predict single text instance"""
47
+ try:
48
+ inputs = self.tokenizer(
49
+ text,
50
+ return_tensors="pt",
51
+ truncation=True,
52
+ max_length=512,
53
+ padding=True
54
+ ).to(self.device)
55
+
56
+ with torch.no_grad():
57
+ outputs = self.model(**inputs)
58
+ predictions = outputs.logits.argmax(-1)
59
+ return predictions.item()
60
+ except Exception as e:
61
+ print(f"Error in single prediction: {str(e)}")
62
+ return 0 # Return default prediction on error
63
+
64
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
65
  """Process a batch of texts and return their predictions"""
66
  max_retries = 3
67
  for attempt in range(max_retries):
68
  try:
69
  print(f"Processing batch {batch_idx} with {len(batch)} items (attempt {attempt + 1})")
 
70
  predictions = []
71
+
72
+ # Process texts one by one for better error handling
73
  for text in batch:
74
+ pred = self.predict_single(text)
75
+ predictions.append(pred)
 
 
 
 
76
 
77
  if not predictions:
78
  raise Exception("No predictions generated for batch")
 
83
  except Exception as e:
84
  if attempt == max_retries - 1:
85
  print(f"Final error in batch {batch_idx}: {str(e)}")
86
+ return [0] * len(batch), batch_idx
87
  print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}")
88
  time.sleep(1)
89
 
90
+ @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
 
 
91
  async def evaluate_text(request: TextEvaluationRequest):
92
+ """Evaluate text classification for climate disinformation detection."""
 
93
 
 
 
 
 
94
  # Get space info
95
  username, space_url = get_space_info()
96
 
 
108
 
109
  # Load and prepare the dataset
110
  dataset = load_dataset(request.dataset_name)
 
 
111
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
 
 
112
  test_dataset = dataset["test"]
113
 
114
  # Start tracking emissions
115
  tracker.start()
116
  tracker.start_task("inference")
117
 
 
 
 
 
 
118
  true_labels = test_dataset["label"]
119
 
120
  # Initialize the model once
121
  classifier = TextClassifier()
122
 
123
  # Prepare batches
124
+ batch_size = 16 # Reduced batch size for better memory management
125
  quotes = test_dataset["quote"]
126
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
127
  batches = [
 
129
  for i in range(num_batches)
130
  ]
131
 
132
+ # Initialize batch_results
133
  batch_results = [[] for _ in range(num_batches)]
134
 
135
  # Process batches in parallel
136
+ max_workers = min(os.cpu_count(), 4)
137
  print(f"Processing with {max_workers} workers")
138
 
139
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
 
140
  future_to_batch = {
141
+ executor.submit(classifier.process_batch, batch, idx): idx
142
+ for idx, batch in enumerate(batches)
 
 
 
143
  }
144
 
 
145
  for future in future_to_batch:
146
  batch_idx = future_to_batch[future]
147
  try:
148
  predictions, idx = future.result()
149
+ if predictions:
150
  batch_results[idx] = predictions
151
  print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
152
  except Exception as e:
153
  print(f"Failed to get results for batch {batch_idx}: {e}")
 
154
  batch_results[batch_idx] = [0] * len(batches[batch_idx])
155
 
156
+ # Flatten predictions
157
  predictions = []
158
  for batch_preds in batch_results:
159
  if batch_preds is not None:
160
  predictions.extend(batch_preds)
 
 
 
 
161
 
162
  # Stop tracking emissions
163
  emissions_data = tracker.stop_task()
164
 
165
  # Calculate accuracy
166
  accuracy = accuracy_score(true_labels, predictions)
167
+ print("accuracy:", accuracy)
168
 
169
+ # Prepare results
170
  results = {
171
  "username": username,
172
  "space_url": space_url,
 
184
  }
185
  }
186
 
187
+ print("results:", results)
 
188
  return results