|
from dataloader import CellLoader |
|
|
|
def run_sequence_prediction( |
|
sequence_input, |
|
nucleus_image, |
|
protein_image, |
|
model, |
|
device |
|
): |
|
""" |
|
Run Celle model with provided inputs and display results. |
|
|
|
:param sequence: Path to sequence file |
|
:param nucleus_image_path: Path to nucleus image |
|
:param protein_image_path: Path to protein image (optional) |
|
:param model_ckpt_path: Path to model checkpoint |
|
:param model_config_path: Path to model config |
|
""" |
|
|
|
|
|
dataset = CellLoader( |
|
sequence_mode="embedding", |
|
vocab="esm2", |
|
split_key="val", |
|
crop_method="center", |
|
resize=600, |
|
crop_size=256, |
|
text_seq_len=1000, |
|
pad_mode="end", |
|
threshold="median", |
|
) |
|
|
|
|
|
if len(sequence_input) == 0: |
|
raise ValueError("Sequence must be provided.") |
|
|
|
if "<mask>" not in sequence_input: |
|
print("Warning: Sequence does not contain any masked positions to predict.") |
|
|
|
|
|
sequence = dataset.tokenize_sequence(sequence_input) |
|
|
|
|
|
_, predicted_sequence, _ = model.celle.sample_text( |
|
text=sequence.to(device), |
|
condition=nucleus_image.to(device), |
|
image=protein_image.to(device), |
|
force_aas=True, |
|
temperature=1, |
|
progress=False, |
|
) |
|
|
|
return predicted_sequence |