|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from prediction import run_sequence_prediction |
|
import torch |
|
import torchvision.transforms as T |
|
from celle.utils import process_image |
|
from PIL import Image |
|
from matplotlib import pyplot as plt |
|
|
|
|
|
def gradio_demo(model_name, sequence_input, image): |
|
model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt") |
|
config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if "Finetuned" in model_name: |
|
dataset = "OpenCell" |
|
|
|
else: |
|
dataset = "HPA" |
|
|
|
|
|
nucleus_image = image['image'].convert('L') |
|
protein_image = image['mask'].convert('L') |
|
|
|
to_tensor = T.ToTensor() |
|
nucleus_tensor = to_tensor(nucleus_image) |
|
protein_tensor = to_tensor(protein_image) |
|
stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0) |
|
processed_images = process_image(stacked_images, dataset) |
|
|
|
nucleus_image = processed_images[0].unsqueeze(0) |
|
protein_image = processed_images[1].unsqueeze(0) |
|
protein_image = protein_image > 0 |
|
protein_image = 1.0 * protein_image |
|
|
|
print(f'{protein_image.sum()}') |
|
|
|
|
|
formatted_predicted_sequence = run_sequence_prediction( |
|
sequence_input=sequence_input, |
|
nucleus_image=nucleus_image, |
|
protein_image=protein_image, |
|
model_ckpt_path=model, |
|
model_config_path=config, |
|
device=device, |
|
) |
|
|
|
return formatted_predicted_sequence |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Select the prediction model.") |
|
gr.Markdown( |
|
"- CELL-E_2_HPA_2560 is a good general purpose model for various cell types using ICC-IF." |
|
) |
|
gr.Markdown( |
|
"- CELL-E_2_OpenCell_2560 is trained on OpenCell and is good more live-cell predictions on HEK cells." |
|
) |
|
with gr.Row(): |
|
model_name = gr.Dropdown( |
|
["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"], |
|
value="CELL-E_2_HPA_2560", |
|
label="Model Name", |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"Input the desired amino acid sequence. GFP is shown below by default." |
|
) |
|
|
|
with gr.Row(): |
|
sequence_input = gr.Textbox( |
|
value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", |
|
label="Sequence", |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image." |
|
) |
|
|
|
with gr.Row().style(equal_height=True): |
|
nucleus_image = gr.Image( |
|
source="upload", |
|
tool="sketch", |
|
invert_colors=True, |
|
label="Nucleus Image", |
|
line_color="white", |
|
interactive=True, |
|
image_mode="L", |
|
type="pil" |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("Sequence predictions are show below.") |
|
|
|
with gr.Row().style(equal_height=True): |
|
predicted_sequence = gr.Textbox(label='Predicted Sequence') |
|
|
|
|
|
with gr.Row(): |
|
button = gr.Button("Run Model") |
|
|
|
inputs = [model_name, sequence_input, nucleus_image] |
|
|
|
outputs = [predicted_sequence] |
|
|
|
button.click(gradio_demo, inputs, outputs) |
|
|
|
demo.launch(enable_queue=True) |
|
|