Tonic commited on
Commit
30f3a06
·
unverified ·
1 Parent(s): bc4f464

fix model loading error

Browse files
Files changed (1) hide show
  1. tasks/text.py +33 -10
tasks/text.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import pipeline
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
@@ -38,13 +38,26 @@ class TextClassifier:
38
 
39
  for attempt in range(max_retries):
40
  try:
41
- # Initialize pipeline
42
- self.classifier = pipeline(
43
- "text-classification",
44
- model=model_name,
45
- device=self.device,
46
- batch_size=32
 
 
 
 
 
 
 
 
 
 
47
  )
 
 
 
48
  print("Model initialized successfully")
49
  break
50
 
@@ -59,9 +72,19 @@ class TextClassifier:
59
  try:
60
  print(f"Processing batch {batch_idx} with {len(batch)} items")
61
 
62
- # Use pipeline for prediction
63
- results = self.classifier(batch)
64
- predictions = [int(result['label'].split('_')[0]) for result in results]
 
 
 
 
 
 
 
 
 
 
65
 
66
  print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
67
  return predictions, batch_idx
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
 
38
 
39
  for attempt in range(max_retries):
40
  try:
41
+ # Load config and modify it to remove bias parameter
42
+ self.config = AutoConfig.from_pretrained(model_name)
43
+ if hasattr(self.config, 'norm_bias'):
44
+ delattr(self.config, 'norm_bias')
45
+
46
+ # Initialize tokenizer
47
+ self.tokenizer = AutoTokenizer.from_pretrained(
48
+ model_name,
49
+ model_max_length=2048
50
+ )
51
+
52
+ # Initialize model with modified config
53
+ self.model = AutoModelForSequenceClassification.from_pretrained(
54
+ model_name,
55
+ config=self.config,
56
+ ignore_mismatched_sizes=True
57
  )
58
+
59
+ self.model.to(self.device)
60
+ self.model.eval()
61
  print("Model initialized successfully")
62
  break
63
 
 
72
  try:
73
  print(f"Processing batch {batch_idx} with {len(batch)} items")
74
 
75
+ # Tokenize
76
+ inputs = self.tokenizer(
77
+ batch,
78
+ padding=True,
79
+ truncation=True,
80
+ max_length=2048,
81
+ return_tensors="pt"
82
+ ).to(self.device)
83
+
84
+ # Get predictions
85
+ with torch.no_grad():
86
+ outputs = self.model(**inputs)
87
+ predictions = torch.argmax(outputs.logits, dim=-1).cpu().tolist()
88
 
89
  print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
90
  return predictions, batch_idx