hbanduk commited on
Commit
31cb1dc
·
verified ·
1 Parent(s): 4707f08

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +5 -2
tasks/text.py CHANGED
@@ -63,6 +63,9 @@ async def evaluate_text(request: TextEvaluationRequest):
63
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
  import torch
65
  from torch.utils.data import DataLoader, TensorDataset
 
 
 
66
 
67
  # Load model and tokenizer from Hugging Face Hub
68
  MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
@@ -71,7 +74,7 @@ async def evaluate_text(request: TextEvaluationRequest):
71
  MAX_LENGTH = 365
72
 
73
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
74
- #model.to(device)
75
  model.eval() # Set to evaluation mode
76
 
77
 
@@ -85,7 +88,7 @@ async def evaluate_text(request: TextEvaluationRequest):
85
  predictions = []
86
  with torch.no_grad():
87
  for batch in test_loader:
88
- # input_ids, attention_mask, labels = [x.to(device) for x in batch]
89
  outputs = model(input_ids, attention_mask=attention_mask)
90
  preds = torch.argmax(outputs.logits, dim=1)
91
  predictions.extend(preds.cpu().numpy())
 
63
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
  import torch
65
  from torch.utils.data import DataLoader, TensorDataset
66
+
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ print(f"Using device: {device}")
69
 
70
  # Load model and tokenizer from Hugging Face Hub
71
  MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
 
74
  MAX_LENGTH = 365
75
 
76
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
77
+ model.to(device)
78
  model.eval() # Set to evaluation mode
79
 
80
 
 
88
  predictions = []
89
  with torch.no_grad():
90
  for batch in test_loader:
91
+ input_ids, attention_mask, labels = [x.to(device) for x in batch]
92
  outputs = model(input_ids, attention_mask=attention_mask)
93
  preds = torch.argmax(outputs.logits, dim=1)
94
  predictions.extend(preds.cpu().numpy())