File size: 3,849 Bytes
548170b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3933a0
 
 
 
 
 
 
 
 
 
 
 
 
548170b
b3933a0
548170b
b3933a0
 
548170b
 
 
 
 
 
 
 
b3933a0
548170b
 
 
 
 
b3933a0
548170b
 
b3933a0
548170b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3933a0
548170b
 
 
 
 
 
b3933a0
548170b
 
 
 
 
 
 
 
b3933a0
548170b
 
b3933a0
 
548170b
 
 
 
 
 
 
 
 
 
b3933a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from huggingface_hub import hf_hub_download
from prediction import run_sequence_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from PIL import Image
from matplotlib import pyplot as plt


def gradio_demo(model_name, sequence_input, image):
    model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
    config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
    hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
    hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if "Finetuned" in model_name:
        dataset = "OpenCell"

    else:
        dataset = "HPA"
        
        
    nucleus_image = image['image'].convert('L')
    protein_image = image['mask'].convert('L')

    to_tensor = T.ToTensor()
    nucleus_tensor = to_tensor(nucleus_image)
    protein_tensor = to_tensor(protein_image)
    stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0)
    processed_images = process_image(stacked_images, dataset)

    nucleus_image = processed_images[0].unsqueeze(0)
    protein_image = processed_images[1].unsqueeze(0)
    protein_image = protein_image > 0
    protein_image = 1.0 * protein_image
    
    print(f'{protein_image.sum()}')
    

    formatted_predicted_sequence = run_sequence_prediction(
        sequence_input=sequence_input,
        nucleus_image=nucleus_image,
        protein_image=protein_image,
        model_ckpt_path=model,
        model_config_path=config,
        device=device,
    )
    
    return formatted_predicted_sequence


with gr.Blocks() as demo:
    gr.Markdown("Select the prediction model.")
    gr.Markdown(
        "- CELL-E_2_HPA_2560 is a good general purpose model for various cell types using ICC-IF."
    )
    gr.Markdown(
        "- CELL-E_2_OpenCell_2560 is trained on OpenCell and is good more live-cell predictions on HEK cells."
    )
    with gr.Row():
        model_name = gr.Dropdown(
            ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
            value="CELL-E_2_HPA_2560",
            label="Model Name",
        )
    with gr.Row():
        gr.Markdown(
            "Input the desired amino acid sequence. GFP is shown below by default."
        )

    with gr.Row():
        sequence_input = gr.Textbox(
            value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
            label="Sequence",
        )
    with gr.Row():
        gr.Markdown(
            "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image."
        )

    with gr.Row().style(equal_height=True):
        nucleus_image = gr.Image(
            source="upload", 
            tool="sketch",
            invert_colors=True,
            label="Nucleus Image", 
            line_color="white",
            interactive=True,
            image_mode="L",
            type="pil"
        )

    with gr.Row():
        gr.Markdown("Sequence predictions are show below.")

    with gr.Row().style(equal_height=True):
        predicted_sequence = gr.Textbox(label='Predicted Sequence')


    with gr.Row():
        button = gr.Button("Run Model")

        inputs = [model_name, sequence_input, nucleus_image]

        outputs = [predicted_sequence]

        button.click(gradio_demo, inputs, outputs)

demo.launch(enable_queue=True)