Tonic commited on
Commit
e6d07cd
·
unverified ·
1 Parent(s): 0ee5862

revert inference code

Browse files
Files changed (1) hide show
  1. tasks/text.py +182 -132
tasks/text.py CHANGED
@@ -1,151 +1,201 @@
1
- # tasks/text.py
2
- from fastapi import APIRouter, HTTPException
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
 
 
 
 
 
 
6
  import torch
7
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
- from torch.utils.data import Dataset, DataLoader
9
- import logging
10
 
11
  from .utils.evaluation import TextEvaluationRequest
12
- from .utils.emissions import start_tracking, stop_tracking, clean_emissions_data, get_space_info
13
 
14
- # Set up logging
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
 
18
  router = APIRouter()
19
 
20
- DESCRIPTION = "Climate Guard Toxic Agent Model"
21
  ROUTE = "/text"
22
 
23
- class TextDataset(Dataset):
24
- def __init__(self, texts, labels, tokenizer, max_len=128):
25
- self.texts = texts
26
- self.labels = labels
27
- self.tokenizer = tokenizer
28
- self.max_len = max_len
29
-
30
- def __len__(self):
31
- return len(self.texts)
32
-
33
- def __getitem__(self, idx):
34
- text = str(self.texts[idx])
35
- label = self.labels[idx]
36
-
37
- encoding = self.tokenizer(
38
- text,
39
- max_length=self.max_len,
40
- padding='max_length',
41
- truncation=True,
42
- return_tensors="pt"
43
- )
44
-
45
- return {
46
- 'input_ids': encoding['input_ids'].squeeze(0),
47
- 'attention_mask': encoding['attention_mask'].squeeze(0),
48
- 'labels': torch.tensor(label, dtype=torch.long)
49
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
 
52
  async def evaluate_text(request: TextEvaluationRequest):
53
  """
54
  Evaluate text classification for climate disinformation detection.
 
 
 
 
55
  """
56
- try:
57
- logger.info("Starting evaluation")
58
- username, space_url = get_space_info()
59
-
60
- # Label mapping
61
- LABEL_MAPPING = {
62
- "0_not_relevant": 0,
63
- "1_not_happening": 1,
64
- "2_not_human": 2,
65
- "3_not_bad": 3,
66
- "4_solutions_harmful_unnecessary": 4,
67
- "5_science_unreliable": 5,
68
- "6_proponents_biased": 6,
69
- "7_fossil_fuels_needed": 7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
71
 
72
- logger.info("Loading dataset")
73
- dataset = load_dataset(request.dataset_name)
74
- dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
75
- test_dataset = dataset["test"]
76
-
77
- logger.info("Starting emissions tracking")
78
- start_tracking()
79
-
80
- try:
81
- logger.info("Loading model and tokenizer")
82
- model_name = "Tonic/climate-guard-toxic-agent"
83
- tokenizer = AutoTokenizer.from_pretrained(model_name)
84
- model = AutoModelForSequenceClassification.from_pretrained(
85
- model_name,
86
- num_labels=len(LABEL_MAPPING)
87
- )
88
-
89
- logger.info("Preparing dataset")
90
- test_data = TextDataset(
91
- texts=test_dataset["text"],
92
- labels=test_dataset["label"],
93
- tokenizer=tokenizer
94
- )
95
-
96
- test_loader = DataLoader(test_data, batch_size=16)
97
-
98
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99
- logger.info(f"Using device: {device}")
100
- model = model.to(device)
101
- model.eval()
102
-
103
- predictions = []
104
- ground_truth = []
105
-
106
- logger.info("Running inference")
107
- with torch.no_grad():
108
- for batch in test_loader:
109
- input_ids = batch['input_ids'].to(device)
110
- attention_mask = batch['attention_mask'].to(device)
111
- labels = batch['labels'].to(device)
112
-
113
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
114
- _, predicted = torch.max(outputs.logits, 1)
115
-
116
- predictions.extend(predicted.cpu().numpy())
117
- ground_truth.extend(labels.cpu().numpy())
118
-
119
- accuracy = accuracy_score(ground_truth, predictions)
120
- logger.info(f"Accuracy: {accuracy}")
121
-
122
- emissions_data = stop_tracking()
123
-
124
- results = {
125
- "username": username,
126
- "space_url": space_url,
127
- "submission_timestamp": datetime.now().isoformat(),
128
- "model_description": DESCRIPTION,
129
- "accuracy": float(accuracy),
130
- "energy_consumed_wh": float(emissions_data.energy_consumed * 1000),
131
- "emissions_gco2eq": float(emissions_data.emissions * 1000),
132
- "emissions_data": clean_emissions_data(emissions_data.__dict__),
133
- "api_route": ROUTE,
134
- "dataset_config": {
135
- "dataset_name": request.dataset_name,
136
- "test_size": request.test_size,
137
- "test_seed": request.test_seed
138
- }
139
- }
140
-
141
- logger.info("Evaluation completed successfully")
142
- return results
143
-
144
- except Exception as e:
145
- logger.error(f"Error during evaluation: {str(e)}")
146
- stop_tracking()
147
- raise HTTPException(status_code=500, detail=str(e))
148
-
149
- except Exception as e:
150
- logger.error(f"Error in evaluate_text: {str(e)}")
151
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ from fastapi import APIRouter
 
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
+ import random
6
+ from transformers import pipeline, AutoConfig
7
+ import os
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from typing import List, Dict, Tuple
10
+ import numpy as np
11
  import torch
 
 
 
12
 
13
  from .utils.evaluation import TextEvaluationRequest
14
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info
15
 
16
+ # Disable torch compile
17
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
 
18
 
19
  router = APIRouter()
20
 
21
+ DESCRIPTION = "Random Baseline"
22
  ROUTE = "/text"
23
 
24
+ class TextClassifier:
25
+ def __init__(self):
26
+ # Add retry mechanism for model initialization
27
+ max_retries = 3
28
+ for attempt in range(max_retries):
29
+ try:
30
+ self.config = AutoConfig.from_pretrained("Tonic/climate-guard-toxic-agent")
31
+ self.label2id = self.config.label2id
32
+ self.classifier = pipeline(
33
+ "text-classification",
34
+ "Tonic/climate-guard-toxic-agent",
35
+ device="cpu",
36
+ batch_size=16
37
+ )
38
+ print("Model initialized successfully")
39
+ break
40
+ except Exception as e:
41
+ if attempt == max_retries - 1:
42
+ raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}")
43
+ print(f"Attempt {attempt + 1} failed, retrying...")
44
+ time.sleep(1)
45
+
46
+ def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
47
+ """Process a batch of texts and return their predictions"""
48
+ max_retries = 3
49
+ for attempt in range(max_retries):
50
+ try:
51
+ print(f"Processing batch {batch_idx} with {len(batch)} items (attempt {attempt + 1})")
52
+ # Process texts one by one in case of errors
53
+ predictions = []
54
+ for text in batch:
55
+ try:
56
+ pred = self.classifier(text)
57
+ pred_label = self.label2id[pred[0]["label"]]
58
+ predictions.append(pred_label)
59
+ except Exception as e:
60
+ print(f"Error processing text in batch {batch_idx}: {str(e)}")
61
+
62
+ if not predictions:
63
+ raise Exception("No predictions generated for batch")
64
+
65
+ print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
66
+ return predictions, batch_idx
67
+
68
+ except Exception as e:
69
+ if attempt == max_retries - 1:
70
+ print(f"Final error in batch {batch_idx}: {str(e)}")
71
+ return [0] * len(batch), batch_idx # Return default predictions instead of empty list
72
+ print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}")
73
+ time.sleep(1)
74
+
75
 
