Tonic commited on
Commit
1f08781
·
unverified ·
1 Parent(s): 6af9c73

fix model loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +47 -88
tasks/text.py CHANGED
@@ -8,6 +8,8 @@ from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
@@ -18,45 +20,37 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
18
  load_dotenv()
19
 
20
  # Authenticate with Hugging Face
21
- HF_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
22
  if HF_TOKEN:
23
  login(token=HF_TOKEN)
24
 
25
- # Disable torch compile
26
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
27
-
28
  router = APIRouter()
29
 
30
- DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT fine-tuned for climate disinformation detection"
31
  ROUTE = "/text"
32
  MODEL_NAME = "Tonic/climate-guard-toxic-agent"
 
33
 
34
  class TextClassifier:
35
  def __init__(self):
36
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
  max_retries = 3
38
 
39
  for attempt in range(max_retries):
40
  try:
41
  # Initialize tokenizer
42
- self.tokenizer = AutoTokenizer.from_pretrained(
43
- MODEL_NAME,
44
- model_max_length=512,
45
- padding_side='right',
46
- truncation_side='right'
47
- )
48
 
49
- # Initialize model with basic configuration
50
  self.model = AutoModelForSequenceClassification.from_pretrained(
51
  MODEL_NAME,
52
  num_labels=8,
53
- problem_type="single_label_classification",
54
- ignore_mismatched_sizes=True,
55
  trust_remote_code=True
56
- )
57
 
58
- # Move model to device
59
- self.model = self.model.to(self.device)
 
60
 
61
  print("Model initialized successfully")
62
  break
@@ -67,34 +61,32 @@ class TextClassifier:
67
  print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}")
68
  time.sleep(1)
69
 
70
- def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
71
  """Process a batch of texts and return their predictions"""
72
  try:
73
- print(f"Processing batch {batch_idx} with {len(batch)} items")
74
-
75
- # Tokenize texts
76
  inputs = self.tokenizer(
77
- batch,
78
  padding=True,
79
  truncation=True,
80
- max_length=512,
81
  return_tensors="pt"
82
- ).to(self.device)
 
 
 
83
 
84
  # Get predictions
85
  with torch.no_grad():
86
  outputs = self.model(**inputs)
87
- predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
88
-
89
- print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
90
- return predictions.tolist(), batch_idx
91
 
92
  except Exception as e:
93
- print(f"Error in batch {batch_idx}: {str(e)}")
94
- return [0] * len(batch), batch_idx
95
 
96
  def __del__(self):
97
- # Clean up CUDA memory
98
  if hasattr(self, 'model'):
99
  del self.model
100
  if torch.cuda.is_available():
@@ -104,10 +96,8 @@ class TextClassifier:
104
  async def evaluate_text(request: TextEvaluationRequest):
105
  """Evaluate text classification for climate disinformation detection."""
106
 
107
- # Get space info
108
  username, space_url = get_space_info()
109
 
110
- # Define the label mapping
111
  LABEL_MAPPING = {
112
  "0_not_relevant": 0,
113
  "1_not_happening": 1,
@@ -120,76 +110,46 @@ async def evaluate_text(request: TextEvaluationRequest):
120
  }
121
 
122
  try:
123
- # Load and prepare the dataset
124
- dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
125
 
126
- # Convert string labels to integers
127
- def convert_label(example):
128
- try:
129
- return {"label": LABEL_MAPPING[example["label"]]}
130
- except KeyError:
131
- print(f"Warning: Unknown label {example['label']}")
132
- return {"label": 0}
133
-
134
- dataset = dataset.map(convert_label)
135
-
136
- # Get test dataset
137
  test_dataset = dataset["test"]
138
 
139
  # Start tracking emissions
140
  tracker.start()
141
  tracker.start_task("inference")
142
 
 
143
  true_labels = test_dataset["label"]
144
 
145
- # Initialize the model once
146
  classifier = TextClassifier()
147
-
148
- # Prepare batches
149
- batch_size = 16 # Reduced batch size for better stability
150
- quotes = test_dataset["quote"]
151
- num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
152
- batches = [
153
- quotes[i * batch_size:(i + 1) * batch_size]
154
- for i in range(num_batches)
155
- ]
156
-
157
- # Initialize batch_results
158
- batch_results = [[] for _ in range(num_batches)]
159
 
160
- # Process batches in parallel
161
- max_workers = min(os.cpu_count(), 4)
162
- print(f"Processing with {max_workers} workers")
163
 
164
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
165
- future_to_batch = {
166
- executor.submit(classifier.process_batch, batch, idx): idx
167
- for idx, batch in enumerate(batches)
168
- }
169
-
170
- for future in future_to_batch:
171
- batch_idx = future_to_batch[future]
172
- try:
173
- predictions, idx = future.result()
174
- if predictions:
175
- batch_results[idx] = predictions
176
- print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
177
- except Exception as e:
178
- print(f"Failed to get results for batch {batch_idx}: {e}")
179
- batch_results[batch_idx] = [0] * len(batches[batch_idx])
180
-
181
- # Flatten predictions
182
- predictions = []
183
- for batch_preds in batch_results:
184
- if batch_preds is not None:
185
- predictions.extend(batch_preds)
186
 
