Update logic.py
Browse files
logic.py
CHANGED
@@ -151,9 +151,74 @@ def GenresEncoder(movie_docs):
|
|
151 |
x[i, mapping[genre]] = 1
|
152 |
return x.to(device)
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
#-------------------------------------------------------------------------------------------
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
def make_graph():
|
158 |
metadata_path = './sampled_movie_dataset/movies_metadata.csv'
|
159 |
df = pd.read_csv(metadata_path)
|
@@ -349,7 +414,29 @@ def make_pyg_graph(movie_rec_db):
|
|
349 |
)(data)
|
350 |
|
351 |
return train_data, val_data, test_data
|
352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
|
355 |
|
|
|
151 |
x[i, mapping[genre]] = 1
|
152 |
return x.to(device)
|
153 |
|
154 |
+
def weighted_mse_loss(pred, target, weight=None):
|
155 |
+
weight = 1. if weight is None else weight[target].to(pred.dtype)
|
156 |
+
return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()
|
157 |
+
|
158 |
+
@torch.no_grad()
|
159 |
+
def test(data):
|
160 |
+
model.eval()
|
161 |
+
pred = model(data.x_dict, data.edge_index_dict,
|
162 |
+
data['user', 'movie'].edge_label_index)
|
163 |
+
pred = pred.clamp(min=0, max=5)
|
164 |
+
target = data['user', 'movie'].edge_label.float()
|
165 |
+
rmse = F.mse_loss(pred, target).sqrt()
|
166 |
+
return float(rmse)
|
167 |
+
|
168 |
+
def train():
|
169 |
+
model.train()
|
170 |
+
optimizer.zero_grad()
|
171 |
+
pred = model(train_data.x_dict, train_data.edge_index_dict,
|
172 |
+
train_data['user', 'movie'].edge_label_index)
|
173 |
+
target = train_data['user', 'movie'].edge_label
|
174 |
+
loss = weighted_mse_loss(pred, target, weight)
|
175 |
+
loss.backward()
|
176 |
+
optimizer.step()
|
177 |
+
return float(loss)
|
178 |
|
|
|
179 |
|
180 |
+
#-------------------------------------------------------------------------------------------
|
181 |
+
# SAGE model
|
182 |
+
class GNNEncoder(torch.nn.Module):
|
183 |
+
def __init__(self, hidden_channels, out_channels):
|
184 |
+
super().__init__()
|
185 |
+
# these convolutions have been replicated to match the number of edge types
|
186 |
+
self.conv1 = SAGEConv((-1, -1), hidden_channels)
|
187 |
+
self.conv2 = SAGEConv((-1, -1), out_channels)
|
188 |
+
|
189 |
+
def forward(self, x, edge_index):
|
190 |
+
x = self.conv1(x, edge_index).relu()
|
191 |
+
x = self.conv2(x, edge_index)
|
192 |
+
return x
|
193 |
+
|
194 |
+
class EdgeDecoder(torch.nn.Module):
|
195 |
+
def __init__(self, hidden_channels):
|
196 |
+
super().__init__()
|
197 |
+
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
|
198 |
+
self.lin2 = Linear(hidden_channels, 1)
|
199 |
+
|
200 |
+
def forward(self, z_dict, edge_label_index):
|
201 |
+
row, col = edge_label_index
|
202 |
+
# concat user and movie embeddings
|
203 |
+
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
|
204 |
+
# concatenated embeddings passed to linear layer
|
205 |
+
z = self.lin1(z).relu()
|
206 |
+
z = self.lin2(z)
|
207 |
+
return z.view(-1)
|
208 |
+
|
209 |
+
class Model(torch.nn.Module):
|
210 |
+
def __init__(self, hidden_channels):
|
211 |
+
super().__init__()
|
212 |
+
self.encoder = GNNEncoder(hidden_channels, hidden_channels)
|
213 |
+
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
|
214 |
+
self.decoder = EdgeDecoder(hidden_channels)
|
215 |
+
|
216 |
+
def forward(self, x_dict, edge_index_dict, edge_label_index):
|
217 |
+
# z_dict contains dictionary of movie and user embeddings returned from GraphSage
|
218 |
+
z_dict = self.encoder(x_dict, edge_index_dict)
|
219 |
+
return self.decoder(z_dict, edge_label_index)
|
220 |
+
|
221 |
+
#-------------------------------------------------------------------------------------------
|
222 |
def make_graph():
|
223 |
metadata_path = './sampled_movie_dataset/movies_metadata.csv'
|
224 |
df = pd.read_csv(metadata_path)
|
|
|
414 |
)(data)
|
415 |
|
416 |
return train_data, val_data, test_data
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
def train(train_data, val_data, test_data):
|
421 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
422 |
+
|
423 |
+
#make weight
|
424 |
+
weight = torch.bincount(train_data['user', 'movie'].edge_label)
|
425 |
+
weight = weight.max() / weight
|
426 |
+
model = Model(hidden_channels=32).to(device)
|
427 |
+
with torch.no_grad():
|
428 |
+
model.encoder(train_data.x_dict, train_data.edge_index_dict)
|
429 |
+
|
430 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
431 |
+
|
432 |
+
# Train loop
|
433 |
+
for epoch in range(1, 300):
|
434 |
+
loss = train()
|
435 |
+
train_rmse = test(train_data)
|
436 |
+
val_rmse = test(val_data)
|
437 |
+
test_rmse = test(test_data)
|
438 |
+
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
|
439 |
+
f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')
|
440 |
|
441 |
|
442 |
|