File size: 7,449 Bytes
52da96f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch

from .init_model import model
from .blocks import upload_pdb_button, parse_pdb_file


input_types = ["sequence", "structure", "text"]

input_examples = {
    "sequence": [
        "MQLQRLGAPLLKRLVGGCIRQSTAPIMPCVVVSGSGGFLTPVRTYMPLPNDQSDFSPYIEIDLPSESRIQSLHKSGLAAQEWVACEKVHGTNFGIYLINQGDHEVVRFAKRSGIMDPNENFFGYHILIDEFTAQIRILNDLLKQKYGLSRVGRLVLNGELFGAKYKHPLVPKSEKWCTLPNGKKFPIAGVQIQREPFPQYSPELHFFAFDIKYSVSGAEEDFVLLGYDEFVEFSSKVPNLLYARALVRGTLDECLAFDVENFMTPLPALLGLGNYPLEGNLAEGVVIRHVRRGDPAVEKHNVSTIIKLRCSSFMELKHPGKQKELKETFIDTVRSGALRRVRGNVTVISDSMLPQVEAAANDLLLNNVSDGRLSNVLSKIGREPLLSGEVSQVDVALMLAKDALKDFLKEVDSLVLNTTLAFRKLLITNVYFESKRLVEQKWKELMQEEAAAQSEAIPPLSPAAPTKGE",
        "MSLSTEQMLRDYPRSMQINGQIPKNAIHETYGNDGVDVFIAGSGPIGATYAKLCVEAGLRVVMVEIGAADSFYAVNAEEGTAVPYVPGYHKKNEIEFQKDIDRFVNVIKGALQQVSVPVRNQNVPTLDPGAWSAPPGSSAISNGKNPHQREFENLSAEAVTRGVGGMSTHWTCSTPRIHPPMESLPGIGRPKLSNDPAEDDKEWNELYSEAERLIGTSTKEFDESIRHTLVLRSLQDAYKDRQRIFRPLPLACHRLKNAPEYVEWHSAENLFHSIYNDDKQKKLFTLLTNHRCTRLALTGGYEKKIGAAEVRNLLATRNPSSQLDSYIMAKVYVLASGAIGNPQILYNSGFSGLQVTPRNDSLIPNLGRYITEQPMAFCQIVLRQEFVDSVRDDPYGLPWWKEAVAQHIAKNPTDALPIPFRDPEPQVTTPFTEEHPWHTQIHRDAFSYGAVGPEVDSRVIVDLRWFGATDPEANNLLVFQNDVQDGYSMPQPTFRYRPSTASNVRARKMMADMCEVASNLGGYLPTSPPQFMDPGLALHLAGTTRIGFDKATTVADNNSLVWDFANLYVAGNGTIRTGFGENPTLTSMCHAIKSARSIINTLKGGTDGKNTGEHRNL",
        "MGVHECPAWLWLLLSLLSLPLGLPVLGAPPRLICDSRVLERYLLEAKEAENITTGCAEHCSLNENITVPDTKVNFYAWKRMEVGQQAVEVWQGLALLSEAVLRGQALLVNSSQPWEPLQLHVDKAVSGLRSLTTLLRALGAQKEAISPPDAASAAPLRTITADTFRKLFRVYSNFLRGKLKLYTGEACRTGDR"
    ],
    
    "structure": [
        "ddddddddddddddddddddddddddddddddpdpddpddpqpdddfddpdqqlddadddfaaddpvqvvlcvvvvvlqakkfkwfdadffkkkwkwadpdpdidifidtnvgtdglqpddllclvcvvlsvqlvvllqvvvcvvvvapafrmkmfiwgkdalddpfppadadpdwhagsvgdidgsvpgdrdddpaqhahsdiaietewiwiarnsdpvriqtafqvvvcvsqvprpphhyidgqfmggnllnlldpqqpaaqlrnqqvvnqvgddpprggqfikmfrrpprppvvcvsvrhgihtdghlvnvcvvdppcsvvcccnrcvprnvvscvvvvndhdtdvlsrhhpvlsvllvqllvlldpvllvvldvvvdlpclqvvvqdllnsllsslvvsvvvsvvpddpvnvpgdpvsvvvssvsssvsssvvsvvcvvvvnvvsvvvvvvvddppdpdddpddd",
        "dpdplvvqppdddplqappppfaadpvcvlvdpvaaaeeeeaqallsllllllclvlvgfyeyefqaeqpdwdddpddvpdddftqtqfapcqppvclqpqqvllvvqvvfwdwqeaefdqpppvpddppddhddppdgdddqqhdppfdpqqdlgqatwgghrntcqnhdpqfddawadadpvahqgtfdaldpdpvvrvvlvvvllvvlcvqlvkdqclqvpflqqcllqvllcvvcvvppwhkgggtgswhadpvhsldirhttsssscvvqrvdpssvssydyhyskhqqewhaghdpfgetawtkiarnccvvpvpdrgihigghrfyeypralprvllrcvssvqalqdpggdprhnqdqffalkwfwwkkkfkfffdpvsqvcqcvppppdpssnvqlvvqcvvcvpdpgsgdssrakhfmwtdadpvqqktktwidghhndddddppddpsrmimimiihwafrdrqfgwgfdppgdhpvrttrihtrddgdpvsvvsvvvrlvvsvvssvstgdtdprgpididrrnsvnlieqrqaedddsvngqayqlqhgpsyphygyfdrnhrngigngdcvsvrssssvsnsvvsscvvvvdpdddppdddddd",
        "ddppppdcvvvvvvvvvppppppvppldplvvlldvvllvvqlvllvvllvvcvvpdpnfflqdwqkafdlddpvvvvvpddlllllqlllvrlvsllvrlvsslvslvpdpdrdvvnnvssvvlnvssvvvnvssvslvsvvsnppddppprdddgdididrgssvssvsvssnsvgsvvvssvvssvvvvd"
    ],
    
    "text": [
        "RNA-editing ligase in kinetoplastid mitochondrial.",
        "Oxidase which catalyzes the oxidation of various aldopyranoses and disaccharides.",
        "Erythropoietin for regulation of erythrocyte proliferation and differentiation."
    ]
}

