Commit
·
ceefc9f
1
Parent(s):
5dbee3d
Update DistilBERT.py
Browse files- DistilBERT.py +0 -4
DistilBERT.py
CHANGED
@@ -89,10 +89,6 @@ class DistilBERTClass(torch.nn.Module):
|
|
89 |
self.classifier = torch.nn.Linear(768, 1)
|
90 |
|
91 |
def forward(self, input_ids, attention_mask, token_type_ids):
|
92 |
-
# Convert input_ids to PyTorch tensor if it's a list
|
93 |
-
if isinstance(input_ids, list):
|
94 |
-
input_ids = torch.tensor(input_ids)
|
95 |
-
|
96 |
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
|
97 |
hidden_state = output_1[0]
|
98 |
pooler = hidden_state[:, 0]
|
|
|
89 |
self.classifier = torch.nn.Linear(768, 1)
|
90 |
|
91 |
def forward(self, input_ids, attention_mask, token_type_ids):
|
|
|
|
|
|
|
|
|
92 |
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
|
93 |
hidden_state = output_1[0]
|
94 |
pooler = hidden_state[:, 0]
|