ribesstefano commited on
Commit
c06df22
·
1 Parent(s): e21cad4

Fixed code on loading cell embeddings for package

Browse files
protac_degradation_predictor/data_utils.py CHANGED
@@ -19,24 +19,25 @@ memory = Memory(cachedir, verbose=0)
19
 
20
 
21
  @memory.cache
22
- def load_cell2embedding(
23
- embeddings_path: Optional[str] = None,
24
  ) -> Dict[str, np.ndarray]:
25
- """ Load the cell line embeddings from a file.
26
-
27
  Args:
28
  embeddings_path (str): The path to the embeddings file.
29
-
30
  Returns:
31
- Dict[str, np.ndarray]: A dictionary of cell line embeddings.
32
  """
33
  if embeddings_path is None:
34
- with pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl') as f:
35
- cell2embedding = pickle.load(f)
36
- else:
37
- with open(embeddings_path, 'rb') as f:
38
- cell2embedding = pickle.load(f)
39
- return cell2embedding
 
40
 
41
 
42
  @memory.cache
@@ -52,9 +53,11 @@ def load_cell2embedding(
52
  Dict[str, np.ndarray]: A dictionary of cell line embeddings.
53
  """
54
  if embeddings_path is None:
55
- embeddings_path = pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl')
56
- with open(embeddings_path, 'rb') as f:
57
- cell2embedding = pickle.load(f)
 
 
58
  return cell2embedding
59
 
60
 
 
19
 
20
 
21
  @memory.cache
22
+ def load_protein2embedding(
23
+ embeddings_path: Optional[str] = None,
24
  ) -> Dict[str, np.ndarray]:
25
+ """ Load the protein embeddings from a file.
26
+
27
  Args:
28
  embeddings_path (str): The path to the embeddings file.
29
+
30
  Returns:
31
+ Dict[str, np.ndarray]: A dictionary of protein embeddings.
32
  """
33
  if embeddings_path is None:
34
+ embeddings_path = pkg_resources.resource_stream(__name__, 'data/uniprot2embedding.h5')
35
+ protein2embedding = {}
36
+ with h5py.File(embeddings_path, "r") as file:
37
+ for sequence_id in file.keys():
38
+ embedding = file[sequence_id][:]
39
+ protein2embedding[sequence_id] = np.array(embedding)
40
+ return protein2embedding
41
 
42
 
43
  @memory.cache
 
53
  Dict[str, np.ndarray]: A dictionary of cell line embeddings.
54
  """
55
  if embeddings_path is None:
56
+ with pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl') as f:
57
+ cell2embedding = pickle.load(f)
58
+ else:
59
+ with open(embeddings_path, 'rb') as f:
60
+ cell2embedding = pickle.load(f)
61
  return cell2embedding
62
 
63