hbanduk commited on
Commit
1b0a6ea
·
verified ·
1 Parent(s): d9fb9e1

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +6 -4
tasks/text.py CHANGED
@@ -62,7 +62,7 @@ async def evaluate_text(request: TextEvaluationRequest):
62
 
63
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
  import torch
65
- from torch.utils.data import DataLoader, Dataset
66
 
67
  # Load model and tokenizer from Hugging Face Hub
68
  MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
@@ -71,14 +71,16 @@ async def evaluate_text(request: TextEvaluationRequest):
71
  MAX_LENGTH = 365
72
 
73
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
 
74
  model.eval() # Set to evaluation mode
75
 
76
 
77
  # tokenize texts
78
  test_encodings = tokenizer(test_dataset["quote"], padding='max_length', truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
79
-
80
- test_dataset = TensorDataset(test_encodings["input_ids"], test_encodings["attention_mask"], test_labels)
81
- test_loader = DataLoader(test_dataset, batch_size=16)
 
82
 
83
  predictions = []
84
  with torch.no_grad():
 
62
 
63
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
64
  import torch
65
+ from torch.utils.data import DataLoader, TensorDataset
66
 
67
  # Load model and tokenizer from Hugging Face Hub
68
  MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
 
71
  MAX_LENGTH = 365
72
 
73
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
74
+ model.to(device)
75
  model.eval() # Set to evaluation mode
76
 
77
 
78
  # tokenize texts
79
  test_encodings = tokenizer(test_dataset["quote"], padding='max_length', truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
80
+ test_labels = torch.tensor(test_dataset["label"])
81
+
82
+ test_dataset_0 = TensorDataset(test_encodings["input_ids"], test_encodings["attention_mask"], test_labels)
83
+ test_loader = DataLoader(test_dataset_0, batch_size=16)
84
 
85
  predictions = []
86
  with torch.no_grad():