perman2011 commited on
Commit
5dbee3d
·
1 Parent(s): 4a4406e

Update DistilBERT.py

Browse files
Files changed (1) hide show
  1. DistilBERT.py +4 -2
DistilBERT.py CHANGED
@@ -89,6 +89,10 @@ class DistilBERTClass(torch.nn.Module):
89
  self.classifier = torch.nn.Linear(768, 1)
90
 
91
  def forward(self, input_ids, attention_mask, token_type_ids):
 
 
 
 
92
  output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
93
  hidden_state = output_1[0]
94
  pooler = hidden_state[:, 0]
@@ -98,8 +102,6 @@ class DistilBERTClass(torch.nn.Module):
98
  output = self.classifier(pooler)
99
  return output
100
 
101
- model_DB = DistilBERTClass()
102
- model_DB.to(device)
103
 
104
  # Validation function
105
  def validation(testing_loader):
 
89
  self.classifier = torch.nn.Linear(768, 1)
90
 
91
  def forward(self, input_ids, attention_mask, token_type_ids):
92
+ # Convert input_ids to PyTorch tensor if it's a list
93
+ if isinstance(input_ids, list):
94
+ input_ids = torch.tensor(input_ids)
95
+
96
  output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
97
  hidden_state = output_1[0]
98
  pooler = hidden_state[:, 0]
 
102
  output = self.classifier(pooler)
103
  return output
104
 
 
 
105
 
106
  # Validation function
107
  def validation(testing_loader):