Tonic commited on
Commit
f3f30d7
·
unverified ·
1 Parent(s): 4477f42

fix emmissions

Browse files
Files changed (2) hide show
  1. tasks/text.py +6 -4
  2. tasks/utils/emissions.py +44 -4
tasks/text.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
- from .utils.emissions import tracker, clean_emissions_data, get_space_info
14
 
15
  # Disable torch compile
16
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
@@ -112,8 +112,9 @@ async def evaluate_text(request: TextEvaluationRequest):
112
  test_dataset = dataset["test"]
113
 
114
  # Start tracking emissions
115
- tracker.start()
116
- tracker.start_task("inference")
 
117
 
118
  true_labels = test_dataset["label"]
119
 
@@ -160,7 +161,8 @@ async def evaluate_text(request: TextEvaluationRequest):
160
  predictions.extend(batch_preds)
161
 
162
  # Stop tracking emissions
163
- emissions_data = tracker.stop_task()
 
164
 
165
  # Calculate accuracy
166
  accuracy = accuracy_score(true_labels, predictions)
 
10
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
14
 
15
  # Disable torch compile
16
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
 
112
  test_dataset = dataset["test"]
113
 
114
  # Start tracking emissions
115
+ start_tracking()
116
+
117
+ # tracker.start_task("inference")
118
 
119
  true_labels = test_dataset["label"]
120
 
 
161
  predictions.extend(batch_preds)
162
 
163
  # Stop tracking emissions
164
+ emissions_data = stop_tracking()
165
+ # emissions_data = tracker.stop_task()
166
 
167
  # Calculate accuracy
168
  accuracy = accuracy_score(true_labels, predictions)
tasks/utils/emissions.py CHANGED
@@ -1,17 +1,34 @@
1
  from codecarbon import EmissionsTracker
2
  import os
3
 
4
- # Initialize tracker
5
- tracker = EmissionsTracker(allow_multiple_runs=True)
 
 
 
 
6
 
7
  class EmissionsData:
8
  def __init__(self, energy_consumed: float, emissions: float):
9
  self.energy_consumed = energy_consumed
10
  self.emissions = emissions
 
 
 
 
 
11
 
12
  def clean_emissions_data(emissions_data):
13
  """Remove unwanted fields from emissions data"""
14
- data_dict = emissions_data.__dict__
 
 
 
 
 
 
 
 
15
  fields_to_remove = ['timestamp', 'project_name', 'experiment_id', 'latitude', 'longitude']
16
  return {k: v for k, v in data_dict.items() if k not in fields_to_remove}
17
 
@@ -25,4 +42,27 @@ def get_space_info():
25
  return username, space_url
26
  except Exception as e:
27
  print(f"Error getting space info: {e}")
28
- return "local-user", "local-development"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from codecarbon import EmissionsTracker
2
  import os
3
 
4
+ # Initialize tracker with basic configuration
5
+ tracker = EmissionsTracker(
6
+ project_name="climate-guard",
7
+ output_dir=os.getenv("OUTPUT_DIR", "./"),
8
+ log_level='error'
9
+ )
10
 
11
  class EmissionsData:
12
  def __init__(self, energy_consumed: float, emissions: float):
13
  self.energy_consumed = energy_consumed
14
  self.emissions = emissions
15
+ self.timestamp = None
16
+ self.project_name = None
17
+ self.experiment_id = None
18
+ self.latitude = None
19
+ self.longitude = None
20
 
21
  def clean_emissions_data(emissions_data):
22
  """Remove unwanted fields from emissions data"""
23
+ if hasattr(emissions_data, '__dict__'):
24
+ data_dict = emissions_data.__dict__
25
+ else:
26
+ # If emissions_data is not an object with __dict__
27
+ data_dict = {
28
+ 'energy_consumed': getattr(emissions_data, 'energy_consumed', 0),
29
+ 'emissions': getattr(emissions_data, 'emissions', 0)
30
+ }
31
+
32
  fields_to_remove = ['timestamp', 'project_name', 'experiment_id', 'latitude', 'longitude']
33
  return {k: v for k, v in data_dict.items() if k not in fields_to_remove}
34
 
 
42
  return username, space_url
43
  except Exception as e:
44
  print(f"Error getting space info: {e}")
45
+ return "local-user", "local-development"
46
+
47
+ def start_tracking():
48
+ """Safely start the emissions tracking"""
49
+ try:
50
+ if not tracker._tracking:
51
+ tracker.start()
52
+ return True
53
+ except Exception as e:
54
+ print(f"Error starting emissions tracking: {e}")
55
+ return False
56
+
57
+ def stop_tracking():
58
+ """Safely stop the emissions tracking and return data"""
59
+ try:
60
+ if tracker._tracking:
61
+ emissions = tracker.stop()
62
+ return EmissionsData(
63
+ energy_consumed=getattr(emissions, 'energy_consumed', 0),
64
+ emissions=getattr(emissions, 'emissions', 0)
65
+ )
66
+ except Exception as e:
67
+ print(f"Error stopping emissions tracking: {e}")
68
+ return EmissionsData(energy_consumed=0, emissions=0)