Emaad commited on
Commit
80a49a4
·
1 Parent(s): b3933a0

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +11 -18
prediction.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  os.chdir('..')
 
3
  from dataloader import CellLoader
4
  from celle_main import instantiate_from_config
5
  from omegaconf import OmegaConf
@@ -53,30 +54,22 @@ def run_sequence_prediction(
53
  # Set condition_model_path and vqgan_model_path to None
54
  config["model"]["params"]["condition_model_path"] = None
55
  config["model"]["params"]["vqgan_model_path"] = None
 
 
56
 
57
  # Instantiate model from config and move to device
58
- model = instantiate_from_config(config).to(device)
59
 
60
  # Sample from model using provided sequence and nucleus image
61
  _, predicted_sequence, _ = model.celle.sample_text(
62
- text=sequence,
63
- condition=nucleus_image,
64
- image=protein_image,
65
  force_aas=True,
66
- timesteps=1,
67
  temperature=1,
68
- progress=True,
69
  )
 
 
70
 
71
- formatted_predicted_sequence = ""
72
-
73
- for i in range(min(len(predicted_sequence), len(sequence))):
74
- if predicted_sequence[i] != sequence[i]:
75
- formatted_predicted_sequence += f"**{predicted_sequence[i]}**"
76
- else:
77
- formatted_predicted_sequence += predicted_sequence[i]
78
-
79
- if len(predicted_sequence) > len(sequence):
80
- formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**"
81
-
82
- return formatted_predicted_sequence
 
1
  import os
2
  os.chdir('..')
3
+ base_dir = os.getcwd()
4
  from dataloader import CellLoader
5
  from celle_main import instantiate_from_config
6
  from omegaconf import OmegaConf
 
54
  # Set condition_model_path and vqgan_model_path to None
55
  config["model"]["params"]["condition_model_path"] = None
56
  config["model"]["params"]["vqgan_model_path"] = None
57
+
58
+ os.chdir(os.path.dirname(model_ckpt_path))
59
 
60
  # Instantiate model from config and move to device
61
+ model = instantiate_from_config(config.model).to(device)
62
 
63
  # Sample from model using provided sequence and nucleus image
64
  _, predicted_sequence, _ = model.celle.sample_text(
65
+ text=sequence.to(device),
66
+ condition=nucleus_image.to(device),
67
+ image=protein_image.to(device),
68
  force_aas=True,
 
69
  temperature=1,
70
+ progress=False,
71
  )
72
+
73
+ os.chdir(base_dir)
74
 
75
+ return predicted_sequence