Tonic commited on
Commit
45a2367
·
unverified ·
1 Parent(s): a036e74

improve text classifier

Browse files
Files changed (1) hide show
  1. tasks/text.py +46 -23
tasks/text.py CHANGED
@@ -7,8 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- import torch.nn as nn
11
- from transformers import AutoTokenizer, pipeline
12
  from huggingface_hub import login
13
  from dotenv import load_dotenv
14
 
@@ -28,34 +27,44 @@ os.environ["TORCH_COMPILE_DISABLE"] = "1"
28
 
29
  router = APIRouter()
30
 
31
- DESCRIPTION = "Climate Guard Toxic Agent model for climate disinformation detection"
32
  ROUTE = "/text"
 
33
 
34
  class TextClassifier:
35
  def __init__(self):
36
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
  max_retries = 3
38
- model_name = "Tonic/climate-guard-toxic-agent"
39
 
40
  for attempt in range(max_retries):
41
  try:
42
- # Initialize tokenizer first
43
  self.tokenizer = AutoTokenizer.from_pretrained(
44
- model_name,
45
- model_max_length=512, # Reduced from 8192
46
  padding_side='right',
47
  truncation_side='right'
48
  )
49
 
50
- # Use pipeline for simpler initialization
 
 
 
 
 
 
 
 
 
 
51
  self.classifier = pipeline(
52
  "text-classification",
53
- model=model_name,
54
  tokenizer=self.tokenizer,
55
  device=self.device,
56
  max_length=512,
57
  truncation=True,
58
- batch_size=32
59
  )
60
 
61
  print("Model initialized successfully")
@@ -69,22 +78,36 @@ class TextClassifier:
69
 
70
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
71
  """Process a batch of texts and return their predictions"""
72
- try:
73
- print(f"Processing batch {batch_idx} with {len(batch)} items")
74
-
75
- # Use pipeline for prediction
76
- outputs = self.classifier(batch)
77
- predictions = [int(output['label'].split('_')[0]) for output in outputs]
78
-
79
- print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
80
- return predictions, batch_idx
81
-
82
- except Exception as e:
83
- print(f"Error in batch {batch_idx}: {str(e)}")
84
- return [0] * len(batch), batch_idx
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def __del__(self):
87
  # Clean up CUDA memory
 
 
88
  if hasattr(self, 'classifier'):
89
  del self.classifier
90
  if torch.cuda.is_available():
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
 
27
 
28
  router = APIRouter()
29
 
30
+ DESCRIPTION = "ModernBERT fine-tuned for climate disinformation detection"
31
  ROUTE = "/text"
32
+ MODEL_NAME = "answerdotai/ModernBERT-base"
33
 
34
  class TextClassifier:
35
  def __init__(self):
36
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
  max_retries = 3
 
38
 
39
  for attempt in range(max_retries):
40
  try:
41
+ # Initialize tokenizer
42
  self.tokenizer = AutoTokenizer.from_pretrained(
43
+ MODEL_NAME,
44
+ model_max_length=512,
45
  padding_side='right',
46
  truncation_side='right'
47
  )
48
 
49
+ # Initialize model with specific configuration
50
+ self.model = AutoModelForSequenceClassification.from_pretrained(
51
+ MODEL_NAME,
52
+ num_labels=8,
53
+ problem_type="single_label_classification"
54
+ )
55
+
56
+ # Move model to appropriate device
57
+ self.model = self.model.to(self.device)
58
+
59
+ # Initialize pipeline with the model and tokenizer
60
  self.classifier = pipeline(
61
  "text-classification",
62
+ model=self.model,
63
  tokenizer=self.tokenizer,
64
  device=self.device,
65
  max_length=512,
66
  truncation=True,
67
+ batch_size=16
68
  )
69
 
70
  print("Model initialized successfully")
 
78
 
79
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
80
  """Process a batch of texts and return their predictions"""
81
+ max_retries = 3
82
+ for attempt in range(max_retries):
83
+ try:
84
+ print(f"Processing batch {batch_idx} with {len(batch)} items")
85
+
86
+ # Process texts with error handling
87
+ predictions = []
88
+ for text in batch:
89
+ try:
90
+ result = self.classifier(text)
91
+ pred_label = int(result[0]['label'].split('_')[0])
92
+ predictions.append(pred_label)
93
+ except Exception as e:
94
+ print(f"Error processing text in batch {batch_idx}: {str(e)}")
95
+ predictions.append(0) # Default prediction
96
+
97
+ print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
98
+ return predictions, batch_idx
99
+
100
+ except Exception as e:
101
+ if attempt == max_retries - 1:
102
+ print(f"Final error in batch {batch_idx}: {str(e)}")
103
+ return [0] * len(batch), batch_idx
104
+ print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}")
105
+ time.sleep(1)
106
 
107
  def __del__(self):
108
  # Clean up CUDA memory
109
+ if hasattr(self, 'model'):
110
+ del self.model
111
  if hasattr(self, 'classifier'):
112
  del self.classifier
113
  if torch.cuda.is_available():