Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from .init_model import model | |
| from .blocks import upload_pdb_button, parse_pdb_file | |
| input_types = ["sequence", "structure", "text"] | |
| input_examples = { | |
| "sequence": [ | |
| "MQLQRLGAPLLKRLVGGCIRQSTAPIMPCVVVSGSGGFLTPVRTYMPLPNDQSDFSPYIEIDLPSESRIQSLHKSGLAAQEWVACEKVHGTNFGIYLINQGDHEVVRFAKRSGIMDPNENFFGYHILIDEFTAQIRILNDLLKQKYGLSRVGRLVLNGELFGAKYKHPLVPKSEKWCTLPNGKKFPIAGVQIQREPFPQYSPELHFFAFDIKYSVSGAEEDFVLLGYDEFVEFSSKVPNLLYARALVRGTLDECLAFDVENFMTPLPALLGLGNYPLEGNLAEGVVIRHVRRGDPAVEKHNVSTIIKLRCSSFMELKHPGKQKELKETFIDTVRSGALRRVRGNVTVISDSMLPQVEAAANDLLLNNVSDGRLSNVLSKIGREPLLSGEVSQVDVALMLAKDALKDFLKEVDSLVLNTTLAFRKLLITNVYFESKRLVEQKWKELMQEEAAAQSEAIPPLSPAAPTKGE", | |
| "MSLSTEQMLRDYPRSMQINGQIPKNAIHETYGNDGVDVFIAGSGPIGATYAKLCVEAGLRVVMVEIGAADSFYAVNAEEGTAVPYVPGYHKKNEIEFQKDIDRFVNVIKGALQQVSVPVRNQNVPTLDPGAWSAPPGSSAISNGKNPHQREFENLSAEAVTRGVGGMSTHWTCSTPRIHPPMESLPGIGRPKLSNDPAEDDKEWNELYSEAERLIGTSTKEFDESIRHTLVLRSLQDAYKDRQRIFRPLPLACHRLKNAPEYVEWHSAENLFHSIYNDDKQKKLFTLLTNHRCTRLALTGGYEKKIGAAEVRNLLATRNPSSQLDSYIMAKVYVLASGAIGNPQILYNSGFSGLQVTPRNDSLIPNLGRYITEQPMAFCQIVLRQEFVDSVRDDPYGLPWWKEAVAQHIAKNPTDALPIPFRDPEPQVTTPFTEEHPWHTQIHRDAFSYGAVGPEVDSRVIVDLRWFGATDPEANNLLVFQNDVQDGYSMPQPTFRYRPSTASNVRARKMMADMCEVASNLGGYLPTSPPQFMDPGLALHLAGTTRIGFDKATTVADNNSLVWDFANLYVAGNGTIRTGFGENPTLTSMCHAIKSARSIINTLKGGTDGKNTGEHRNL", | |
| "MGVHECPAWLWLLLSLLSLPLGLPVLGAPPRLICDSRVLERYLLEAKEAENITTGCAEHCSLNENITVPDTKVNFYAWKRMEVGQQAVEVWQGLALLSEAVLRGQALLVNSSQPWEPLQLHVDKAVSGLRSLTTLLRALGAQKEAISPPDAASAAPLRTITADTFRKLFRVYSNFLRGKLKLYTGEACRTGDR" | |
| ], | |
| "structure": [ | |
| "ddddddddddddddddddddddddddddddddpdpddpddpqpdddfddpdqqlddadddfaaddpvqvvlcvvvvvlqakkfkwfdadffkkkwkwadpdpdidifidtnvgtdglqpddllclvcvvlsvqlvvllqvvvcvvvvapafrmkmfiwgkdalddpfppadadpdwhagsvgdidgsvpgdrdddpaqhahsdiaietewiwiarnsdpvriqtafqvvvcvsqvprpphhyidgqfmggnllnlldpqqpaaqlrnqqvvnqvgddpprggqfikmfrrpprppvvcvsvrhgihtdghlvnvcvvdppcsvvcccnrcvprnvvscvvvvndhdtdvlsrhhpvlsvllvqllvlldpvllvvldvvvdlpclqvvvqdllnsllsslvvsvvvsvvpddpvnvpgdpvsvvvssvsssvsssvvsvvcvvvvnvvsvvvvvvvddppdpdddpddd", | |
| "dpdplvvqppdddplqappppfaadpvcvlvdpvaaaeeeeaqallsllllllclvlvgfyeyefqaeqpdwdddpddvpdddftqtqfapcqppvclqpqqvllvvqvvfwdwqeaefdqpppvpddppddhddppdgdddqqhdppfdpqqdlgqatwgghrntcqnhdpqfddawadadpvahqgtfdaldpdpvvrvvlvvvllvvlcvqlvkdqclqvpflqqcllqvllcvvcvvppwhkgggtgswhadpvhsldirhttsssscvvqrvdpssvssydyhyskhqqewhaghdpfgetawtkiarnccvvpvpdrgihigghrfyeypralprvllrcvssvqalqdpggdprhnqdqffalkwfwwkkkfkfffdpvsqvcqcvppppdpssnvqlvvqcvvcvpdpgsgdssrakhfmwtdadpvqqktktwidghhndddddppddpsrmimimiihwafrdrqfgwgfdppgdhpvrttrihtrddgdpvsvvsvvvrlvvsvvssvstgdtdprgpididrrnsvnlieqrqaedddsvngqayqlqhgpsyphygyfdrnhrngigngdcvsvrssssvsnsvvsscvvvvdpdddppdddddd", | |
| "ddppppdcvvvvvvvvvppppppvppldplvvlldvvllvvqlvllvvllvvcvvpdpnfflqdwqkafdlddpvvvvvpddlllllqlllvrlvsllvrlvsslvslvpdpdrdvvnnvssvvlnvssvvvnvssvslvsvvsnppddppprdddgdididrgssvssvsvssnsvgsvvvssvvssvvvvd" | |
| ], | |
| "text": [ | |
| "RNA-editing ligase in kinetoplastid mitochondrial.", | |
| "Oxidase which catalyzes the oxidation of various aldopyranoses and disaccharides.", | |
| "Erythropoietin for regulation of erythrocyte proliferation and differentiation." | |
| ] | |
| } | |
| samples = [[s1, s2] for s1, s2 in zip(input_examples["sequence"], input_examples["text"])] | |
| def compute_score(input_type_1: str, input_1: str, input_type_2: str, input_2: str): | |
| with torch.no_grad(): | |
| input_reprs = [] | |
| for input_type, input in [(input_type_1, input_1), (input_type_2, input_2)]: | |
| if input_type == "sequence": | |
| input_reprs.append(model.get_protein_repr([input])) | |
| elif input_type == "structure": | |
| input_reprs.append(model.get_structure_repr([input])) | |
| else: | |
| input_reprs.append(model.get_text_repr([input])) | |
| score = input_reprs[0] @ input_reprs[1].T / model.temperature | |
| return f"{score.item():.4f}" | |
| def change_input_type(choice_1: str, choice_2: str): | |
| examples_1 = input_examples[choice_1] | |
| examples_2 = input_examples[choice_2] | |
| # Change examples if input type is changed | |
| global samples | |
| samples = [[s1, s2] for s1, s2 in zip(examples_1, examples_2)] | |
| # Set visibility of upload button | |
| if choice_1 == "text": | |
| visible_1 = False | |
| else: | |
| visible_1 = True | |
| if choice_2 == "text": | |
| visible_2 = False | |
| else: | |
| visible_2 = True | |
| return (gr.update(samples=samples), "", "", gr.update(visible=visible_1), gr.update(visible=visible_1), | |
| gr.update(visible=visible_2), gr.update(visible=visible_2)) | |
| # Load example from dataset | |
| def load_example(example_id): | |
| return samples[example_id] | |
| # Build the block for computing protein-text similarity | |
| def build_score_computation(): | |
| gr.Markdown(f"# Compute similarity score between two modalities") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| # Compute similarity score between sequence and text | |
| with gr.Row(): | |
| input_1 = gr.Textbox(label="Input 1") | |
| # Choose the type of input 1 | |
| input_type_1 = gr.Dropdown(input_types, label="Input type", value="sequence", | |
| interactive=True, visible=True) | |
| # Provide an upload button to upload a pdb file | |
| upload_btn_1, chain_box_1 = upload_pdb_button(visible=True) | |
| upload_btn_1.upload(parse_pdb_file, inputs=[input_type_1, upload_btn_1, chain_box_1], outputs=[input_1]) | |
| with gr.Row(): | |
| input_2 = gr.Textbox(label="Input 2") | |
| # Choose the type of input 2 | |
| input_type_2 = gr.Dropdown(input_types, label="Input type", value="text", | |
| interactive=True, visible=True) | |
| # Provide an upload button to upload a pdb file | |
| upload_btn_2, chain_box_2 = upload_pdb_button(visible=False) | |
| upload_btn_2.upload(parse_pdb_file, inputs=[input_type_2, upload_btn_2, chain_box_2], outputs=[input_2]) | |
| # Provide examples | |
| examples = gr.Dataset(samples=samples, type="index", components=[input_1, input_2], label="Input examples") | |
| # Add click event to examples | |
| examples.click(fn=load_example, inputs=[examples], outputs=[input_1, input_2]) | |
| compute_btn = gr.Button(value="Compute") | |
| # Change examples based on input type | |
| input_type_1.change(fn=change_input_type, inputs=[input_type_1, input_type_2], | |
| outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1, | |
| upload_btn_2, chain_box_2]) | |
| input_type_2.change(fn=change_input_type, inputs=[input_type_1, input_type_2], | |
| outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1, | |
| upload_btn_2, chain_box_2]) | |
| similarity_score = gr.Label(label="similarity score") | |
| compute_btn.click(fn=compute_score, inputs=[input_type_1, input_1, input_type_2, input_2], | |
| outputs=[similarity_score]) |