Emaad commited on
Commit
22f2c54
·
1 Parent(s): 61dc572

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +16 -8
prediction.py CHANGED
@@ -1,8 +1,9 @@
1
  from dataloader import CellLoader
2
 
3
- def run_image_prediction(
4
  sequence_input,
5
  nucleus_image,
 
6
  model,
7
  device
8
  ):
@@ -15,6 +16,7 @@ def run_image_prediction(
15
  :param model_ckpt_path: Path to model checkpoint
16
  :param model_config_path: Path to model config
17
  """
 
18
  # Instantiate dataset object
19
  dataset = CellLoader(
20
  sequence_mode="embedding",
@@ -28,20 +30,26 @@ def run_image_prediction(
28
  threshold="median",
29
  )
30
 
 
 
 
 
 
 
 
31
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
32
  sequence = dataset.tokenize_sequence(sequence_input)
33
 
34
  # Sample from model using provided sequence and nucleus image
35
- _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
36
  text=sequence.to(device),
37
  condition=nucleus_image.to(device),
38
- timesteps=1,
 
39
  temperature=1,
40
  progress=False,
41
  )
 
 
42
 
43
- # Move predicted_threshold and predicted_heatmap to CPU and select first element of batch
44
- predicted_threshold = predicted_threshold.cpu()[0, 0]
45
- predicted_heatmap = predicted_heatmap.cpu()[0, 0]
46
-
47
- return predicted_threshold, predicted_heatmap
 
1
  from dataloader import CellLoader
2
 
3
+ def run_sequence_prediction(
4
  sequence_input,
5
  nucleus_image,
6
+ protein_image,
7
  model,
8
  device
9
  ):
 
16
  :param model_ckpt_path: Path to model checkpoint
17
  :param model_config_path: Path to model config
18
  """
19
+
20
  # Instantiate dataset object
21
  dataset = CellLoader(
22
  sequence_mode="embedding",
 
30
  threshold="median",
31
  )
32
 
33
+ # Check if sequence is provided and valid
34
+ if len(sequence_input) == 0:
35
+ raise ValueError("Sequence must be provided.")
36
+
37
+ if "<mask>" not in sequence_input:
38
+ print("Warning: Sequence does not contain any masked positions to predict.")
39
+
40
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
41
  sequence = dataset.tokenize_sequence(sequence_input)
42
 
43
  # Sample from model using provided sequence and nucleus image
44
+ _, predicted_sequence, _ = model.celle.sample_text(
45
  text=sequence.to(device),
46
  condition=nucleus_image.to(device),
47
+ image=protein_image.to(device),
48
+ force_aas=True,
49
  temperature=1,
50
  progress=False,
51
  )
52
+
53
+ os.chdir(base_dir)
54
 
55
+ return predicted_sequence