Spaces:
Sleeping
Sleeping
update textclassifier
Browse files- tasks/text.py +16 -20
tasks/text.py
CHANGED
@@ -39,35 +39,31 @@ class TextClassifier:
|
|
39 |
|
40 |
for attempt in range(max_retries):
|
41 |
try:
|
42 |
-
#
|
43 |
-
self.config = AutoConfig.from_pretrained(
|
44 |
-
model_name,
|
45 |
-
num_labels=8,
|
46 |
-
problem_type="single_label_classification",
|
47 |
-
trust_remote_code=True
|
48 |
-
)
|
49 |
-
|
50 |
-
# Remove problematic config attributes
|
51 |
-
if hasattr(self.config, 'norm_bias'):
|
52 |
-
delattr(self.config, 'norm_bias')
|
53 |
-
if hasattr(self.config, 'bias'):
|
54 |
-
delattr(self.config, 'bias')
|
55 |
-
|
56 |
-
# Initialize tokenizer
|
57 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
58 |
model_name,
|
59 |
model_max_length=8192,
|
60 |
padding_side='right',
|
61 |
-
truncation_side='right'
|
62 |
-
trust_remote_code=True
|
63 |
)
|
64 |
|
65 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
self.model = AutoModelForSequenceClassification.from_pretrained(
|
67 |
model_name,
|
68 |
config=self.config,
|
69 |
-
trust_remote_code=True,
|
70 |
-
torch_dtype=torch.float32,
|
71 |
ignore_mismatched_sizes=True
|
72 |
)
|
73 |
|
|
|
39 |
|
40 |
for attempt in range(max_retries):
|
41 |
try:
|
42 |
+
# Initialize tokenizer first
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
44 |
model_name,
|
45 |
model_max_length=8192,
|
46 |
padding_side='right',
|
47 |
+
truncation_side='right'
|
|
|
48 |
)
|
49 |
|
50 |
+
# Load base config
|
51 |
+
self.config = AutoConfig.from_pretrained(
|
52 |
+
model_name,
|
53 |
+
num_labels=8,
|
54 |
+
problem_type="single_label_classification"
|
55 |
+
)
|
56 |
+
|
57 |
+
# Set required attributes
|
58 |
+
self.config.hidden_size = 768
|
59 |
+
self.config.num_attention_heads = 12
|
60 |
+
self.config.num_hidden_layers = 12
|
61 |
+
self.config.norm_eps = 1e-5
|
62 |
+
|
63 |
+
# Initialize model with basic config
|
64 |
self.model = AutoModelForSequenceClassification.from_pretrained(
|
65 |
model_name,
|
66 |
config=self.config,
|
|
|
|
|
67 |
ignore_mismatched_sizes=True
|
68 |
)
|
69 |
|