Tonic commited on
Commit
7abed63
·
unverified ·
1 Parent(s): 08e3356

fix model loading

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. tasks/text.py +54 -59
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  fastapi==0.103.2
2
  uvicorn==0.23.2
3
- transformers==4.34.0
4
  torch==2.0.1
5
  datasets==2.14.5
6
  scikit-learn==1.3.1
 
1
  fastapi==0.103.2
2
  uvicorn==0.23.2
3
+ transformers #==4.34.0
4
  torch==2.0.1
5
  datasets==2.14.5
6
  scikit-learn==1.3.1
tasks/text.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
@@ -24,67 +24,85 @@ class TextClassifier:
24
  def __init__(self):
25
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
  max_retries = 3
 
 
27
  for attempt in range(max_retries):
28
  try:
29
- # Initialize tokenizer and model separately
30
- self.tokenizer = AutoTokenizer.from_pretrained("Tonic/climate-guard-toxic-agent")
31
- self.model = AutoModelForSequenceClassification.from_pretrained("Tonic/climate-guard-toxic-agent")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  self.model.to(self.device)
33
  self.model.eval()
34
  print("Model initialized successfully")
35
  break
 
36
  except Exception as e:
37
  if attempt == max_retries - 1:
38
  raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}")
39
- print(f"Attempt {attempt + 1} failed, retrying...")
40
  time.sleep(1)
41
 
42
  def predict_single(self, text: str) -> int:
43
  """Predict single text instance"""
44
  try:
45
- # Tokenize and prepare input
46
  inputs = self.tokenizer(
47
  text,
48
  return_tensors="pt",
49
  truncation=True,
50
  max_length=512,
51
- padding=True
52
  ).to(self.device)
53
 
54
  # Get prediction
55
  with torch.no_grad():
56
  outputs = self.model(**inputs)
57
- predictions = outputs.logits.argmax(-1)
58
  return predictions.item()
 
59
  except Exception as e:
60
  print(f"Error in single prediction: {str(e)}")
61
  return 0 # Return default prediction on error
62
 
63
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
64
  """Process a batch of texts and return their predictions"""
65
- max_retries = 3
66
- for attempt in range(max_retries):
67
- try:
68
- print(f"Processing batch {batch_idx} with {len(batch)} items (attempt {attempt + 1})")
69
- predictions = []
70
-
71
- # Process texts one by one for better error handling
72
- for text in batch:
73
- pred = self.predict_single(text)
74
- predictions.append(pred)
75
-
76
- if not predictions:
77
- raise Exception("No predictions generated for batch")
78
-
79
- print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
80
- return predictions, batch_idx
81
 
