Spaces:
Sleeping
Sleeping
improve model loading
Browse files- tasks/text.py +16 -6
tasks/text.py
CHANGED
@@ -39,10 +39,12 @@ class TextClassifier:
|
|
39 |
# Initialize tokenizer
|
40 |
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
41 |
|
42 |
-
# Initialize model
|
43 |
-
self.model =
|
44 |
MODEL_NAME,
|
|
|
45 |
num_labels=8,
|
|
|
46 |
ignore_mismatched_sizes=True
|
47 |
).to(self.device)
|
48 |
|
@@ -117,19 +119,27 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
117 |
examples["quote"],
|
118 |
truncation=True,
|
119 |
padding=True,
|
120 |
-
max_length=512
|
|
|
121 |
)
|
122 |
|
123 |
# Tokenize dataset
|
124 |
-
tokenized_test = test_dataset.map(
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
# Create DataLoader
|
128 |
data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
|
129 |
test_loader = DataLoader(
|
130 |
tokenized_test,
|
131 |
batch_size=16,
|
132 |
-
collate_fn=data_collator
|
|
|
133 |
)
|
134 |
|
135 |
# Get predictions
|
|
|
39 |
# Initialize tokenizer
|
40 |
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
41 |
|
42 |
+
# Initialize model with auto class
|
43 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
44 |
MODEL_NAME,
|
45 |
+
trust_remote_code=True,
|
46 |
num_labels=8,
|
47 |
+
problem_type="single_label_classification",
|
48 |
ignore_mismatched_sizes=True
|
49 |
).to(self.device)
|
50 |
|
|
|
119 |
examples["quote"],
|
120 |
truncation=True,
|
121 |
padding=True,
|
122 |
+
max_length=512,
|
123 |
+
return_tensors=None # Changed this to None for batched processing
|
124 |
)
|
125 |
|
126 |
# Tokenize dataset
|
127 |
+
tokenized_test = test_dataset.map(
|
128 |
+
preprocess_function,
|
129 |
+
batched=True,
|
130 |
+
remove_columns=test_dataset.column_names
|
131 |
+
)
|
132 |
+
|
133 |
+
# Set format for pytorch
|
134 |
+
tokenized_test.set_format("torch")
|
135 |
|
136 |
# Create DataLoader
|
137 |
data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
|
138 |
test_loader = DataLoader(
|
139 |
tokenized_test,
|
140 |
batch_size=16,
|
141 |
+
collate_fn=data_collator,
|
142 |
+
shuffle=False
|
143 |
)
|
144 |
|
145 |
# Get predictions
|