Tonic commited on
Commit
0c9dbe5
·
unverified ·
1 Parent(s): 30f3a06

fix model loading error

Browse files
Files changed (1) hide show
  1. tasks/text.py +10 -6
tasks/text.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
@@ -38,15 +38,19 @@ class TextClassifier:
38
 
39
  for attempt in range(max_retries):
40
  try:
41
- # Load config and modify it to remove bias parameter
42
  self.config = AutoConfig.from_pretrained(model_name)
 
 
43
  if hasattr(self.config, 'norm_bias'):
44
  delattr(self.config, 'norm_bias')
45
 
46
  # Initialize tokenizer
47
  self.tokenizer = AutoTokenizer.from_pretrained(
48
  model_name,
49
- model_max_length=2048
 
 
50
  )
51
 
52
  # Initialize model with modified config
@@ -75,10 +79,10 @@ class TextClassifier:
75
  # Tokenize
76
  inputs = self.tokenizer(
77
  batch,
78
- padding=True,
79
  truncation=True,
80
- max_length=2048,
81
- return_tensors="pt"
82
  ).to(self.device)
83
 
84
  # Get predictions
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
 
38
 
39
  for attempt in range(max_retries):
40
  try:
41
+ # Load config and modify it
42
  self.config = AutoConfig.from_pretrained(model_name)
43
+
44
+ # Remove problematic bias parameters
45
  if hasattr(self.config, 'norm_bias'):
46
  delattr(self.config, 'norm_bias')
47
 
48
  # Initialize tokenizer
49
  self.tokenizer = AutoTokenizer.from_pretrained(
50
  model_name,
51
+ model_max_length=512,
52
+ padding_side='right',
53
+ truncation_side='right'
54
  )
55
 
56
  # Initialize model with modified config
 
79
  # Tokenize
80
  inputs = self.tokenizer(
81
  batch,
82
+ return_tensors="pt",
83
  truncation=True,
84
+ max_length=512,
85
+ padding='max_length'
86
  ).to(self.device)
87
 
88
  # Get predictions