Tonic commited on
Commit
1a885c6
·
unverified ·
1 Parent(s): 08f1c39

revert to reference code

Browse files
Files changed (1) hide show
  1. tasks/text.py +15 -19
tasks/text.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from fastapi import APIRouter
3
  from datetime import datetime
4
  import time
@@ -14,7 +13,7 @@ from huggingface_hub import login
14
  from dotenv import load_dotenv
15
 
16
  from .utils.evaluation import TextEvaluationRequest
17
- from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
18
 
19
  # Load environment variables
20
  load_dotenv()
@@ -29,7 +28,7 @@ os.environ["TORCH_COMPILE_DISABLE"] = "1"
29
 
30
  router = APIRouter()
31
 
32
- DESCRIPTION = "Climate Guard Toxic Agent Classifier"
33
  ROUTE = "/text"
34
 
35
  class TextClassifier:
@@ -43,13 +42,15 @@ class TextClassifier:
43
  # Load config
44
  self.config = AutoConfig.from_pretrained(
45
  model_name,
 
 
46
  trust_remote_code=True
47
  )
48
 
49
  # Initialize tokenizer
50
  self.tokenizer = AutoTokenizer.from_pretrained(
51
  model_name,
52
- model_max_length=2048,
53
  padding_side='right',
54
  truncation_side='right',
55
  trust_remote_code=True
@@ -60,15 +61,11 @@ class TextClassifier:
60
  model_name,
61
  config=self.config,
62
  trust_remote_code=True,
63
- torch_dtype=torch.float32,
64
- device_map="auto",
65
- low_cpu_mem_usage=True
66
  )
67
 
68
- # Force model to CPU if CUDA is not available
69
- if not torch.cuda.is_available():
70
- self.model = self.model.cpu()
71
-
72
  self.model.eval()
73
  print("Model initialized successfully")
74
  break
@@ -84,12 +81,12 @@ class TextClassifier:
84
  try:
85
  print(f"Processing batch {batch_idx} with {len(batch)} items")
86
 
87
- # Tokenize with smaller max length
88
  inputs = self.tokenizer(
89
  batch,
90
  return_tensors="pt",
91
  truncation=True,
92
- max_length=512, # Reduced max length
93
  padding=True
94
  )
95
 
@@ -129,14 +126,14 @@ async def evaluate_text(request: TextEvaluationRequest):
129
  "2_not_human": 2,
130
  "3_not_bad": 3,
131
  "4_solutions_harmful_unnecessary": 4,
132
- "5_science_unreliable": 5,
133
  "6_proponents_biased": 6,
134
  "7_fossil_fuels_needed": 7
135
  }
136
 
137
  try:
138
  # Load and prepare the dataset
139
- dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
140
 
141
  # Convert string labels to integers
142
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
@@ -154,7 +151,7 @@ async def evaluate_text(request: TextEvaluationRequest):
154
  classifier = TextClassifier()
155
 
156
  # Prepare batches
157
- batch_size = 24
158
  quotes = test_dataset["quote"]
159
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
160
  batches = [
@@ -166,7 +163,7 @@ async def evaluate_text(request: TextEvaluationRequest):
166
  batch_results = [[] for _ in range(num_batches)]
167
 
168
  # Process batches in parallel
169
- max_workers = min(os.cpu_count(), 4)
170
  print(f"Processing with {max_workers} workers")
171
 
172
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -222,5 +219,4 @@ async def evaluate_text(request: TextEvaluationRequest):
222
 
223
  except Exception as e:
224
  print(f"Error in evaluate_text: {str(e)}")
225
- raise Exception(f"Failed to process request: {str(e)}")
226
-
 
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  import time
 
13
  from dotenv import load_dotenv
14
 
15
  from .utils.evaluation import TextEvaluationRequest
16
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info
17
 
18
  # Load environment variables
19
  load_dotenv()
 
28
 
29
  router = APIRouter()
30
 
31
+ DESCRIPTION = "ModernBERT Climate Claims Classifier"
32
  ROUTE = "/text"
33
 
34
  class TextClassifier:
 
42
  # Load config
43
  self.config = AutoConfig.from_pretrained(
44
  model_name,
45
+ num_labels=8,
46
+ problem_type="single_label_classification",
47
  trust_remote_code=True
48
  )
49
 
50
  # Initialize tokenizer
51
  self.tokenizer = AutoTokenizer.from_pretrained(
52
  model_name,
53
+ model_max_length=8192,
54
  padding_side='right',
55
  truncation_side='right',
56
  trust_remote_code=True
 
61
  model_name,
62
  config=self.config,
63
  trust_remote_code=True,
64
+ torch_dtype=torch.float32
 
 
65
  )
66
 
67
+ # Move model to appropriate device
68
+ self.model = self.model.to(self.device)
 
 
69
  self.model.eval()
70
  print("Model initialized successfully")
71
  break
 
81
  try:
82
  print(f"Processing batch {batch_idx} with {len(batch)} items")
83
 
84
+ # Tokenize
85
  inputs = self.tokenizer(
86
  batch,
87
  return_tensors="pt",
88
  truncation=True,
89
+ max_length=512,
90
  padding=True
91
  )
92
 
 
126
  "2_not_human": 2,
127
  "3_not_bad": 3,
128
  "4_solutions_harmful_unnecessary": 4,
129
+ "5_science_is_unreliable": 5,
130
  "6_proponents_biased": 6,
131
  "7_fossil_fuels_needed": 7
132
  }
133
 
134
  try:
135
  # Load and prepare the dataset
136
+ dataset = load_dataset(request.dataset_name)
137
 
138
  # Convert string labels to integers
139
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
151
  classifier = TextClassifier()
152
 
153
  # Prepare batches
154
+ batch_size = 16 # Reduced batch size
155
  quotes = test_dataset["quote"]
156
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
157
  batches = [
 
163
  batch_results = [[] for _ in range(num_batches)]
164
 
165
  # Process batches in parallel
166
+ max_workers = min(os.cpu_count(), 2) # Reduced workers
167
  print(f"Processing with {max_workers} workers")
168
 
169
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
 
219
 
220
  except Exception as e:
221
  print(f"Error in evaluate_text: {str(e)}")
222
+ raise Exception(f"Failed to process request: {str(e)}")