Polo123 commited on
Commit
d0c5644
·
verified ·
1 Parent(s): 64e5385

Update logic.py

Browse files
Files changed (1) hide show
  1. logic.py +79 -1
logic.py CHANGED
@@ -107,6 +107,50 @@ def create_ratings_graph(user_id, movie_id, ratings):
107
  print("Inserting batch the last batch!")
108
  edge_collection.import_bulk(batch)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  #-------------------------------------------------------------------------------------------
112
 
@@ -268,9 +312,43 @@ def load_data_to_ArangoDB(login):
268
 
269
 
270
  return movie_rec_db
 
 
 
 
 
 
 
 
 
271
 
 
 
 
 
 
 
 
 
 
272
 
273
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
 
276
 
 
107
  print("Inserting batch the last batch!")
108
  edge_collection.import_bulk(batch)
109
 
110
+ def create_pyg_edges(rating_docs):
111
+ src = []
112
+ dst = []
113
+ ratings = []
114
+ for doc in rating_docs:
115
+ _from = int(doc['_from'].split('/')[1])
116
+ _to = int(doc['_to'].split('/')[1])
117
+
118
+ src.append(_from)
119
+ dst.append(_to)
120
+ ratings.append(int(doc['_rating']))
121
+
122
+ edge_index = torch.tensor([src, dst])
123
+ edge_attr = torch.tensor(ratings)
124
+
125
+ return edge_index, edge_attr
126
+
127
+ def SequenceEncoder(movie_docs , model_name=None):
128
+ movie_titles = [doc['movie_title'] for doc in movie_docs]
129
+ model = SentenceTransformer(model_name, device=device)
130
+ title_embeddings = model.encode(movie_titles, show_progress_bar=True,
131
+ convert_to_tensor=True, device=device)
132
+
133
+ return title_embeddings
134
+
135
+ def GenresEncoder(movie_docs):
136
+ gen = []
137
+ #sep = '|'
138
+ for doc in movie_docs:
139
+ gen.append(doc['genres'])
140
+ #genre = doc['movie_genres']
141
+ #gen.append(genre.split(sep))
142
+
143
+ # getting unique genres
144
+ unique_gen = set(list(itertools.chain(*gen)))
145
+ print("Number of unqiue genres we have:", unique_gen)
146
+
147
+ mapping = {g: i for i, g in enumerate(unique_gen)}
148
+ x = torch.zeros(len(gen), len(mapping))
149
+ for i, m_gen in enumerate(gen):
150
+ for genre in m_gen:
151
+ x[i, mapping[genre]] = 1
152
+ return x.to(device)
153
+
154
 
155
  #-------------------------------------------------------------------------------------------
156
 
 
312
 
313
 
314
  return movie_rec_db
315
+
316
+
317
+ def make_pyg_graph(movie_rec_db):
318
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
319
+ users = movie_rec_db.collection('Users')
320
+ movies = movie_rec_db.collection('Movie')
321
+ ratings_graph = movie_rec_db.collection('Ratings')
322
+
323
+ edge_index, edge_label = create_pyg_edges(movie_rec_db.aql.execute('FOR doc IN Ratings RETURN doc'))
324
 
325
+ title_emb = SequenceEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'), model_name='all-MiniLM-L6-v2')
326
+ encoded_genres = GenresEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'))
327
+ movie_x = torch.cat((title_emb, encoded_genres), dim=-1)
328
+
329
+ data = HeteroData()
330
+ data['user'].num_nodes = len(users) # Users do not have any features.
331
+ data['movie'].x = movie_x
332
+ data['user', 'rates', 'movie'].edge_index = edge_index
333
+ data['user', 'rates', 'movie'].edge_label = edge_label
334
 
335
+ # Add user node features for message passing:
336
+ data['user'].x = torch.eye(data['user'].num_nodes, device=device)
337
+ del data['user'].num_nodes
338
+ data = ToUndirected()(data)
339
+ del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label.
340
+
341
+ data = data.to(device)
342
+
343
+ train_data, val_data, test_data = T.RandomLinkSplit(
344
+ num_val=0.1,
345
+ num_test=0.1,
346
+ neg_sampling_ratio=0.0,
347
+ edge_types=[('user', 'rates', 'movie')],
348
+ rev_edge_types=[('movie', 'rev_rates', 'user')],
349
+ )(data)
350
+
351
+
352
 
353
 
354