Tonic commited on
Commit
7eb6153
·
unverified ·
1 Parent(s): 5ad260d

switch model loading technique

Browse files
Files changed (1) hide show
  1. tasks/text.py +56 -32
tasks/text.py CHANGED
@@ -31,11 +31,6 @@ router = APIRouter()
31
  DESCRIPTION = "Climate Guard Toxic Agent Classifier"
32
  ROUTE = "/text"
33
 
34
- # Custom LayerNorm that ignores bias parameter
35
- class CustomLayerNorm(nn.LayerNorm):
36
- def __init__(self, normalized_shape, eps=1e-5, **kwargs):
37
- super().__init__(normalized_shape, eps=eps)
38
-
39
  class TextClassifier:
40
  def __init__(self):
41
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -45,33 +40,34 @@ class TextClassifier:
45
  for attempt in range(max_retries):
46
  try:
47
  # Load config
48
- self.config = AutoConfig.from_pretrained(model_name)
 
 
 
49
 
50
  # Initialize tokenizer
51
  self.tokenizer = AutoTokenizer.from_pretrained(
52
  model_name,
53
  model_max_length=2048,
54
  padding_side='right',
55
- truncation_side='right'
 
56
  )
57
 
58
- # Patch LayerNorm
59
- original_layernorm = nn.LayerNorm
60
- nn.LayerNorm = CustomLayerNorm
 
 
 
 
 
 
61
 
62
- try:
63
- # Initialize model with patched LayerNorm
64
- self.model = AutoModelForSequenceClassification.from_pretrained(
65
- model_name,
66
- config=self.config,
67
- ignore_mismatched_sizes=True,
68
- low_cpu_mem_usage=True
69
- )
70
- finally:
71
- # Restore original LayerNorm
72
- nn.LayerNorm = original_layernorm
73
 
74
- self.model.to(self.device)
75
  self.model.eval()
76
  print("Model initialized successfully")
77
  break
@@ -87,14 +83,17 @@ class TextClassifier:
87
  try:
88
  print(f"Processing batch {batch_idx} with {len(batch)} items")
89
 
90
- # Tokenize
91
  inputs = self.tokenizer(
92
  batch,
93
  return_tensors="pt",
94
  truncation=True,
95
- max_length=2048,
96
- padding='max_length'
97
- ).to(self.device)
 
 
 
98
 
99
  # Get predictions
100
  with torch.no_grad():
@@ -108,6 +107,13 @@ class TextClassifier:
108
  print(f"Error in batch {batch_idx}: {str(e)}")
109
  return [0] * len(batch), batch_idx
110
 
 
 
 
 
 
 
 
111
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
112
  async def evaluate_text(request: TextEvaluationRequest):
113
  """Evaluate text classification for climate disinformation detection."""
@@ -128,8 +134,21 @@ async def evaluate_text(request: TextEvaluationRequest):
128
  }
129
 
130
  try:
131
- # Load and prepare the dataset
132
- dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  # Convert string labels to integers
135
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
@@ -146,8 +165,8 @@ async def evaluate_text(request: TextEvaluationRequest):
146
  # Initialize the model once
147
  classifier = TextClassifier()
148
 
149
- # Prepare batches
150
- batch_size = 24
151
  quotes = test_dataset["quote"]
