Tonic commited on
Commit
2c8310a
·
unverified ·
1 Parent(s): b206095

revert to template

Browse files
Files changed (1) hide show
  1. tasks/text.py +95 -118
tasks/text.py CHANGED
@@ -1,92 +1,30 @@
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
- import time
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
6
- import os
7
- from concurrent.futures import ThreadPoolExecutor
8
- from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification
11
  from torch.utils.data import DataLoader
12
  from transformers import DataCollatorWithPadding
13
- from huggingface_hub import login
14
- from dotenv import load_dotenv
15
 
16
  from .utils.evaluation import TextEvaluationRequest
17
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
18
 
19
- # Load environment variables
20
- load_dotenv()
21
-
22
- # Authenticate with Hugging Face
23
- HF_TOKEN = os.getenv('HF_TOKEN')
24
- if HF_TOKEN:
25
- login(token=HF_TOKEN)
26
-
27
  router = APIRouter()
28
 
29
- DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT for Climate Disinformation Detection"
30
  ROUTE = "/text"
31
- MODEL_NAME = "Tonic/climate-guard-toxic-agent"
32
- TOKENIZER_NAME = "answerdotai/ModernBERT-base"
33
-
34
- class TextClassifier:
35
- def __init__(self):
36
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
-
38
- try:
39
- # Initialize tokenizer
40
- self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
41
-
42
- # Initialize model with auto class
43
- self.model = AutoModelForSequenceClassification.from_pretrained(
44
- MODEL_NAME,
45
- trust_remote_code=True,
46
- num_labels=8,
47
- problem_type="single_label_classification",
48
- ignore_mismatched_sizes=True
49
- ).to(self.device)
50
-
51
- # Convert to half precision and eval mode
52
- self.model = self.model.half()
53
- self.model.eval()
54
-
55
- print("Model initialized successfully")
56
-
57
- except Exception as e:
58
- print(f"Error initializing model: {str(e)}")
59
- raise
60
-
61
- def process_batch(self, batch):
62
- try:
63
- # Move batch to device
64
- input_ids = batch['input_ids'].to(self.device)
65
- attention_mask = batch['attention_mask'].to(self.device)
66
-
67
- # Get predictions
68
- with torch.no_grad():
69
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
70
- predictions = torch.argmax(outputs.logits, dim=-1)
71
-
72
- return predictions.cpu().numpy().tolist()
73
-
74
- except Exception as e:
75
- print(f"Error in batch processing: {str(e)}")
76
- return [0] * len(batch['input_ids'])
77
-
78
- def __del__(self):
79
- if hasattr(self, 'model'):
80
- del self.model
81
- if torch.cuda.is_available():
82
- torch.cuda.empty_cache()
83
 
84
- @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
 
85
  async def evaluate_text(request: TextEvaluationRequest):
86
- """Evaluate text classification for climate disinformation detection."""
87
-
 
 
88
  username, space_url = get_space_info()
89
 
 
90
  LABEL_MAPPING = {
91
  "0_not_relevant": 0,
92
  "1_not_happening": 1,
@@ -98,29 +36,51 @@ async def evaluate_text(request: TextEvaluationRequest):
98
  "7_fossil_fuels_needed": 7
99
  }
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  try:
102
- # Load dataset
103
- dataset = load_dataset(request.dataset_name)
104
 
105
- # Convert labels
106
- dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
107
- test_dataset = dataset["test"]
108
 
109
- # Start tracking emissions
110
- tracker.start()
111
- tracker.start_task("inference")
112
-
113
- # Initialize model
114
- classifier = TextClassifier()
 
 
 
 
 
 
115
 
116
- # Prepare tokenization function
117
  def preprocess_function(examples):
118
- return classifier.tokenizer(
119
  examples["quote"],
120
  truncation=True,
121
  padding=True,
122
  max_length=512,
123
- return_tensors=None # Changed this to None for batched processing
124
  )
125
 
126
  # Tokenize dataset
@@ -134,7 +94,7 @@ async def evaluate_text(request: TextEvaluationRequest):
134
  tokenized_test.set_format("torch")
135
 
136
  # Create DataLoader
137
- data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
138
  test_loader = DataLoader(
139
  tokenized_test,
140
  batch_size=16,
@@ -143,37 +103,54 @@ async def evaluate_text(request: TextEvaluationRequest):
143
  )
144
 
145
  # Get predictions
146
- all_predictions = []
147
- for batch in test_loader:
148
- batch_preds = classifier.process_batch(batch)
149
- all_predictions.extend(batch_preds)
150
-
151
- # Stop tracking emissions
152
- emissions_data = tracker.stop_task()
153
-
154
- # Calculate accuracy
155
- accuracy = accuracy_score(test_dataset["label"], all_predictions)
156
-
157
- # Prepare results
158
- results = {
159
- "username": username,
160
- "space_url": space_url,
161
- "submission_timestamp": datetime.now().isoformat(),
162
- "model_description": DESCRIPTION,
163
- "accuracy": float(accuracy),
164
- "energy_consumed_wh": emissions_data.energy_consumed * 1000,
165
- "emissions_gco2eq": emissions_data.emissions * 1000,
166
- "emissions_data": clean_emissions_data(emissions_data),
167
- "api_route": ROUTE,
168
- "dataset_config": {
169
- "dataset_name": request.dataset_name,
170
- "test_size": request.test_size,
171
- "test_seed": request.test_seed
172
- }
173
- }
174
 
