Update logic2.py
Browse files
logic2.py
CHANGED
@@ -22,9 +22,56 @@ import yaml
|
|
22 |
|
23 |
import pickle
|
24 |
#----------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def load_hetero_data():
|
27 |
with open('Hgraph.pkl', 'rb') as file:
|
28 |
data = pickle.load(file)
|
29 |
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
22 |
|
23 |
import pickle
|
24 |
#----------------------------------------------
|
25 |
+
# SAGE model
|
26 |
+
class GNNEncoder(torch.nn.Module):
|
27 |
+
def __init__(self, hidden_channels, out_channels):
|
28 |
+
super().__init__()
|
29 |
+
# these convolutions have been replicated to match the number of edge types
|
30 |
+
self.conv1 = SAGEConv((-1, -1), hidden_channels)
|
31 |
+
self.conv2 = SAGEConv((-1, -1), out_channels)
|
32 |
|
33 |
+
def forward(self, x, edge_index):
|
34 |
+
x = self.conv1(x, edge_index).relu()
|
35 |
+
x = self.conv2(x, edge_index)
|
36 |
+
return x
|
37 |
+
|
38 |
+
class EdgeDecoder(torch.nn.Module):
|
39 |
+
def __init__(self, hidden_channels):
|
40 |
+
super().__init__()
|
41 |
+
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
|
42 |
+
self.lin2 = Linear(hidden_channels, 1)
|
43 |
+
|
44 |
+
def forward(self, z_dict, edge_label_index):
|
45 |
+
row, col = edge_label_index
|
46 |
+
# concat user and movie embeddings
|
47 |
+
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
|
48 |
+
# concatenated embeddings passed to linear layer
|
49 |
+
z = self.lin1(z).relu()
|
50 |
+
z = self.lin2(z)
|
51 |
+
return z.view(-1)
|
52 |
+
|
53 |
+
class Model(torch.nn.Module):
|
54 |
+
def __init__(self, hidden_channels):
|
55 |
+
super().__init__()
|
56 |
+
self.encoder = GNNEncoder(hidden_channels, hidden_channels)
|
57 |
+
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
|
58 |
+
self.decoder = EdgeDecoder(hidden_channels)
|
59 |
+
|
60 |
+
def forward(self, x_dict, edge_index_dict, edge_label_index):
|
61 |
+
# z_dict contains dictionary of movie and user embeddings returned from GraphSage
|
62 |
+
z_dict = self.encoder(x_dict, edge_index_dict)
|
63 |
+
return self.decoder(z_dict, edge_label_index)
|
64 |
+
#----------------------------------------------
|
65 |
def load_hetero_data():
|
66 |
with open('Hgraph.pkl', 'rb') as file:
|
67 |
data = pickle.load(file)
|
68 |
return data
|
69 |
+
|
70 |
+
def load_model(train_data, val_data, test_data):
|
71 |
+
model = Model(hidden_channels=32)
|
72 |
+
with torch.no_grad():
|
73 |
+
model.encoder(train_data.x_dict, train_data.edge_index_dict)
|
74 |
+
model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
|
75 |
+
model.eval()
|
76 |
+
return model
|
77 |
|