Tonic commited on
Commit
89f8be4
·
unverified ·
1 Parent(s): c3f000b

improve model loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +5 -28
tasks/text.py CHANGED
@@ -37,38 +37,16 @@ class TextClassifier:
37
 
38
  try:
39
  # Initialize tokenizer
40
- self.tokenizer = AutoTokenizer.from_pretrained(
41
- TOKENIZER_NAME,
42
- model_max_length=8192,
43
- padding_side='right',
44
- truncation_side='right'
45
- )
46
-
47
- # Load model configuration
48
- model_config = {
49
- "architectures": ["ModernBertForSequenceClassification"],
50
- "model_type": "modernbert",
51
- "num_labels": 8,
52
- "problem_type": "single_label_classification",
53
- "hidden_size": 768,
54
- "num_attention_heads": 12,
55
- "num_hidden_layers": 22,
56
- "intermediate_size": 1152,
57
- "max_position_embeddings": 8192,
58
- "torch_dtype": "float32",
59
- "transformers_version": "4.48.3",
60
- "layer_norm_eps": 1e-05
61
- }
62
 
63
  # Initialize model
64
- self.model = AutoModelForSequenceClassification.from_pretrained(
65
  MODEL_NAME,
66
- config=model_config,
67
- ignore_mismatched_sizes=True,
68
- trust_remote_code=True
69
  ).to(self.device)
70
 
71
- # Convert to half precision
72
  self.model = self.model.half()
73
  self.model.eval()
74
 
@@ -79,7 +57,6 @@ class TextClassifier:
79
  raise
80
 
81
  def process_batch(self, batch):
82
- """Process a batch of texts and return their predictions"""
83
  try:
84
  # Move batch to device
85
  input_ids = batch['input_ids'].to(self.device)
 
37
 
38
  try:
39
  # Initialize tokenizer
40
+ self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Initialize model
43
+ self.model = BertForSequenceClassification.from_pretrained(
44
  MODEL_NAME,
45
+ num_labels=8,
46
+ ignore_mismatched_sizes=True
 
47
  ).to(self.device)
48
 
49
+ # Convert to half precision and eval mode
50
  self.model = self.model.half()
51
  self.model.eval()
52
 
 
57
  raise
58
 
59
  def process_batch(self, batch):
 
60
  try:
61
  # Move batch to device
62
  input_ids = batch['input_ids'].to(self.device)