brillm05
commited on
Commit
·
00b7bd2
1
Parent(s):
e8fc320
update model.py
Browse files
model.py
CHANGED
@@ -43,9 +43,9 @@ class BraLM(nn.Module):
|
|
43 |
self.node_bias = nn.Parameter(torch.randn(len(vocab.edge_dict), 1, self.hidden_size).uniform_(-0.5, 0.5))
|
44 |
|
45 |
def to_device(self, device):
|
46 |
-
self.weights.to(device)
|
47 |
-
self.biases.to(device)
|
48 |
-
self.node_bias.to(device)
|
49 |
self.positions.data = self.positions.data.to(device)
|
50 |
self.device = device
|
51 |
|
|
|
43 |
self.node_bias = nn.Parameter(torch.randn(len(vocab.edge_dict), 1, self.hidden_size).uniform_(-0.5, 0.5))
|
44 |
|
45 |
def to_device(self, device):
|
46 |
+
self.weights.data = self.weights.data.to(device)
|
47 |
+
self.biases.data = self.biases.data.to(device)
|
48 |
+
self.node_bias.data = self.node_bias.data.to(device)
|
49 |
self.positions.data = self.positions.data.to(device)
|
50 |
self.device = device
|
51 |
|