76
+ @router.post(ROUTE, tags=["Text Task"],
77
+ description=DESCRIPTION)
78
  async def evaluate_text(request: TextEvaluationRequest):
79
  """
80
  Evaluate text classification for climate disinformation detection.
81
+
82
+ Current Model: Random Baseline
83
+ - Makes random predictions from the label space (0-7)
84
+ - Used as a baseline for comparison
85
  """
86
+ # Get space info
87
+ username, space_url = get_space_info()
88
+
89
+ # Define the label mapping
90
+ LABEL_MAPPING = {
91
+ "0_not_relevant": 0,
92
+ "1_not_happening": 1,
93
+ "2_not_human": 2,
94
+ "3_not_bad": 3,
95
+ "4_solutions_harmful_unnecessary": 4,
96
+ "5_science_unreliable": 5,
97
+ "6_proponents_biased": 6,
98
+ "7_fossil_fuels_needed": 7
99
+ }
100
+
101
+ # Load and prepare the dataset
102
+ dataset = load_dataset(request.dataset_name)
103
+
104
+ # Convert string labels to integers
105
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
106
+
107
+ # Split dataset
108
+ train_test = dataset["train"]
109
+ test_dataset = dataset["test"]
110
+
111
+ # Start tracking emissions
112
+ tracker.start()
113
+ tracker.start_task("inference")
114
+
115
+ #--------------------------------------------------------------------------------------------
116
+ # YOUR MODEL INFERENCE CODE HERE
117
+ # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
118
+ #--------------------------------------------------------------------------------------------
119
+
120
+ true_labels = test_dataset["label"]
121
+
122
+ # Initialize the model once
123
+ classifier = TextClassifier()
124
+
125
+ # Prepare batches
126
+ batch_size = 32
127
+ quotes = test_dataset["quote"]
128
+ num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
129
+ batches = [
130
+ quotes[i * batch_size:(i + 1) * batch_size]
131
+ for i in range(num_batches)
132
+ ]
133
+
134
+ # Initialize batch_results before parallel processing
135
+ batch_results = [[] for _ in range(num_batches)]
136
+
137
+ # Process batches in parallel
138
+ max_workers = min(os.cpu_count(), 4) # Limit to 4 workers or CPU count
139
+ print(f"Processing with {max_workers} workers")
140
+
141
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
142
+ # Submit all batches for processing
143
+ future_to_batch = {
144
+ executor.submit(
145
+ classifier.process_batch,
146
+ batch,
147
+ idx
148
+ ): idx for idx, batch in enumerate(batches)
149
  }
