Tonic commited on
Commit
ed458ce
·
unverified ·
1 Parent(s): 485bf3f

attempt to remove all bias configurations last time

Browse files
Files changed (1) hide show
  1. tasks/text.py +46 -1
tasks/text.py CHANGED
@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
8
  from transformers import DataCollatorWithPadding
9
 
10
  from .utils.evaluation import TextEvaluationRequest
11
- from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
13
  router = APIRouter()
14
 
@@ -104,6 +104,51 @@ async def evaluate_text(request: TextEvaluationRequest):
104
  # Set model to evaluation mode
105
  model.eval()
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  #--------------------------------------------------------------------------------------------
108
  # MODEL INFERENCE ENDS HERE
109
  #--------------------------------------------------------------------------------------------
 
8
  from transformers import DataCollatorWithPadding
9
 
10
  from .utils.evaluation import TextEvaluationRequest
11
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
13
  router = APIRouter()
14
 
 
104
  # Set model to evaluation mode
105
  model.eval()
106
 
107
+ # Preprocess function
108
+ def preprocess_function(examples):
109
+ return tokenizer(
110
+ examples["quote"],
111
+ padding=False,
112
+ truncation=True,
113
+ max_length=512,
114
+ return_tensors=None
115
+ )
116
+
117
+ # Tokenize dataset
118
+ tokenized_test = test_dataset.map(
119
+ preprocess_function,
120
+ batched=True,
121
+ remove_columns=test_dataset.column_names
122
+ )
123
+
124
+ # Set format for pytorch
125
+ tokenized_test.set_format("torch")
126
+
127
+ # Create DataLoader
128
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
129
+ test_loader = DataLoader(
130
+ tokenized_test,
131
+ batch_size=16,
132
+ collate_fn=data_collator,
133
+ shuffle=False
134
+ )
135
+
136
+ # Get predictions
137
+ predictions = []
138
+ with torch.no_grad():
139
+ for batch in test_loader:
140
+ batch = {k: v.to(device) for k, v in batch.items()}
141
+ outputs = model(**batch)
142
+ preds = torch.argmax(outputs.logits, dim=-1)
143
+ predictions.extend(preds.cpu().numpy().tolist())
144
+
145
+ # Clean up GPU memory
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
+ except Exception as e:
150
+ print(f"Error during model inference: {str(e)}")
151
+ raise
152
  #--------------------------------------------------------------------------------------------
153
  # MODEL INFERENCE ENDS HERE
154
  #--------------------------------------------------------------------------------------------