Polo123 commited on
Commit
03164f7
·
verified ·
1 Parent(s): 4c8efc5

Update logic2.py

Browse files
Files changed (1) hide show
  1. logic2.py +47 -0
logic2.py CHANGED
@@ -22,9 +22,56 @@ import yaml
22
 
23
  import pickle
24
  #----------------------------------------------
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def load_hetero_data():
27
  with open('Hgraph.pkl', 'rb') as file:
28
  data = pickle.load(file)
29
  return data
 
 
 
 
 
 
 
 
30
 
 
22
 
23
  import pickle
24
  #----------------------------------------------
25
+ # SAGE model
26
+ class GNNEncoder(torch.nn.Module):
27
+ def __init__(self, hidden_channels, out_channels):
28
+ super().__init__()
29
+ # these convolutions have been replicated to match the number of edge types
30
+ self.conv1 = SAGEConv((-1, -1), hidden_channels)
31
+ self.conv2 = SAGEConv((-1, -1), out_channels)
32
 
33
+ def forward(self, x, edge_index):
34
+ x = self.conv1(x, edge_index).relu()
35
+ x = self.conv2(x, edge_index)
36
+ return x
37
+
38
+ class EdgeDecoder(torch.nn.Module):
39
+ def __init__(self, hidden_channels):
40
+ super().__init__()
41
+ self.lin1 = Linear(2 * hidden_channels, hidden_channels)
42
+ self.lin2 = Linear(hidden_channels, 1)
43
+
44
+ def forward(self, z_dict, edge_label_index):
45
+ row, col = edge_label_index
46
+ # concat user and movie embeddings
47
+ z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
48
+ # concatenated embeddings passed to linear layer
49
+ z = self.lin1(z).relu()
50
+ z = self.lin2(z)
51
+ return z.view(-1)
52
+
53
+ class Model(torch.nn.Module):
54
+ def __init__(self, hidden_channels):
55
+ super().__init__()
56
+ self.encoder = GNNEncoder(hidden_channels, hidden_channels)
57
+ self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
58
+ self.decoder = EdgeDecoder(hidden_channels)
59
+
60
+ def forward(self, x_dict, edge_index_dict, edge_label_index):
61
+ # z_dict contains dictionary of movie and user embeddings returned from GraphSage
62
+ z_dict = self.encoder(x_dict, edge_index_dict)
63
+ return self.decoder(z_dict, edge_label_index)
64
+ #----------------------------------------------
65
  def load_hetero_data():
66
  with open('Hgraph.pkl', 'rb') as file:
67
  data = pickle.load(file)
68
  return data
69
+
70
+ def load_model(train_data, val_data, test_data):
71
+ model = Model(hidden_channels=32)
72
+ with torch.no_grad():
73
+ model.encoder(train_data.x_dict, train_data.edge_index_dict)
74
+ model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
75
+ model.eval()
76
+ return model
77