187
  # Stop tracking emissions
188
  emissions_data = tracker.stop_task()
189
 
190
  # Calculate accuracy
191
- accuracy = accuracy_score(true_labels, predictions)
192
- print("accuracy:", accuracy)
193
 
194
  # Prepare results
195
  results = {
@@ -209,7 +169,6 @@ async def evaluate_text(request: TextEvaluationRequest):
209
  }
210
  }
211
 
212
- print("results:", results)
213
  return results
214
 
215
  except Exception as e:
 
8
  from typing import List, Dict, Tuple
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+ from torch.utils.data import DataLoader
12
+ from transformers import DataCollatorWithPadding
13
  from huggingface_hub import login
14
  from dotenv import load_dotenv
15
 
 
20
  load_dotenv()
21
 
22
  # Authenticate with Hugging Face
23
+ HF_TOKEN = os.getenv('HF_TOKEN')
24
  if HF_TOKEN:
25
  login(token=HF_TOKEN)
26
 
 
 
 
27
  router = APIRouter()
28
 
29
+ DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT for Climate Disinformation Detection"
30
  ROUTE = "/text"
31
  MODEL_NAME = "Tonic/climate-guard-toxic-agent"
32
+ TOKENIZER_NAME = "answerdotai/ModernBERT-base"
33
 
34
  class TextClassifier:
35
  def __init__(self):
36
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  max_retries = 3
38
 
39
  for attempt in range(max_retries):
40
  try:
41
  # Initialize tokenizer
42
+ self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
 
 
 
 
 
43
 
44
+ # Initialize model
45
  self.model = AutoModelForSequenceClassification.from_pretrained(
46
  MODEL_NAME,
47
  num_labels=8,
 
 
48
  trust_remote_code=True
49
+ ).to(self.device)
50
 
51
+ # Convert to half precision
52
+ self.model = self.model.half()
53
+ self.model.eval()
54
 
55
  print("Model initialized successfully")
56
  break
 
61
  print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}")
62
  time.sleep(1)
63
 
64
+ def process_batch(self, texts: List[str]) -> List[int]:
65
  """Process a batch of texts and return their predictions"""
66
  try:
67
+ # Tokenize
 
 
68
  inputs = self.tokenizer(
69
+ texts,
70
  padding=True,
71
  truncation=True,
 
72
  return_tensors="pt"
73
+ )
74
+
75
+ # Move inputs to device
76
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
77
 
78
  # Get predictions
79
  with torch.no_grad():
80
  outputs = self.model(**inputs)
81
+ predictions = torch.argmax(outputs.logits, dim=-1)
82
+
83
+ return predictions.cpu().numpy().tolist()
 
84
 
85
  except Exception as e:
86
+ print(f"Error in batch processing: {str(e)}")
87
+ return [0] * len(texts)
88
 
89
  def __del__(self):
 
90
  if hasattr(self, 'model'):
91
  del self.model
92
  if torch.cuda.is_available():
 
96
  async def evaluate_text(request: TextEvaluationRequest):
97
  """Evaluate text classification for climate disinformation detection."""
98
 
 
99
  username, space_url = get_space_info()
100
 
 
101
  LABEL_MAPPING = {
102
  "0_not_relevant": 0,
103
  "1_not_happening": 1,
 
110
  }
111
 
112
  try:
113
+ # Load dataset
114
+ dataset = load_dataset(request.dataset_name)
115
 
116
+ # Convert labels
117
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
 
 
 
 
 
 
 
 
118
  test_dataset = dataset["test"]
119
 
120
  # Start tracking emissions
121
  tracker.start()
122
  tracker.start_task("inference")
123
 
124
+ # Get true labels
125
  true_labels = test_dataset["label"]
126
 
127
+ # Initialize model
128
  classifier = TextClassifier()
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ # Process in batches
131
+ batch_size = 16
132
+ data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
133
 
134
+ # Create DataLoader
135
+ test_loader = DataLoader(
136
+ test_dataset,
137
+ batch_size=batch_size,
138
+ collate_fn=data_collator
139
+ )
140
+
141
+ # Get predictions
142
+ all_predictions = []
143
+ for batch in test_loader:
144
+ batch_texts = batch["quote"]
145
+ batch_preds = classifier.process_batch(batch_texts)
146
+ all_predictions.extend(batch_preds)
 
 
 
 
 
 
 
 
 
147
 
148
  # Stop tracking emissions
149
  emissions_data = tracker.stop_task()
150
 
151
  # Calculate accuracy
152
+ accuracy = accuracy_score(true_labels, all_predictions)
 
153
 
154
  # Prepare results
155
  results = {
 
169
  }
170
  }
171
 
 
172
  return results
173
 
174
  except Exception as e: