Tonic commited on
Commit
b206095
·
unverified ·
1 Parent(s): 0831f97

improve model loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +16 -6
tasks/text.py CHANGED
@@ -39,10 +39,12 @@ class TextClassifier:
39
  # Initialize tokenizer
40
  self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
41
 
42
- # Initialize model
43
- self.model = BertForSequenceClassification.from_pretrained(
44
  MODEL_NAME,
 
45
  num_labels=8,
 
46
  ignore_mismatched_sizes=True
47
  ).to(self.device)
48
 
@@ -117,19 +119,27 @@ async def evaluate_text(request: TextEvaluationRequest):
117
  examples["quote"],
118
  truncation=True,
119
  padding=True,
120
- max_length=512
 
121
  )
122
 
123
  # Tokenize dataset
124
- tokenized_test = test_dataset.map(preprocess_function, batched=True)
125
- tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
 
 
 
 
 
 
126
 
127
  # Create DataLoader
128
  data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
129
  test_loader = DataLoader(
130
  tokenized_test,
131
  batch_size=16,
132
- collate_fn=data_collator
 
133
  )
134
 
135
  # Get predictions
 
39
  # Initialize tokenizer
40
  self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
41
 
42
+ # Initialize model with auto class
43
+ self.model = AutoModelForSequenceClassification.from_pretrained(
44
  MODEL_NAME,
45
+ trust_remote_code=True,
46
  num_labels=8,
47
+ problem_type="single_label_classification",
48
  ignore_mismatched_sizes=True
49
  ).to(self.device)
50
 
 
119
  examples["quote"],
120
  truncation=True,
121
  padding=True,
122
+ max_length=512,
123
+ return_tensors=None # Changed this to None for batched processing
124
  )
125
 
126
  # Tokenize dataset
127
+ tokenized_test = test_dataset.map(
128
+ preprocess_function,
129
+ batched=True,
130
+ remove_columns=test_dataset.column_names
131
+ )
132
+
133
+ # Set format for pytorch
134
+ tokenized_test.set_format("torch")
135
 
136
  # Create DataLoader
137
  data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
138
  test_loader = DataLoader(
139
  tokenized_test,
140
  batch_size=16,
141
+ collate_fn=data_collator,
142
+ shuffle=False
143
  )
144
 
145
  # Get predictions