Tonic commited on
Commit
4357468
·
verified ·
1 Parent(s): 520f169

revert to model loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +37 -44
tasks/text.py CHANGED
@@ -27,7 +27,7 @@ os.environ["TORCH_COMPILE_DISABLE"] = "1"
27
 
28
  router = APIRouter()
29
 
30
- DESCRIPTION = "ModernBERT fine-tuned for climate disinformation detection"
31
  ROUTE = "/text"
32
  MODEL_NAME = "Tonic/climate-guard-toxic-agent"
33
 
@@ -38,7 +38,7 @@ class TextClassifier:
38
 
39
  for attempt in range(max_retries):
40
  try:
41
- # Initialize tokenizer first
42
  self.tokenizer = AutoTokenizer.from_pretrained(
43
  MODEL_NAME,
44
  model_max_length=512,
@@ -46,17 +46,18 @@ class TextClassifier:
46
  truncation_side='right'
47
  )
48
 
49
- # Use pipeline directly without modifying config
50
- self.classifier = pipeline(
51
- "text-classification",
52
- model=MODEL_NAME,
53
- tokenizer=self.tokenizer,
54
- device=self.device,
55
- max_length=512,
56
- truncation=True,
57
- batch_size=16
58
  )
59
 
 
 
 
60
  print("Model initialized successfully")
61
  break
62
 
@@ -68,42 +69,34 @@ class TextClassifier:
68
 
69
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
70
  """Process a batch of texts and return their predictions"""
71
- max_retries = 3
72
- for attempt in range(max_retries):
73
- try:
74
- print(f"Processing batch {batch_idx} with {len(batch)} items")
75
-
76
- # Process texts with error handling
77
- predictions = []
78
- for text in batch:
79
- try:
80
- result = self.classifier(text)
81
- # Extract the numeric label from the prediction
82
- label_str = result[0]['label']
83
- # Handle both numeric and string label formats
84
- if '_' in label_str:
85
- pred_label = int(label_str.split('_')[0])
86
- else:
87
- pred_label = int(label_str)
88
- predictions.append(pred_label)
89
- except Exception as e:
90
- print(f"Error processing text in batch {batch_idx}: {str(e)}")
91
- predictions.append(0) # Default prediction
92
-
93
- print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
94
- return predictions, batch_idx
95
-
96
- except Exception as e:
97
- if attempt == max_retries - 1:
98
- print(f"Final error in batch {batch_idx}: {str(e)}")
99
- return [0] * len(batch), batch_idx
100
- print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}")
101
- time.sleep(1)
102
 
103
  def __del__(self):
104
  # Clean up CUDA memory
105
- if hasattr(self, 'classifier'):
106
- del self.classifier
107
  if torch.cuda.is_available():
108
  torch.cuda.empty_cache()
109
 
 
27
 
28
  router = APIRouter()
29
 
30
+ DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT fine-tuned for climate disinformation detection"
31
  ROUTE = "/text"
32
  MODEL_NAME = "Tonic/climate-guard-toxic-agent"
33
 
 
38
 
39
  for attempt in range(max_retries):
40
  try:
41
+ # Initialize tokenizer
42
  self.tokenizer = AutoTokenizer.from_pretrained(
43
  MODEL_NAME,
44
  model_max_length=512,
 
46
  truncation_side='right'
47
  )
48
 
49
+ # Initialize model with basic configuration
50
+ self.model = AutoModelForSequenceClassification.from_pretrained(
51
+ MODEL_NAME,
52
+ num_labels=8,
53
+ problem_type="single_label_classification",
54
+ ignore_mismatched_sizes=True,
55
+ trust_remote_code=True
 
 
56
  )
57
 
58
+ # Move model to device
59
+ self.model = self.model.to(self.device)
60
+
61
  print("Model initialized successfully")
62
  break
63
 
 
69
 
70
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
71
  """Process a batch of texts and return their predictions"""
72
+ try:
73
+ print(f"Processing batch {batch_idx} with {len(batch)} items")
74
+
75
+ # Tokenize texts
76
+ inputs = self.tokenizer(
77
+ batch,
78
+ padding=True,
79
+ truncation=True,
80
+ max_length=512,
81
+ return_tensors="pt"
82
+ ).to(self.device)
83
+
84
+ # Get predictions
85
+ with torch.no_grad():
86
+ outputs = self.model(**inputs)
87
+ predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
88
+
89
+ print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
90
+ return predictions.tolist(), batch_idx
91
+
92
+ except Exception as e:
93
+ print(f"Error in batch {batch_idx}: {str(e)}")
94
+ return [0] * len(batch), batch_idx
 
 
 
 
 
 
 
 
95
 
96
  def __del__(self):
97
  # Clean up CUDA memory
98
+ if hasattr(self, 'model'):
99
+ del self.model
100
  if torch.cuda.is_available():
101
  torch.cuda.empty_cache()
102