brillm05 commited on
Commit
00b7bd2
·
1 Parent(s): e8fc320

update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -3
model.py CHANGED
@@ -43,9 +43,9 @@ class BraLM(nn.Module):
43
  self.node_bias = nn.Parameter(torch.randn(len(vocab.edge_dict), 1, self.hidden_size).uniform_(-0.5, 0.5))
44
 
45
  def to_device(self, device):
46
- self.weights.to(device)
47
- self.biases.to(device)
48
- self.node_bias.to(device)
49
  self.positions.data = self.positions.data.to(device)
50
  self.device = device
51
 
 
43
  self.node_bias = nn.Parameter(torch.randn(len(vocab.edge_dict), 1, self.hidden_size).uniform_(-0.5, 0.5))
44
 
45
  def to_device(self, device):
46
+ self.weights.data = self.weights.data.to(device)
47
+ self.biases.data = self.biases.data.to(device)
48
+ self.node_bias.data = self.node_bias.data.to(device)
49
  self.positions.data = self.positions.data.to(device)
50
  self.device = device
51