Update tasks/text.py
Browse files- tasks/text.py +5 -3
tasks/text.py
CHANGED
@@ -92,6 +92,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
model = AutoModelForSequenceClassification.from_pretrained(path_model).to(device).eval()
|
94 |
tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
|
|
|
|
|
95 |
|
96 |
# Use optimized tokenization
|
97 |
def preprocess_function(df):
|
@@ -106,10 +108,10 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
106 |
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
107 |
|
108 |
# Optimized inference function
|
109 |
-
def predict(dataset):
|
110 |
all_preds = []
|
111 |
with torch.no_grad(): # No gradient computation (saves energy)
|
112 |
-
for batch in torch.utils.data.DataLoader(dataset, batch_size=
|
113 |
outputs = model(**batch)
|
114 |
preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
|
115 |
all_preds.extend(preds)
|
@@ -117,7 +119,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
117 |
|
118 |
# Run inference
|
119 |
predictions = predict(tokenized_test)
|
120 |
-
|
121 |
# predictions = np.array([np.argmax(x) for x in preds[0]])
|
122 |
|
123 |
#--------------------------------------------------------------------------------------------
|
|
|
92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
model = AutoModelForSequenceClassification.from_pretrained(path_model).to(device).eval()
|
94 |
tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
|
95 |
+
|
96 |
+
model.half()
|
97 |
|
98 |
# Use optimized tokenization
|
99 |
def preprocess_function(df):
|
|
|
108 |
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
109 |
|
110 |
# Optimized inference function
|
111 |
+
def predict(dataset, batch_size=16):
|
112 |
all_preds = []
|
113 |
with torch.no_grad(): # No gradient computation (saves energy)
|
114 |
+
for batch in torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn):
|
115 |
outputs = model(**batch)
|
116 |
preds = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
|
117 |
all_preds.extend(preds)
|
|
|
119 |
|
120 |
# Run inference
|
121 |
predictions = predict(tokenized_test)
|
122 |
+
print(predictions)
|
123 |
# predictions = np.array([np.argmax(x) for x in preds[0]])
|
124 |
|
125 |
#--------------------------------------------------------------------------------------------
|