MatthiasPi commited on
Commit
1f45c21
·
verified ·
1 Parent(s): e39064f

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +5 -3
tasks/text.py CHANGED
@@ -92,6 +92,8 @@ async def evaluate_text(request: TextEvaluationRequest):
92
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
  model = AutoModelForSequenceClassification.from_pretrained(path_model).to(device).eval()
94
  tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
 
 
95
 
96
  # Use optimized tokenization
97
  def preprocess_function(df):
@@ -106,10 +108,10 @@ async def evaluate_text(request: TextEvaluationRequest):
106
  return {"input_ids": input_ids, "attention_mask": attention_mask}
107
 
108
  # Optimized inference function
109
- def predict(dataset):
110
  all_preds = []
111
  with torch.no_grad(): # No gradient computation (saves energy)
112
- for batch in torch.utils.data.DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn):
113
  outputs = model(**batch)
114
  preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
115
  all_preds.extend(preds)
@@ -117,7 +119,7 @@ async def evaluate_text(request: TextEvaluationRequest):
117
 
118
  # Run inference
119
  predictions = predict(tokenized_test)
120
-
121
  # predictions = np.array([np.argmax(x) for x in preds[0]])
122
 
123
  #--------------------------------------------------------------------------------------------
 
92
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
  model = AutoModelForSequenceClassification.from_pretrained(path_model).to(device).eval()
94
  tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
95
+
96
+ model.half()
97
 
98
  # Use optimized tokenization
99
  def preprocess_function(df):
 
108
  return {"input_ids": input_ids, "attention_mask": attention_mask}
109
 
110
  # Optimized inference function
111
+ def predict(dataset, batch_size=16):
112
  all_preds = []
113
  with torch.no_grad(): # No gradient computation (saves energy)
114
+ for batch in torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn):
115
  outputs = model(**batch)
116
  preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
117
  all_preds.extend(preds)
 
119
 
120
  # Run inference
121
  predictions = predict(tokenized_test)
122
+ print(predictions)
123
  # predictions = np.array([np.argmax(x) for x in preds[0]])
124
 
125
  #--------------------------------------------------------------------------------------------