Tonic commited on
Commit
c3f000b
·
unverified ·
1 Parent(s): 1f08781

fix model initialization with explicit loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +64 -48
tasks/text.py CHANGED
@@ -34,57 +34,67 @@ TOKENIZER_NAME = "answerdotai/ModernBERT-base"
34
  class TextClassifier:
35
  def __init__(self):
36
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
- max_retries = 3
38
 
39
- for attempt in range(max_retries):
40
- try:
41
- # Initialize tokenizer
42
- self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
43
-
44
- # Initialize model
45
- self.model = AutoModelForSequenceClassification.from_pretrained(
46
- MODEL_NAME,
47
- num_labels=8,
48
- trust_remote_code=True
49
- ).to(self.device)
50
-
51
- # Convert to half precision
52
- self.model = self.model.half()
53
- self.model.eval()
54
-
55
- print("Model initialized successfully")
56
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- except Exception as e:
59
- if attempt == max_retries - 1:
60
- raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}")
61
- print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}")
62
- time.sleep(1)
63
 
64
- def process_batch(self, texts: List[str]) -> List[int]:
65
  """Process a batch of texts and return their predictions"""
66
  try:
67
- # Tokenize
68
- inputs = self.tokenizer(
69
- texts,
70
- padding=True,
71
- truncation=True,
72
- return_tensors="pt"
73
- )
74
-
75
- # Move inputs to device
76
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
77
 
78
  # Get predictions
79
  with torch.no_grad():
80
- outputs = self.model(**inputs)
81
  predictions = torch.argmax(outputs.logits, dim=-1)
82
 
83
  return predictions.cpu().numpy().tolist()
84
 
85
  except Exception as e:
86
  print(f"Error in batch processing: {str(e)}")
87
- return [0] * len(texts)
88
 
89
  def __del__(self):
90
  if hasattr(self, 'model'):
@@ -121,35 +131,41 @@ async def evaluate_text(request: TextEvaluationRequest):
121
  tracker.start()
122
  tracker.start_task("inference")
123
 
124
- # Get true labels
125
- true_labels = test_dataset["label"]
126
-
127
  # Initialize model
128
  classifier = TextClassifier()
129
 
130
- # Process in batches
131
- batch_size = 16
132
- data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
 
 
 
 
 
 
 
 
 
133
 
134
  # Create DataLoader
 
135
  test_loader = DataLoader(
136
- test_dataset,
137
- batch_size=batch_size,
138
  collate_fn=data_collator
139
  )
140
 
141
  # Get predictions
142
  all_predictions = []
143
  for batch in test_loader:
144
- batch_texts = batch["quote"]
145
- batch_preds = classifier.process_batch(batch_texts)
146
  all_predictions.extend(batch_preds)
147
 
148
  # Stop tracking emissions
149
  emissions_data = tracker.stop_task()
150
 
151
  # Calculate accuracy
152
- accuracy = accuracy_score(true_labels, all_predictions)
153
 
154
  # Prepare results
155
  results = {
 
34
  class TextClassifier:
35
  def __init__(self):
36
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
37
 
38
+ try:
39
+ # Initialize tokenizer
40
+ self.tokenizer = AutoTokenizer.from_pretrained(
41
+ TOKENIZER_NAME,
42
+ model_max_length=8192,
43
+ padding_side='right',
44
+ truncation_side='right'
45
+ )
46
+
47
+ # Load model configuration
48
+ model_config = {
49
+ "architectures": ["ModernBertForSequenceClassification"],
50
+ "model_type": "modernbert",
51
+ "num_labels": 8,
52
+ "problem_type": "single_label_classification",
53
+ "hidden_size": 768,
54
+ "num_attention_heads": 12,
55
+ "num_hidden_layers": 22,
56
+ "intermediate_size": 1152,
57
+ "max_position_embeddings": 8192,
58
+ "torch_dtype": "float32",
59
+ "transformers_version": "4.48.3",
60
+ "layer_norm_eps": 1e-05
61
+ }
62
+
63
+ # Initialize model
64
+ self.model = AutoModelForSequenceClassification.from_pretrained(
65
+ MODEL_NAME,
66
+ config=model_config,
67
+ ignore_mismatched_sizes=True,
68
+ trust_remote_code=True
69
+ ).to(self.device)
70
+
71
+ # Convert to half precision
72
+ self.model = self.model.half()
73
+ self.model.eval()
74
+
75
+ print("Model initialized successfully")
76
 
77
+ except Exception as e:
78
+ print(f"Error initializing model: {str(e)}")
79
+ raise
 
 
80
 
81
+ def process_batch(self, batch):
82
  """Process a batch of texts and return their predictions"""
83
  try:
84
+ # Move batch to device
85
+ input_ids = batch['input_ids'].to(self.device)
86
+ attention_mask = batch['attention_mask'].to(self.device)
 
 
 
 
 
 
 
87
 
88
  # Get predictions
89
  with torch.no_grad():
90
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
91
  predictions = torch.argmax(outputs.logits, dim=-1)
92
 
93
  return predictions.cpu().numpy().tolist()
94
 
95
  except Exception as e:
96
  print(f"Error in batch processing: {str(e)}")
97
+ return [0] * len(batch['input_ids'])
98
 
99
  def __del__(self):
100
  if hasattr(self, 'model'):
 
131
  tracker.start()
132
  tracker.start_task("inference")
133
 
 
 
 
134
  # Initialize model
135
  classifier = TextClassifier()
136
 
137
+ # Prepare tokenization function
138
+ def preprocess_function(examples):
139
+ return classifier.tokenizer(
140
+ examples["quote"],
141
+ truncation=True,
142
+ padding=True,
143
+ max_length=512
144
+ )
145
+
146
+ # Tokenize dataset
147
+ tokenized_test = test_dataset.map(preprocess_function, batched=True)
148
+ tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
149
 
150
  # Create DataLoader
151
+ data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
152
  test_loader = DataLoader(
153
+ tokenized_test,
154
+ batch_size=16,
155
  collate_fn=data_collator
156
  )
157
 
158
  # Get predictions
159
  all_predictions = []
160
  for batch in test_loader:
161
+ batch_preds = classifier.process_batch(batch)
 
162
  all_predictions.extend(batch_preds)
163
 
164
  # Stop tracking emissions
165
  emissions_data = tracker.stop_task()
166
 
167
  # Calculate accuracy
168
+ accuracy = accuracy_score(test_dataset["label"], all_predictions)
169
 
170
  # Prepare results
171
  results = {