Update prediction.py
Browse files- prediction.py +16 -8
prediction.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
from dataloader import CellLoader
|
2 |
|
3 |
-
def
|
4 |
sequence_input,
|
5 |
nucleus_image,
|
|
|
6 |
model,
|
7 |
device
|
8 |
):
|
@@ -15,6 +16,7 @@ def run_image_prediction(
|
|
15 |
:param model_ckpt_path: Path to model checkpoint
|
16 |
:param model_config_path: Path to model config
|
17 |
"""
|
|
|
18 |
# Instantiate dataset object
|
19 |
dataset = CellLoader(
|
20 |
sequence_mode="embedding",
|
@@ -28,20 +30,26 @@ def run_image_prediction(
|
|
28 |
threshold="median",
|
29 |
)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
# Convert SEQUENCE to sequence using dataset.tokenize_sequence()
|
32 |
sequence = dataset.tokenize_sequence(sequence_input)
|
33 |
|
34 |
# Sample from model using provided sequence and nucleus image
|
35 |
-
_,
|
36 |
text=sequence.to(device),
|
37 |
condition=nucleus_image.to(device),
|
38 |
-
|
|
|
39 |
temperature=1,
|
40 |
progress=False,
|
41 |
)
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
predicted_threshold = predicted_threshold.cpu()[0, 0]
|
45 |
-
predicted_heatmap = predicted_heatmap.cpu()[0, 0]
|
46 |
-
|
47 |
-
return predicted_threshold, predicted_heatmap
|
|
|
1 |
from dataloader import CellLoader
|
2 |
|
3 |
+
def run_sequence_prediction(
|
4 |
sequence_input,
|
5 |
nucleus_image,
|
6 |
+
protein_image,
|
7 |
model,
|
8 |
device
|
9 |
):
|
|
|
16 |
:param model_ckpt_path: Path to model checkpoint
|
17 |
:param model_config_path: Path to model config
|
18 |
"""
|
19 |
+
|
20 |
# Instantiate dataset object
|
21 |
dataset = CellLoader(
|
22 |
sequence_mode="embedding",
|
|
|
30 |
threshold="median",
|
31 |
)
|
32 |
|
33 |
+
# Check if sequence is provided and valid
|
34 |
+
if len(sequence_input) == 0:
|
35 |
+
raise ValueError("Sequence must be provided.")
|
36 |
+
|
37 |
+
if "<mask>" not in sequence_input:
|
38 |
+
print("Warning: Sequence does not contain any masked positions to predict.")
|
39 |
+
|
40 |
# Convert SEQUENCE to sequence using dataset.tokenize_sequence()
|
41 |
sequence = dataset.tokenize_sequence(sequence_input)
|
42 |
|
43 |
# Sample from model using provided sequence and nucleus image
|
44 |
+
_, predicted_sequence, _ = model.celle.sample_text(
|
45 |
text=sequence.to(device),
|
46 |
condition=nucleus_image.to(device),
|
47 |
+
image=protein_image.to(device),
|
48 |
+
force_aas=True,
|
49 |
temperature=1,
|
50 |
progress=False,
|
51 |
)
|
52 |
+
|
53 |
+
os.chdir(base_dir)
|
54 |
|
55 |
+
return predicted_sequence
|
|
|
|
|
|
|
|