Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- lightGCNModel_num_layers_MovieLens100K_checkpoint.pt +3 -0
- processed_MVL_light.pt +3 -0
- requirements.txt +0 -0
- u.item +0 -0
- u1.base +0 -0
- utils.py +60 -0
lightGCNModel_num_layers_MovieLens100K_checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa4a3960c3b792ea38bcba6e94ac96d85c95eb1fab88f0ed67e36060dea6ae0a
|
| 3 |
+
size 2019334
|
processed_MVL_light.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c19e3fab1ca895a2a7c25ce4692f9d19ace5179ccdce4b0786e2f61b1a24d44
|
| 3 |
+
size 1772240
|
requirements.txt
ADDED
|
Binary file (4.36 kB). View file
|
|
|
u.item
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
u1.base
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from torch_geometric.datasets import AmazonBook, MovieLens100K, MovieLens1M
|
| 10 |
+
from torch_geometric.nn import GCNConv, LGConv
|
| 11 |
+
from torch_geometric.utils import degree
|
| 12 |
+
from torch_geometric.nn.conv import MessagePassing
|
| 13 |
+
from torch_geometric.data import HeteroData, Data
|
| 14 |
+
import torch_geometric.transforms as T
|
| 15 |
+
|
| 16 |
+
def predict(model, device, data, num_users, num_items, user_id, train_edge_label_index, k=5):
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
|
| 19 |
+
## ML100k
|
| 20 |
+
interaction_dataframe = pd.read_csv('./u1.base', delim_whitespace=True, header=None)
|
| 21 |
+
meta_dataframe = pd.read_csv('./u.item', sep='|', encoding='latin-1', header=None)
|
| 22 |
+
interaction_dataframe = interaction_dataframe[[0, 1]]
|
| 23 |
+
interaction_dataframe.columns = ['reviewerID', 'asin']
|
| 24 |
+
|
| 25 |
+
meta_dataframe = meta_dataframe[[0, 1]]
|
| 26 |
+
meta_dataframe.columns = ['asin', 'title']
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
out = model.get_embedding(data.edge_index)
|
| 30 |
+
user_emb, item_emb = out[:num_users], out[num_users:]
|
| 31 |
+
logits = user_emb @ item_emb.t()
|
| 32 |
+
logits = torch.nn.Sigmoid()(logits)
|
| 33 |
+
logits[train_edge_label_index[0], train_edge_label_index[1]-num_users] = float('-inf')
|
| 34 |
+
|
| 35 |
+
# Create unique users to find the index of it in embedding table
|
| 36 |
+
unique_users = interaction_dataframe['reviewerID'].unique().tolist()
|
| 37 |
+
unique_items = interaction_dataframe['asin'].unique().tolist()
|
| 38 |
+
|
| 39 |
+
random_row = random.randint(0, len(interaction_dataframe))
|
| 40 |
+
user_to_rec = interaction_dataframe.iloc[random_row]['reviewerID']
|
| 41 |
+
user_to_rec = user_id
|
| 42 |
+
#user_to_rec = 923
|
| 43 |
+
user_rates = logits[unique_users.index(user_to_rec)]
|
| 44 |
+
|
| 45 |
+
# print(f"ID of user we want to recommend to: {user_to_rec}")
|
| 46 |
+
|
| 47 |
+
ground_truth_asins = interaction_dataframe[interaction_dataframe['reviewerID'] == user_to_rec]['asin'].to_list()
|
| 48 |
+
|
| 49 |
+
ground_truth_items = meta_dataframe[meta_dataframe['asin'].isin(ground_truth_asins)].head(5)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
_, top_ratings = torch.topk(user_rates, k)
|
| 53 |
+
|
| 54 |
+
recommended_items = []
|
| 55 |
+
for index in top_ratings:
|
| 56 |
+
asin_of_item = unique_items[index]
|
| 57 |
+
recommended_item = meta_dataframe[meta_dataframe['asin'] == asin_of_item]['title'].values
|
| 58 |
+
recommended_items.append(recommended_item)
|
| 59 |
+
|
| 60 |
+
return ground_truth_items.to_list(), recommended_items
|