File size: 3,849 Bytes
548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 548170b b3933a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
from huggingface_hub import hf_hub_download
from prediction import run_sequence_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from PIL import Image
from matplotlib import pyplot as plt
def gradio_demo(model_name, sequence_input, image):
model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if "Finetuned" in model_name:
dataset = "OpenCell"
else:
dataset = "HPA"
nucleus_image = image['image'].convert('L')
protein_image = image['mask'].convert('L')
to_tensor = T.ToTensor()
nucleus_tensor = to_tensor(nucleus_image)
protein_tensor = to_tensor(protein_image)
stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0)
processed_images = process_image(stacked_images, dataset)
nucleus_image = processed_images[0].unsqueeze(0)
protein_image = processed_images[1].unsqueeze(0)
protein_image = protein_image > 0
protein_image = 1.0 * protein_image
print(f'{protein_image.sum()}')
formatted_predicted_sequence = run_sequence_prediction(
sequence_input=sequence_input,
nucleus_image=nucleus_image,
protein_image=protein_image,
model_ckpt_path=model,
model_config_path=config,
device=device,
)
return formatted_predicted_sequence
with gr.Blocks() as demo:
gr.Markdown("Select the prediction model.")
gr.Markdown(
"- CELL-E_2_HPA_2560 is a good general purpose model for various cell types using ICC-IF."
)
gr.Markdown(
"- CELL-E_2_OpenCell_2560 is trained on OpenCell and is good more live-cell predictions on HEK cells."
)
with gr.Row():
model_name = gr.Dropdown(
["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
value="CELL-E_2_HPA_2560",
label="Model Name",
)
with gr.Row():
gr.Markdown(
"Input the desired amino acid sequence. GFP is shown below by default."
)
with gr.Row():
sequence_input = gr.Textbox(
value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
label="Sequence",
)
with gr.Row():
gr.Markdown(
"Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image."
)
with gr.Row().style(equal_height=True):
nucleus_image = gr.Image(
source="upload",
tool="sketch",
invert_colors=True,
label="Nucleus Image",
line_color="white",
interactive=True,
image_mode="L",
type="pil"
)
with gr.Row():
gr.Markdown("Sequence predictions are show below.")
with gr.Row().style(equal_height=True):
predicted_sequence = gr.Textbox(label='Predicted Sequence')
with gr.Row():
button = gr.Button("Run Model")
inputs = [model_name, sequence_input, nucleus_image]
outputs = [predicted_sequence]
button.click(gradio_demo, inputs, outputs)
demo.launch(enable_queue=True)
|