Update logic.py
Browse files
logic.py
CHANGED
@@ -107,6 +107,50 @@ def create_ratings_graph(user_id, movie_id, ratings):
|
|
107 |
print("Inserting batch the last batch!")
|
108 |
edge_collection.import_bulk(batch)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
#-------------------------------------------------------------------------------------------
|
112 |
|
@@ -268,9 +312,43 @@ def load_data_to_ArangoDB(login):
|
|
268 |
|
269 |
|
270 |
return movie_rec_db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
|
276 |
|
|
|
107 |
print("Inserting batch the last batch!")
|
108 |
edge_collection.import_bulk(batch)
|
109 |
|
110 |
+
def create_pyg_edges(rating_docs):
|
111 |
+
src = []
|
112 |
+
dst = []
|
113 |
+
ratings = []
|
114 |
+
for doc in rating_docs:
|
115 |
+
_from = int(doc['_from'].split('/')[1])
|
116 |
+
_to = int(doc['_to'].split('/')[1])
|
117 |
+
|
118 |
+
src.append(_from)
|
119 |
+
dst.append(_to)
|
120 |
+
ratings.append(int(doc['_rating']))
|
121 |
+
|
122 |
+
edge_index = torch.tensor([src, dst])
|
123 |
+
edge_attr = torch.tensor(ratings)
|
124 |
+
|
125 |
+
return edge_index, edge_attr
|
126 |
+
|
127 |
+
def SequenceEncoder(movie_docs , model_name=None):
|
128 |
+
movie_titles = [doc['movie_title'] for doc in movie_docs]
|
129 |
+
model = SentenceTransformer(model_name, device=device)
|
130 |
+
title_embeddings = model.encode(movie_titles, show_progress_bar=True,
|
131 |
+
convert_to_tensor=True, device=device)
|
132 |
+
|
133 |
+
return title_embeddings
|
134 |
+
|
135 |
+
def GenresEncoder(movie_docs):
|
136 |
+
gen = []
|
137 |
+
#sep = '|'
|
138 |
+
for doc in movie_docs:
|
139 |
+
gen.append(doc['genres'])
|
140 |
+
#genre = doc['movie_genres']
|
141 |
+
#gen.append(genre.split(sep))
|
142 |
+
|
143 |
+
# getting unique genres
|
144 |
+
unique_gen = set(list(itertools.chain(*gen)))
|
145 |
+
print("Number of unqiue genres we have:", unique_gen)
|
146 |
+
|
147 |
+
mapping = {g: i for i, g in enumerate(unique_gen)}
|
148 |
+
x = torch.zeros(len(gen), len(mapping))
|
149 |
+
for i, m_gen in enumerate(gen):
|
150 |
+
for genre in m_gen:
|
151 |
+
x[i, mapping[genre]] = 1
|
152 |
+
return x.to(device)
|
153 |
+
|
154 |
|
155 |
#-------------------------------------------------------------------------------------------
|
156 |
|
|
|
312 |
|
313 |
|
314 |
return movie_rec_db
|
315 |
+
|
316 |
+
|
317 |
+
def make_pyg_graph(movie_rec_db):
|
318 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
319 |
+
users = movie_rec_db.collection('Users')
|
320 |
+
movies = movie_rec_db.collection('Movie')
|
321 |
+
ratings_graph = movie_rec_db.collection('Ratings')
|
322 |
+
|
323 |
+
edge_index, edge_label = create_pyg_edges(movie_rec_db.aql.execute('FOR doc IN Ratings RETURN doc'))
|
324 |
|
325 |
+
title_emb = SequenceEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'), model_name='all-MiniLM-L6-v2')
|
326 |
+
encoded_genres = GenresEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'))
|
327 |
+
movie_x = torch.cat((title_emb, encoded_genres), dim=-1)
|
328 |
+
|
329 |
+
data = HeteroData()
|
330 |
+
data['user'].num_nodes = len(users) # Users do not have any features.
|
331 |
+
data['movie'].x = movie_x
|
332 |
+
data['user', 'rates', 'movie'].edge_index = edge_index
|
333 |
+
data['user', 'rates', 'movie'].edge_label = edge_label
|
334 |
|
335 |
+
# Add user node features for message passing:
|
336 |
+
data['user'].x = torch.eye(data['user'].num_nodes, device=device)
|
337 |
+
del data['user'].num_nodes
|
338 |
+
data = ToUndirected()(data)
|
339 |
+
del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label.
|
340 |
+
|
341 |
+
data = data.to(device)
|
342 |
+
|
343 |
+
train_data, val_data, test_data = T.RandomLinkSplit(
|
344 |
+
num_val=0.1,
|
345 |
+
num_test=0.1,
|
346 |
+
neg_sampling_ratio=0.0,
|
347 |
+
edge_types=[('user', 'rates', 'movie')],
|
348 |
+
rev_edge_types=[('movie', 'rev_rates', 'user')],
|
349 |
+
)(data)
|
350 |
+
|
351 |
+
|
352 |
|
353 |
|
354 |
|