samples = [[s1, s2] for s1, s2 in zip(input_examples["sequence"], input_examples["text"])]


def compute_score(input_type_1: str, input_1: str, input_type_2: str, input_2: str):
    with torch.no_grad():
        input_reprs = []
        
        for input_type, input in [(input_type_1, input_1), (input_type_2, input_2)]:
            if input_type == "sequence":
                input_reprs.append(model.get_protein_repr([input]))
            
            elif input_type == "structure":
                input_reprs.append(model.get_structure_repr([input]))
            
            else:
                input_reprs.append(model.get_text_repr([input]))
        
        score = input_reprs[0] @ input_reprs[1].T / model.temperature
    
    return f"{score.item():.4f}"


def change_input_type(choice_1: str, choice_2: str):
    examples_1 = input_examples[choice_1]
    examples_2 = input_examples[choice_2]
    
    # Change examples if input type is changed
    global samples
    samples = [[s1, s2] for s1, s2 in zip(examples_1, examples_2)]
    
    # Set visibility of upload button
    if choice_1 == "text":
        visible_1 = False
    else:
        visible_1 = True
    
    if choice_2 == "text":
        visible_2 = False
    else:
        visible_2 = True
    
    return (gr.update(samples=samples), "", "", gr.update(visible=visible_1), gr.update(visible=visible_1),
            gr.update(visible=visible_2), gr.update(visible=visible_2))


# Load example from dataset
def load_example(example_id):
    return samples[example_id]


# Build the block for computing protein-text similarity
def build_score_computation():
    gr.Markdown(f"# Compute similarity score between two modalities")
    with gr.Row(equal_height=True):
        with gr.Column():
            # Compute similarity score between sequence and text
            with gr.Row():
                input_1 = gr.Textbox(label="Input 1")
                
                # Choose the type of input 1
                input_type_1 = gr.Dropdown(input_types, label="Input type", value="sequence",
                                           interactive=True, visible=True)
                
                # Provide an upload button to upload a pdb file
                upload_btn_1, chain_box_1 = upload_pdb_button(visible=True)
                upload_btn_1.upload(parse_pdb_file, inputs=[input_type_1, upload_btn_1, chain_box_1], outputs=[input_1])
            
            with gr.Row():
                input_2 = gr.Textbox(label="Input 2")
                
                # Choose the type of input 2
                input_type_2 = gr.Dropdown(input_types, label="Input type", value="text",
                                           interactive=True, visible=True)
                
                # Provide an upload button to upload a pdb file
                upload_btn_2, chain_box_2 = upload_pdb_button(visible=False)
                upload_btn_2.upload(parse_pdb_file, inputs=[input_type_2, upload_btn_2, chain_box_2], outputs=[input_2])
            
            # Provide examples
            examples = gr.Dataset(samples=samples, type="index", components=[input_1, input_2], label="Input examples")
            
            # Add click event to examples
            examples.click(fn=load_example, inputs=[examples], outputs=[input_1, input_2])
            
            compute_btn = gr.Button(value="Compute")
        
        # Change examples based on input type
        input_type_1.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
                            outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
                                     upload_btn_2, chain_box_2])
        
        input_type_2.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
                            outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
                                     upload_btn_2, chain_box_2])
        
        similarity_score = gr.Label(label="similarity score")
        compute_btn.click(fn=compute_score, inputs=[input_type_1, input_1, input_type_2, input_2],
                          outputs=[similarity_score])