175
- return results
 
 
176
 
177
- except Exception as e:
178
- print(f"Error in evaluate_text: {str(e)}")
179
- raise Exception(f"Failed to process request: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
 
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
 
 
 
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
  from torch.utils.data import DataLoader
8
  from transformers import DataCollatorWithPadding
 
 
9
 
10
  from .utils.evaluation import TextEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
 
 
 
 
 
 
 
 
13
  router = APIRouter()
14
 
15
+ DESCRIPTION = "ModernBERT for Climate Disinformation Detection"
16
  ROUTE = "/text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ @router.post(ROUTE, tags=["Text Task"],
19
+ description=DESCRIPTION)
20
  async def evaluate_text(request: TextEvaluationRequest):
21
+ """
22
+ Evaluate text classification for climate disinformation detection using ModernBERT.
23
+ """
24
+ # Get space info
25
  username, space_url = get_space_info()
26
 
27
+ # Define the label mapping
28
  LABEL_MAPPING = {
29
  "0_not_relevant": 0,
30
  "1_not_happening": 1,
 
36
  "7_fossil_fuels_needed": 7
37
  }
38
 
39
+ # Load and prepare the dataset
40
+ dataset = load_dataset(request.dataset_name)
41
+
42
+ # Convert string labels to integers
43
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
44
+
45
+ # Get test dataset
46
+ test_dataset = dataset["test"]
47
+
48
+ # Start tracking emissions
49
+ tracker.start()
50
+ tracker.start_task("inference")
51
+
52
+ #--------------------------------------------------------------------------------------------
53
+ # MODEL INFERENCE CODE
54
+ #--------------------------------------------------------------------------------------------
55
+
56
  try:
57
+ # Set device
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
 
60
+ # Initialize tokenizer
61
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
 
62
 
63
+ # Initialize model with configuration that avoids bias parameter
64
+ model = AutoModelForSequenceClassification.from_pretrained(
65
+ "Tonic/climate-guard-toxic-agent",
66
+ trust_remote_code=True,
67
+ num_labels=8,
68
+ problem_type="single_label_classification",
69
+ ignore_mismatched_sizes=True,
70
+ torch_dtype=torch.float16 # Use float16 for efficiency
71
+ ).to(device)
72
+
73
+ # Set model to evaluation mode
74
+ model.eval()
75
 
76
+ # Tokenize function
77
  def preprocess_function(examples):
78
+ return tokenizer(
79
  examples["quote"],
80
  truncation=True,
81
  padding=True,
82
  max_length=512,
83
+ return_tensors=None
84
  )
85
 
86
  # Tokenize dataset
 
94
  tokenized_test.set_format("torch")
95
 
96
  # Create DataLoader
97
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
98
  test_loader = DataLoader(
99
  tokenized_test,
100
  batch_size=16,
 
103
  )
104
 
105
  # Get predictions
106
+ predictions = []
107
+ with torch.no_grad():
108
+ for batch in test_loader:
109
+ # Move batch to device
110
+ input_ids = batch['input_ids'].to(device)
111
+ attention_mask = batch['attention_mask'].to(device)
112
+
113
+ # Get model outputs
114
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
115
+ preds = torch.argmax(outputs.logits, dim=-1)
116
+
117
+ # Add batch predictions to list
118
+ predictions.extend(preds.cpu().numpy().tolist())
119
+
120
+ # Clean up GPU memory
121
+ if torch.cuda.is_available():
122
+ torch.cuda.empty_cache()
123
+
124
+ except Exception as e:
125
+ print(f"Error during model inference: {str(e)}")
126
+ raise
 
 
 
 
 
 
 
127
 
128
+ #--------------------------------------------------------------------------------------------
129
+ # MODEL INFERENCE ENDS HERE
130
+ #--------------------------------------------------------------------------------------------
131
 
132
+ # Stop tracking emissions
133
+ emissions_data = tracker.stop_task()
134
+
135
+ # Calculate accuracy
136
+ accuracy = accuracy_score(test_dataset["label"], predictions)
137
+
138
+ # Prepare results dictionary
139
+ results = {
140
+ "username": username,
141
+ "space_url": space_url,
142
+ "submission_timestamp": datetime.now().isoformat(),
143
+ "model_description": DESCRIPTION,
144
+ "accuracy": float(accuracy),
145
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
146
+ "emissions_gco2eq": emissions_data.emissions * 1000,
147
+ "emissions_data": clean_emissions_data(emissions_data),
148
+ "api_route": ROUTE,
149
+ "dataset_config": {
150
+ "dataset_name": request.dataset_name,
151
+ "test_size": request.test_size,
152
+ "test_seed": request.test_seed
153
+ }
154
+ }
155
+
156
+ return results