152
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
153
  batches = [
@@ -158,8 +177,8 @@ async def evaluate_text(request: TextEvaluationRequest):
158
  # Initialize batch_results
159
  batch_results = [[] for _ in range(num_batches)]
160
 
161
- # Process batches in parallel
162
- max_workers = min(os.cpu_count(), 4)
163
  print(f"Processing with {max_workers} workers")
164
 
165
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -192,6 +211,11 @@ async def evaluate_text(request: TextEvaluationRequest):
192
  accuracy = accuracy_score(true_labels, predictions)
193
  print("accuracy:", accuracy)
194
 
 
 
 
 
 
195
  # Prepare results
196
  results = {
197
  "username": username,
 
31
  DESCRIPTION = "Climate Guard Toxic Agent Classifier"
32
  ROUTE = "/text"
33
 
 
 
 
 
 
34
  class TextClassifier:
35
  def __init__(self):
36
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
40
  for attempt in range(max_retries):
41
  try:
42
  # Load config
43
+ self.config = AutoConfig.from_pretrained(
44
+ model_name,
45
+ trust_remote_code=True
46
+ )
47
 
48
  # Initialize tokenizer
49
  self.tokenizer = AutoTokenizer.from_pretrained(
50
  model_name,
51
  model_max_length=2048,
52
  padding_side='right',
53
+ truncation_side='right',
54
+ trust_remote_code=True
55
  )
56
 
57
+ # Initialize model
58
+ self.model = AutoModelForSequenceClassification.from_pretrained(
59
+ model_name,
60
+ config=self.config,
61
+ trust_remote_code=True,
62
+ torch_dtype=torch.float32,
63
+ device_map="auto",
64
+ low_cpu_mem_usage=True
65
+ )
66
 
67
+ # Force model to CPU if CUDA is not available
68
+ if not torch.cuda.is_available():
69
+ self.model = self.model.cpu()
 
 
 
 
 
 
 
 
70
 
 
71
  self.model.eval()
72
  print("Model initialized successfully")
73
  break
 
83
  try:
84
  print(f"Processing batch {batch_idx} with {len(batch)} items")
85
 
86
+ # Tokenize with smaller max length
87
  inputs = self.tokenizer(
88
  batch,
89
  return_tensors="pt",
90
  truncation=True,
91
+ max_length=512, # Reduced max length
92
+ padding=True
93
+ )
94
+
95
+ # Move inputs to device
96
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
97
 
98
  # Get predictions
99
  with torch.no_grad():
 
107
  print(f"Error in batch {batch_idx}: {str(e)}")
108
  return [0] * len(batch), batch_idx
109
 
110
+ def __del__(self):
111
+ # Clean up CUDA memory
112
+ if hasattr(self, 'model'):
113
+ del self.model
114
+ if torch.cuda.is_available():
115
+ torch.cuda.empty_cache()
116
+
117
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
118
  async def evaluate_text(request: TextEvaluationRequest):
119
  """Evaluate text classification for climate disinformation detection."""
 
134
  }
135
 
136
  try:
137
+ # Load and prepare the dataset with retry mechanism
138
+ max_retries = 3
139
+ for attempt in range(max_retries):
140
+ try:
141
+ dataset = load_dataset(
142
+ "QuotaClimat/frugalaichallenge-text-train",
143
+ token=HF_TOKEN,
144
+ trust_remote_code=True
145
+ )
146
+ break
147
+ except Exception as e:
148
+ if attempt == max_retries - 1:
149
+ raise Exception(f"Failed to load dataset after {max_retries} attempts: {str(e)}")
150
+ print(f"Dataset loading attempt {attempt + 1} failed, retrying... Error: {str(e)}")
151
+ time.sleep(2)
152
 
153
  # Convert string labels to integers
154
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
165
  # Initialize the model once
166
  classifier = TextClassifier()
167
 
168
+ # Prepare batches with smaller batch size
169
+ batch_size = 16 # Reduced batch size
170
  quotes = test_dataset["quote"]
171
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
172
  batches = [
 
177
  # Initialize batch_results
178
  batch_results = [[] for _ in range(num_batches)]
179
 
180
+ # Process batches in parallel with fewer workers
181
+ max_workers = min(os.cpu_count(), 2) # Reduced number of workers
182
  print(f"Processing with {max_workers} workers")
183
 
184
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
 
211
  accuracy = accuracy_score(true_labels, predictions)
212
  print("accuracy:", accuracy)
213
 
214
+ # Clean up
215
+ del classifier
216
+ if torch.cuda.is_available():
217
+ torch.cuda.empty_cache()
218
+
219
  # Prepare results
220
  results = {
221
  "username": username,