Update prediction.py
Browse files- prediction.py +11 -18
prediction.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
os.chdir('..')
|
|
|
3 |
from dataloader import CellLoader
|
4 |
from celle_main import instantiate_from_config
|
5 |
from omegaconf import OmegaConf
|
@@ -53,30 +54,22 @@ def run_sequence_prediction(
|
|
53 |
# Set condition_model_path and vqgan_model_path to None
|
54 |
config["model"]["params"]["condition_model_path"] = None
|
55 |
config["model"]["params"]["vqgan_model_path"] = None
|
|
|
|
|
56 |
|
57 |
# Instantiate model from config and move to device
|
58 |
-
model = instantiate_from_config(config).to(device)
|
59 |
|
60 |
# Sample from model using provided sequence and nucleus image
|
61 |
_, predicted_sequence, _ = model.celle.sample_text(
|
62 |
-
text=sequence,
|
63 |
-
condition=nucleus_image,
|
64 |
-
image=protein_image,
|
65 |
force_aas=True,
|
66 |
-
timesteps=1,
|
67 |
temperature=1,
|
68 |
-
progress=
|
69 |
)
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
for i in range(min(len(predicted_sequence), len(sequence))):
|
74 |
-
if predicted_sequence[i] != sequence[i]:
|
75 |
-
formatted_predicted_sequence += f"**{predicted_sequence[i]}**"
|
76 |
-
else:
|
77 |
-
formatted_predicted_sequence += predicted_sequence[i]
|
78 |
-
|
79 |
-
if len(predicted_sequence) > len(sequence):
|
80 |
-
formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**"
|
81 |
-
|
82 |
-
return formatted_predicted_sequence
|
|
|
1 |
import os
|
2 |
os.chdir('..')
|
3 |
+
base_dir = os.getcwd()
|
4 |
from dataloader import CellLoader
|
5 |
from celle_main import instantiate_from_config
|
6 |
from omegaconf import OmegaConf
|
|
|
54 |
# Set condition_model_path and vqgan_model_path to None
|
55 |
config["model"]["params"]["condition_model_path"] = None
|
56 |
config["model"]["params"]["vqgan_model_path"] = None
|
57 |
+
|
58 |
+
os.chdir(os.path.dirname(model_ckpt_path))
|
59 |
|
60 |
# Instantiate model from config and move to device
|
61 |
+
model = instantiate_from_config(config.model).to(device)
|
62 |
|
63 |
# Sample from model using provided sequence and nucleus image
|
64 |
_, predicted_sequence, _ = model.celle.sample_text(
|
65 |
+
text=sequence.to(device),
|
66 |
+
condition=nucleus_image.to(device),
|
67 |
+
image=protein_image.to(device),
|
68 |
force_aas=True,
|
|
|
69 |
temperature=1,
|
70 |
+
progress=False,
|
71 |
)
|
72 |
+
|
73 |
+
os.chdir(base_dir)
|
74 |
|
75 |
+
return predicted_sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|