import gradio as gr import torch from .init_model import model from .blocks import upload_pdb_button, parse_pdb_file input_types = ["sequence", "structure", "text"] input_examples = { "sequence": [ "MQLQRLGAPLLKRLVGGCIRQSTAPIMPCVVVSGSGGFLTPVRTYMPLPNDQSDFSPYIEIDLPSESRIQSLHKSGLAAQEWVACEKVHGTNFGIYLINQGDHEVVRFAKRSGIMDPNENFFGYHILIDEFTAQIRILNDLLKQKYGLSRVGRLVLNGELFGAKYKHPLVPKSEKWCTLPNGKKFPIAGVQIQREPFPQYSPELHFFAFDIKYSVSGAEEDFVLLGYDEFVEFSSKVPNLLYARALVRGTLDECLAFDVENFMTPLPALLGLGNYPLEGNLAEGVVIRHVRRGDPAVEKHNVSTIIKLRCSSFMELKHPGKQKELKETFIDTVRSGALRRVRGNVTVISDSMLPQVEAAANDLLLNNVSDGRLSNVLSKIGREPLLSGEVSQVDVALMLAKDALKDFLKEVDSLVLNTTLAFRKLLITNVYFESKRLVEQKWKELMQEEAAAQSEAIPPLSPAAPTKGE", "MSLSTEQMLRDYPRSMQINGQIPKNAIHETYGNDGVDVFIAGSGPIGATYAKLCVEAGLRVVMVEIGAADSFYAVNAEEGTAVPYVPGYHKKNEIEFQKDIDRFVNVIKGALQQVSVPVRNQNVPTLDPGAWSAPPGSSAISNGKNPHQREFENLSAEAVTRGVGGMSTHWTCSTPRIHPPMESLPGIGRPKLSNDPAEDDKEWNELYSEAERLIGTSTKEFDESIRHTLVLRSLQDAYKDRQRIFRPLPLACHRLKNAPEYVEWHSAENLFHSIYNDDKQKKLFTLLTNHRCTRLALTGGYEKKIGAAEVRNLLATRNPSSQLDSYIMAKVYVLASGAIGNPQILYNSGFSGLQVTPRNDSLIPNLGRYITEQPMAFCQIVLRQEFVDSVRDDPYGLPWWKEAVAQHIAKNPTDALPIPFRDPEPQVTTPFTEEHPWHTQIHRDAFSYGAVGPEVDSRVIVDLRWFGATDPEANNLLVFQNDVQDGYSMPQPTFRYRPSTASNVRARKMMADMCEVASNLGGYLPTSPPQFMDPGLALHLAGTTRIGFDKATTVADNNSLVWDFANLYVAGNGTIRTGFGENPTLTSMCHAIKSARSIINTLKGGTDGKNTGEHRNL", "MGVHECPAWLWLLLSLLSLPLGLPVLGAPPRLICDSRVLERYLLEAKEAENITTGCAEHCSLNENITVPDTKVNFYAWKRMEVGQQAVEVWQGLALLSEAVLRGQALLVNSSQPWEPLQLHVDKAVSGLRSLTTLLRALGAQKEAISPPDAASAAPLRTITADTFRKLFRVYSNFLRGKLKLYTGEACRTGDR" ], "structure": [ "ddddddddddddddddddddddddddddddddpdpddpddpqpdddfddpdqqlddadddfaaddpvqvvlcvvvvvlqakkfkwfdadffkkkwkwadpdpdidifidtnvgtdglqpddllclvcvvlsvqlvvllqvvvcvvvvapafrmkmfiwgkdalddpfppadadpdwhagsvgdidgsvpgdrdddpaqhahsdiaietewiwiarnsdpvriqtafqvvvcvsqvprpphhyidgqfmggnllnlldpqqpaaqlrnqqvvnqvgddpprggqfikmfrrpprppvvcvsvrhgihtdghlvnvcvvdppcsvvcccnrcvprnvvscvvvvndhdtdvlsrhhpvlsvllvqllvlldpvllvvldvvvdlpclqvvvqdllnsllsslvvsvvvsvvpddpvnvpgdpvsvvvssvsssvsssvvsvvcvvvvnvvsvvvvvvvddppdpdddpddd", "dpdplvvqppdddplqappppfaadpvcvlvdpvaaaeeeeaqallsllllllclvlvgfyeyefqaeqpdwdddpddvpdddftqtqfapcqppvclqpqqvllvvqvvfwdwqeaefdqpppvpddppddhddppdgdddqqhdppfdpqqdlgqatwgghrntcqnhdpqfddawadadpvahqgtfdaldpdpvvrvvlvvvllvvlcvqlvkdqclqvpflqqcllqvllcvvcvvppwhkgggtgswhadpvhsldirhttsssscvvqrvdpssvssydyhyskhqqewhaghdpfgetawtkiarnccvvpvpdrgihigghrfyeypralprvllrcvssvqalqdpggdprhnqdqffalkwfwwkkkfkfffdpvsqvcqcvppppdpssnvqlvvqcvvcvpdpgsgdssrakhfmwtdadpvqqktktwidghhndddddppddpsrmimimiihwafrdrqfgwgfdppgdhpvrttrihtrddgdpvsvvsvvvrlvvsvvssvstgdtdprgpididrrnsvnlieqrqaedddsvngqayqlqhgpsyphygyfdrnhrngigngdcvsvrssssvsnsvvsscvvvvdpdddppdddddd", "ddppppdcvvvvvvvvvppppppvppldplvvlldvvllvvqlvllvvllvvcvvpdpnfflqdwqkafdlddpvvvvvpddlllllqlllvrlvsllvrlvsslvslvpdpdrdvvnnvssvvlnvssvvvnvssvslvsvvsnppddppprdddgdididrgssvssvsvssnsvgsvvvssvvssvvvvd" ], "text": [ "RNA-editing ligase in kinetoplastid mitochondrial.", "Oxidase which catalyzes the oxidation of various aldopyranoses and disaccharides.", "Erythropoietin for regulation of erythrocyte proliferation and differentiation." ] } samples = [[s1, s2] for s1, s2 in zip(input_examples["sequence"], input_examples["text"])] def compute_score(input_type_1: str, input_1: str, input_type_2: str, input_2: str): with torch.no_grad(): input_reprs = [] for input_type, input in [(input_type_1, input_1), (input_type_2, input_2)]: if input_type == "sequence": input_reprs.append(model.get_protein_repr([input])) elif input_type == "structure": input_reprs.append(model.get_structure_repr([input])) else: input_reprs.append(model.get_text_repr([input])) score = input_reprs[0] @ input_reprs[1].T / model.temperature return f"{score.item():.4f}" def change_input_type(choice_1: str, choice_2: str): examples_1 = input_examples[choice_1] examples_2 = input_examples[choice_2] # Change examples if input type is changed global samples samples = [[s1, s2] for s1, s2 in zip(examples_1, examples_2)] # Set visibility of upload button if choice_1 == "text": visible_1 = False else: visible_1 = True if choice_2 == "text": visible_2 = False else: visible_2 = True return (gr.update(samples=samples), "", "", gr.update(visible=visible_1), gr.update(visible=visible_1), gr.update(visible=visible_2), gr.update(visible=visible_2)) # Load example from dataset def load_example(example_id): return samples[example_id] # Build the block for computing protein-text similarity def build_score_computation(): gr.Markdown(f"# Compute similarity score between two modalities") with gr.Row(equal_height=True): with gr.Column(): # Compute similarity score between sequence and text with gr.Row(): input_1 = gr.Textbox(label="Input 1") # Choose the type of input 1 input_type_1 = gr.Dropdown(input_types, label="Input type", value="sequence", interactive=True, visible=True) # Provide an upload button to upload a pdb file upload_btn_1, chain_box_1 = upload_pdb_button(visible=True) upload_btn_1.upload(parse_pdb_file, inputs=[input_type_1, upload_btn_1, chain_box_1], outputs=[input_1]) with gr.Row(): input_2 = gr.Textbox(label="Input 2") # Choose the type of input 2 input_type_2 = gr.Dropdown(input_types, label="Input type", value="text", interactive=True, visible=True) # Provide an upload button to upload a pdb file upload_btn_2, chain_box_2 = upload_pdb_button(visible=False) upload_btn_2.upload(parse_pdb_file, inputs=[input_type_2, upload_btn_2, chain_box_2], outputs=[input_2]) # Provide examples examples = gr.Dataset(samples=samples, type="index", components=[input_1, input_2], label="Input examples") # Add click event to examples examples.click(fn=load_example, inputs=[examples], outputs=[input_1, input_2]) compute_btn = gr.Button(value="Compute") # Change examples based on input type input_type_1.change(fn=change_input_type, inputs=[input_type_1, input_type_2], outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1, upload_btn_2, chain_box_2]) input_type_2.change(fn=change_input_type, inputs=[input_type_1, input_type_2], outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1, upload_btn_2, chain_box_2]) similarity_score = gr.Label(label="similarity score") compute_btn.click(fn=compute_score, inputs=[input_type_1, input_1, input_type_2, input_2], outputs=[similarity_score])