Spaces:
Sleeping
Sleeping
fix model loading error
Browse files- tasks/text.py +10 -6
tasks/text.py
CHANGED
@@ -7,7 +7,7 @@ import os
|
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict, Tuple
|
9 |
import torch
|
10 |
-
from transformers import
|
11 |
from huggingface_hub import login
|
12 |
from dotenv import load_dotenv
|
13 |
|
@@ -38,15 +38,19 @@ class TextClassifier:
|
|
38 |
|
39 |
for attempt in range(max_retries):
|
40 |
try:
|
41 |
-
# Load config and modify it
|
42 |
self.config = AutoConfig.from_pretrained(model_name)
|
|
|
|
|
43 |
if hasattr(self.config, 'norm_bias'):
|
44 |
delattr(self.config, 'norm_bias')
|
45 |
|
46 |
# Initialize tokenizer
|
47 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
48 |
model_name,
|
49 |
-
model_max_length=
|
|
|
|
|
50 |
)
|
51 |
|
52 |
# Initialize model with modified config
|
@@ -75,10 +79,10 @@ class TextClassifier:
|
|
75 |
# Tokenize
|
76 |
inputs = self.tokenizer(
|
77 |
batch,
|
78 |
-
|
79 |
truncation=True,
|
80 |
-
max_length=
|
81 |
-
|
82 |
).to(self.device)
|
83 |
|
84 |
# Get predictions
|
|
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict, Tuple
|
9 |
import torch
|
10 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
|
11 |
from huggingface_hub import login
|
12 |
from dotenv import load_dotenv
|
13 |
|
|
|
38 |
|
39 |
for attempt in range(max_retries):
|
40 |
try:
|
41 |
+
# Load config and modify it
|
42 |
self.config = AutoConfig.from_pretrained(model_name)
|
43 |
+
|
44 |
+
# Remove problematic bias parameters
|
45 |
if hasattr(self.config, 'norm_bias'):
|
46 |
delattr(self.config, 'norm_bias')
|
47 |
|
48 |
# Initialize tokenizer
|
49 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
50 |
model_name,
|
51 |
+
model_max_length=512,
|
52 |
+
padding_side='right',
|
53 |
+
truncation_side='right'
|
54 |
)
|
55 |
|
56 |
# Initialize model with modified config
|
|
|
79 |
# Tokenize
|
80 |
inputs = self.tokenizer(
|
81 |
batch,
|
82 |
+
return_tensors="pt",
|
83 |
truncation=True,
|
84 |
+
max_length=512,
|
85 |
+
padding='max_length'
|
86 |
).to(self.device)
|
87 |
|
88 |
# Get predictions
|