150
 
151
+ # Collect results in order
152
+ for future in future_to_batch:
153
+ batch_idx = future_to_batch[future]
154
+ try:
155
+ predictions, idx = future.result()
156
+ if predictions: # Only store non-empty predictions
157
+ batch_results[idx] = predictions
158
+ print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
159
+ except Exception as e:
160
+ print(f"Failed to get results for batch {batch_idx}: {e}")
161
+ # Use default predictions instead of empty list
162
+ batch_results[batch_idx] = [0] * len(batches[batch_idx])
163
+
164
+ # Flatten predictions while maintaining order
165
+ predictions = []
166
+ for batch_preds in batch_results:
167
+ if batch_preds is not None:
168
+ predictions.extend(batch_preds)
169
+
170
+ #--------------------------------------------------------------------------------------------
171
+ # YOUR MODEL INFERENCE STOPS HERE
172
+ #--------------------------------------------------------------------------------------------
173
+
174
+ # Stop tracking emissions
175
+ emissions_data = tracker.stop_task()
176
+
177
+ # Calculate accuracy
178
+ accuracy = accuracy_score(true_labels, predictions)
179
+ print("accuracy : ", accuracy)
180
+
181
+ # Prepare results dictionary
182
+ results = {
183
+ "username": username,
184
+ "space_url": space_url,
185
+ "submission_timestamp": datetime.now().isoformat(),
186
+ "model_description": DESCRIPTION,
187
+ "accuracy": float(accuracy),
188
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
189
+ "emissions_gco2eq": emissions_data.emissions * 1000,
190
+ "emissions_data": clean_emissions_data(emissions_data),
191
+ "api_route": ROUTE,
192
+ "dataset_config": {
193
+ "dataset_name": request.dataset_name,
194
+ "test_size": request.test_size,
195
+ "test_seed": request.test_seed
196
+ }
197
+ }
198
+
199
+ print("results : ", results)
200
+
201
+ return results