82
- except Exception as e:
83
- if attempt == max_retries - 1:
84
- print(f"Final error in batch {batch_idx}: {str(e)}")
85
- return [0] * len(batch), batch_idx
86
- print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}")
87
- time.sleep(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
90
  async def evaluate_text(request: TextEvaluationRequest):
@@ -119,7 +137,7 @@ async def evaluate_text(request: TextEvaluationRequest):
119
  classifier = TextClassifier()
120
 
121
  # Prepare batches
122
- batch_size = 16
123
  quotes = test_dataset["quote"]
124
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
125
  batches = [
@@ -127,35 +145,12 @@ async def evaluate_text(request: TextEvaluationRequest):
127
  for i in range(num_batches)
128
  ]
129
 
130
- # Initialize batch_results
131
- batch_results = [[] for _ in range(num_batches)]
132
-
133
- # Process batches in parallel
134
- max_workers = min(os.cpu_count(), 4)
135
- print(f"Processing with {max_workers} workers")
136
-
137
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
138
- future_to_batch = {
139
- executor.submit(classifier.process_batch, batch, idx): idx
140
- for idx, batch in enumerate(batches)
141
- }
142
-
143
- for future in future_to_batch:
144
- batch_idx = future_to_batch[future]
145
- try:
146
- predictions, idx = future.result()
147
- if predictions:
148
- batch_results[idx] = predictions
149
- print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
150
- except Exception as e:
151
- print(f"Failed to get results for batch {batch_idx}: {e}")
152
- batch_results[batch_idx] = [0] * len(batches[batch_idx])
153
-
154
- # Flatten predictions
155
  predictions = []
156
- for batch_preds in batch_results:
157
- if batch_preds is not None:
158
- predictions.extend(batch_preds)
 
159
 
160
  # Stop tracking emissions
161
  emissions_data = stop_tracking()
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
 
24
  def __init__(self):
25
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
  max_retries = 3
27
+ model_name = "Tonic/climate-guard-toxic-agent"
28
+
29
  for attempt in range(max_retries):
30
  try:
31
+ # Load config first
32
+ config = AutoConfig.from_pretrained(model_name)
33
+
34
+ # Initialize tokenizer with specific model type
35
+ self.tokenizer = AutoTokenizer.from_pretrained(
36
+ model_name,
37
+ model_max_length=512,
38
+ padding_side='right',
39
+ truncation_side='right'
40
+ )
41
+
42
+ # Initialize model with config
43
+ self.model = AutoModelForSequenceClassification.from_pretrained(
44
+ model_name,
45
+ config=config,
46
+ torch_dtype=torch.float32
47
+ )
48
+
49
  self.model.to(self.device)
50
  self.model.eval()
51
  print("Model initialized successfully")
52
  break
53
+
54
  except Exception as e:
55
  if attempt == max_retries - 1:
56
  raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}")
57
+ print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}")
58
  time.sleep(1)
59
 
60
  def predict_single(self, text: str) -> int:
61
  """Predict single text instance"""
62
  try:
63
+ # Tokenize with explicit padding and truncation
64
  inputs = self.tokenizer(
65
  text,
66
  return_tensors="pt",
67
  truncation=True,
68
  max_length=512,
69
+ padding='max_length'
70
  ).to(self.device)
71
 
72
  # Get prediction
73
  with torch.no_grad():
74
  outputs = self.model(**inputs)
75
+ predictions = torch.argmax(outputs.logits, dim=-1)
76
  return predictions.item()
77
+
78
  except Exception as e:
79
  print(f"Error in single prediction: {str(e)}")
80
  return 0 # Return default prediction on error
81
 
82
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
83
  """Process a batch of texts and return their predictions"""
84
+ try:
85
+ print(f"Processing batch {batch_idx} with {len(batch)} items")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # Process entire batch at once
88
+ inputs = self.tokenizer(
89
+ batch,
90
+ return_tensors="pt",
91
+ truncation=True,
92
+ max_length=512,
93
+ padding='max_length'
94
+ ).to(self.device)
95
+
96
+ with torch.no_grad():
97
+ outputs = self.model(**inputs)
98
+ predictions = torch.argmax(outputs.logits, dim=-1).tolist()
99
+
100
+ print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
101
+ return predictions, batch_idx
102
+
103
+ except Exception as e:
104
+ print(f"Error in batch {batch_idx}: {str(e)}")
105
+ return [0] * len(batch), batch_idx
106
 
107
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
108
  async def evaluate_text(request: TextEvaluationRequest):
 
137
  classifier = TextClassifier()
138
 
139
  # Prepare batches
140
+ batch_size = 32 # Increased batch size for efficiency
141
  quotes = test_dataset["quote"]
142
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
143
  batches = [
 
145
  for i in range(num_batches)
146
  ]
147
 
148
+ # Process batches sequentially to avoid memory issues
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  predictions = []
150
+ for idx, batch in enumerate(batches):
151
+ batch_preds, _ = classifier.process_batch(batch, idx)
152
+ predictions.extend(batch_preds)
153
+ print(f"Processed batch {idx + 1}/{num_batches}")
154
 
155
  # Stop tracking emissions
156
  emissions_data = stop_tracking()