Polo123 commited on
Commit
c3a89b9
·
verified ·
1 Parent(s): aeb0bfe

Update logic.py

Browse files
Files changed (1) hide show
  1. logic.py +89 -2
logic.py CHANGED
@@ -151,9 +151,74 @@ def GenresEncoder(movie_docs):
151
  x[i, mapping[genre]] = 1
152
  return x.to(device)
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- #-------------------------------------------------------------------------------------------
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def make_graph():
158
  metadata_path = './sampled_movie_dataset/movies_metadata.csv'
159
  df = pd.read_csv(metadata_path)
@@ -349,7 +414,29 @@ def make_pyg_graph(movie_rec_db):
349
  )(data)
350
 
351
  return train_data, val_data, test_data
352
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
 
355
 
 
151
  x[i, mapping[genre]] = 1
152
  return x.to(device)
153
 
154
+ def weighted_mse_loss(pred, target, weight=None):
155
+ weight = 1. if weight is None else weight[target].to(pred.dtype)
156
+ return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()
157
+
158
+ @torch.no_grad()
159
+ def test(data):
160
+ model.eval()
161
+ pred = model(data.x_dict, data.edge_index_dict,
162
+ data['user', 'movie'].edge_label_index)
163
+ pred = pred.clamp(min=0, max=5)
164
+ target = data['user', 'movie'].edge_label.float()
165
+ rmse = F.mse_loss(pred, target).sqrt()
166
+ return float(rmse)
167
+
168
+ def train():
169
+ model.train()
170
+ optimizer.zero_grad()
171
+ pred = model(train_data.x_dict, train_data.edge_index_dict,
172
+ train_data['user', 'movie'].edge_label_index)
173
+ target = train_data['user', 'movie'].edge_label
174
+ loss = weighted_mse_loss(pred, target, weight)
175
+ loss.backward()
176
+ optimizer.step()
177
+ return float(loss)
178
 
 
179
 
180
+ #-------------------------------------------------------------------------------------------
181
+ # SAGE model
182
+ class GNNEncoder(torch.nn.Module):
183
+ def __init__(self, hidden_channels, out_channels):
184
+ super().__init__()
185
+ # these convolutions have been replicated to match the number of edge types
186
+ self.conv1 = SAGEConv((-1, -1), hidden_channels)
187
+ self.conv2 = SAGEConv((-1, -1), out_channels)
188
+
189
+ def forward(self, x, edge_index):
190
+ x = self.conv1(x, edge_index).relu()
191
+ x = self.conv2(x, edge_index)
192
+ return x
193
+
194
+ class EdgeDecoder(torch.nn.Module):
195
+ def __init__(self, hidden_channels):
196
+ super().__init__()
197
+ self.lin1 = Linear(2 * hidden_channels, hidden_channels)
198
+ self.lin2 = Linear(hidden_channels, 1)
199
+
200
+ def forward(self, z_dict, edge_label_index):
201
+ row, col = edge_label_index
202
+ # concat user and movie embeddings
203
+ z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
204
+ # concatenated embeddings passed to linear layer
205
+ z = self.lin1(z).relu()
206
+ z = self.lin2(z)
207
+ return z.view(-1)
208
+
209
+ class Model(torch.nn.Module):
210
+ def __init__(self, hidden_channels):
211
+ super().__init__()
212
+ self.encoder = GNNEncoder(hidden_channels, hidden_channels)
213
+ self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
214
+ self.decoder = EdgeDecoder(hidden_channels)
215
+
216
+ def forward(self, x_dict, edge_index_dict, edge_label_index):
217
+ # z_dict contains dictionary of movie and user embeddings returned from GraphSage
218
+ z_dict = self.encoder(x_dict, edge_index_dict)
219
+ return self.decoder(z_dict, edge_label_index)
220
+
221
+ #-------------------------------------------------------------------------------------------
222
  def make_graph():
223
  metadata_path = './sampled_movie_dataset/movies_metadata.csv'
224
  df = pd.read_csv(metadata_path)
 
414
  )(data)
415
 
416
  return train_data, val_data, test_data
417
+
418
+
419
+
420
+ def train(train_data, val_data, test_data):
421
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
422
+
423
+ #make weight
424
+ weight = torch.bincount(train_data['user', 'movie'].edge_label)
425
+ weight = weight.max() / weight
426
+ model = Model(hidden_channels=32).to(device)
427
+ with torch.no_grad():
428
+ model.encoder(train_data.x_dict, train_data.edge_index_dict)
429
+
430
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
431
+
432
+ # Train loop
433
+ for epoch in range(1, 300):
434
+ loss = train()
435
+ train_rmse = test(train_data)
436
+ val_rmse = test(val_data)
437
+ test_rmse = test(test_data)
438
+ print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
439
+ f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')
440
 
441
 
442