hbanduk commited on
Commit
4707f08
·
verified ·
1 Parent(s): 1b0a6ea

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +2 -2
tasks/text.py CHANGED
@@ -71,7 +71,7 @@ async def evaluate_text(request: TextEvaluationRequest):
71
  MAX_LENGTH = 365
72
 
73
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
74
- model.to(device)
75
  model.eval() # Set to evaluation mode
76
 
77
 
@@ -85,7 +85,7 @@ async def evaluate_text(request: TextEvaluationRequest):
85
  predictions = []
86
  with torch.no_grad():
87
  for batch in test_loader:
88
- input_ids, attention_mask, labels = [x.to(device) for x in batch]
89
  outputs = model(input_ids, attention_mask=attention_mask)
90
  preds = torch.argmax(outputs.logits, dim=1)
91
  predictions.extend(preds.cpu().numpy())
 
71
  MAX_LENGTH = 365
72
 
73
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
74
+ #model.to(device)
75
  model.eval() # Set to evaluation mode
76
 
77
 
 
85
  predictions = []
86
  with torch.no_grad():
87
  for batch in test_loader:
88
+ # input_ids, attention_mask, labels = [x.to(device) for x in batch]
89
  outputs = model(input_ids, attention_mask=attention_mask)
90
  preds = torch.argmax(outputs.logits, dim=1)
91
  predictions.extend(preds.cpu().numpy())