Commit
·
5dbee3d
1
Parent(s):
4a4406e
Update DistilBERT.py
Browse files- DistilBERT.py +4 -2
DistilBERT.py
CHANGED
@@ -89,6 +89,10 @@ class DistilBERTClass(torch.nn.Module):
|
|
89 |
self.classifier = torch.nn.Linear(768, 1)
|
90 |
|
91 |
def forward(self, input_ids, attention_mask, token_type_ids):
|
|
|
|
|
|
|
|
|
92 |
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
|
93 |
hidden_state = output_1[0]
|
94 |
pooler = hidden_state[:, 0]
|
@@ -98,8 +102,6 @@ class DistilBERTClass(torch.nn.Module):
|
|
98 |
output = self.classifier(pooler)
|
99 |
return output
|
100 |
|
101 |
-
model_DB = DistilBERTClass()
|
102 |
-
model_DB.to(device)
|
103 |
|
104 |
# Validation function
|
105 |
def validation(testing_loader):
|
|
|
89 |
self.classifier = torch.nn.Linear(768, 1)
|
90 |
|
91 |
def forward(self, input_ids, attention_mask, token_type_ids):
|
92 |
+
# Convert input_ids to PyTorch tensor if it's a list
|
93 |
+
if isinstance(input_ids, list):
|
94 |
+
input_ids = torch.tensor(input_ids)
|
95 |
+
|
96 |
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
|
97 |
hidden_state = output_1[0]
|
98 |
pooler = hidden_state[:, 0]
|
|
|
102 |
output = self.classifier(pooler)
|
103 |
return output
|
104 |
|
|
|
|
|
105 |
|
106 |
# Validation function
|
107 |
def validation(testing_loader):
|