Spaces:
Sleeping
Sleeping
fix emmissions
Browse files- tasks/text.py +6 -4
- tasks/utils/emissions.py +44 -4
tasks/text.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
11 |
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
-
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
14 |
|
15 |
# Disable torch compile
|
16 |
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
@@ -112,8 +112,9 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
112 |
test_dataset = dataset["test"]
|
113 |
|
114 |
# Start tracking emissions
|
115 |
-
|
116 |
-
|
|
|
117 |
|
118 |
true_labels = test_dataset["label"]
|
119 |
|
@@ -160,7 +161,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
160 |
predictions.extend(batch_preds)
|
161 |
|
162 |
# Stop tracking emissions
|
163 |
-
emissions_data =
|
|
|
164 |
|
165 |
# Calculate accuracy
|
166 |
accuracy = accuracy_score(true_labels, predictions)
|
|
|
10 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
11 |
|
12 |
from .utils.evaluation import TextEvaluationRequest
|
13 |
+
from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
|
14 |
|
15 |
# Disable torch compile
|
16 |
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
|
|
112 |
test_dataset = dataset["test"]
|
113 |
|
114 |
# Start tracking emissions
|
115 |
+
start_tracking()
|
116 |
+
|
117 |
+
# tracker.start_task("inference")
|
118 |
|
119 |
true_labels = test_dataset["label"]
|
120 |
|
|
|
161 |
predictions.extend(batch_preds)
|
162 |
|
163 |
# Stop tracking emissions
|
164 |
+
emissions_data = stop_tracking()
|
165 |
+
# emissions_data = tracker.stop_task()
|
166 |
|
167 |
# Calculate accuracy
|
168 |
accuracy = accuracy_score(true_labels, predictions)
|
tasks/utils/emissions.py
CHANGED
@@ -1,17 +1,34 @@
|
|
1 |
from codecarbon import EmissionsTracker
|
2 |
import os
|
3 |
|
4 |
-
# Initialize tracker
|
5 |
-
tracker = EmissionsTracker(
|
|
|
|
|
|
|
|
|
6 |
|
7 |
class EmissionsData:
|
8 |
def __init__(self, energy_consumed: float, emissions: float):
|
9 |
self.energy_consumed = energy_consumed
|
10 |
self.emissions = emissions
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def clean_emissions_data(emissions_data):
|
13 |
"""Remove unwanted fields from emissions data"""
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
fields_to_remove = ['timestamp', 'project_name', 'experiment_id', 'latitude', 'longitude']
|
16 |
return {k: v for k, v in data_dict.items() if k not in fields_to_remove}
|
17 |
|
@@ -25,4 +42,27 @@ def get_space_info():
|
|
25 |
return username, space_url
|
26 |
except Exception as e:
|
27 |
print(f"Error getting space info: {e}")
|
28 |
-
return "local-user", "local-development"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from codecarbon import EmissionsTracker
|
2 |
import os
|
3 |
|
4 |
+
# Initialize tracker with basic configuration
|
5 |
+
tracker = EmissionsTracker(
|
6 |
+
project_name="climate-guard",
|
7 |
+
output_dir=os.getenv("OUTPUT_DIR", "./"),
|
8 |
+
log_level='error'
|
9 |
+
)
|
10 |
|
11 |
class EmissionsData:
|
12 |
def __init__(self, energy_consumed: float, emissions: float):
|
13 |
self.energy_consumed = energy_consumed
|
14 |
self.emissions = emissions
|
15 |
+
self.timestamp = None
|
16 |
+
self.project_name = None
|
17 |
+
self.experiment_id = None
|
18 |
+
self.latitude = None
|
19 |
+
self.longitude = None
|
20 |
|
21 |
def clean_emissions_data(emissions_data):
|
22 |
"""Remove unwanted fields from emissions data"""
|
23 |
+
if hasattr(emissions_data, '__dict__'):
|
24 |
+
data_dict = emissions_data.__dict__
|
25 |
+
else:
|
26 |
+
# If emissions_data is not an object with __dict__
|
27 |
+
data_dict = {
|
28 |
+
'energy_consumed': getattr(emissions_data, 'energy_consumed', 0),
|
29 |
+
'emissions': getattr(emissions_data, 'emissions', 0)
|
30 |
+
}
|
31 |
+
|
32 |
fields_to_remove = ['timestamp', 'project_name', 'experiment_id', 'latitude', 'longitude']
|
33 |
return {k: v for k, v in data_dict.items() if k not in fields_to_remove}
|
34 |
|
|
|
42 |
return username, space_url
|
43 |
except Exception as e:
|
44 |
print(f"Error getting space info: {e}")
|
45 |
+
return "local-user", "local-development"
|
46 |
+
|
47 |
+
def start_tracking():
|
48 |
+
"""Safely start the emissions tracking"""
|
49 |
+
try:
|
50 |
+
if not tracker._tracking:
|
51 |
+
tracker.start()
|
52 |
+
return True
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Error starting emissions tracking: {e}")
|
55 |
+
return False
|
56 |
+
|
57 |
+
def stop_tracking():
|
58 |
+
"""Safely stop the emissions tracking and return data"""
|
59 |
+
try:
|
60 |
+
if tracker._tracking:
|
61 |
+
emissions = tracker.stop()
|
62 |
+
return EmissionsData(
|
63 |
+
energy_consumed=getattr(emissions, 'energy_consumed', 0),
|
64 |
+
emissions=getattr(emissions, 'emissions', 0)
|
65 |
+
)
|
66 |
+
except Exception as e:
|
67 |
+
print(f"Error stopping emissions tracking: {e}")
|
68 |
+
return EmissionsData(energy_consumed=0, emissions=0)
|