LTEnjoy commited on
Commit
2bd60c8
·
verified ·
1 Parent(s): 9606143

Delete demo

Browse files
demo/__init__.py DELETED
File without changes
demo/config.yaml DELETED
@@ -1,22 +0,0 @@
1
- model_dir: /data/ProTrek_650M_UniRef50
2
- faiss_config:
3
- IO_FLAG_MMAP: True
4
- sequence_index_dir:
5
- - name: Swiss-Prot
6
- index_dir: /data/ProTrek-faiss-index/ProTrek_650M_UniRef50/Swiss-Prot/sequence
7
- - name: UniRef50
8
- index_dir: /data/ProTrek-faiss-index/ProTrek_650M_UniRef50/UniRef50/sequence
9
- - name: Uncharacterized
10
- index_dir: /data/ProTrek-faiss-index/ProTrek_650M_UniRef50/Uncharacterized/sequence
11
- - name: PDB
12
- index_dir: /data/ProTrek-faiss-index/ProTrek_650M_UniRef50/PDB/sequence
13
-
14
- structure_index_dir:
15
- - name: Swiss-Prot
16
- index_dir: /data/ProTrek-faiss-index/ProTrek_650M_UniRef50/Swiss-Prot/structure
17
- - name: PDB
18
- index_dir: /data/ProTrek-faiss-index/ProTrek_650M_UniRef50/PDB/structure
19
-
20
- text_index_dir:
21
- - name: Swiss-Prot
22
- index_dir: /data/ProTrek-faiss-index/ProTrek_650M_UniRef50/Swiss-Prot/text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/modules/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- import sys
2
-
3
- sys.path += []
4
-
5
- import argparse
6
-
7
-
8
- def main():
9
- pass
10
-
11
-
12
- def get_args():
13
- parser = argparse.ArgumentParser()
14
- return parser.parse_args()
15
-
16
-
17
- if __name__ == '__main__':
18
- args = get_args()
19
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/modules/blocks.py DELETED
@@ -1,66 +0,0 @@
1
- import gradio as gr
2
-
3
- from utils.foldseek_util import get_struc_seq
4
-
5
-
6
- ####################################################
7
- # gradio blocks #
8
- ####################################################
9
- def upload_pdb_button(visible: bool = True, chain_visible: bool = True):
10
- """
11
- Provide an upload button to upload a pdb file
12
- Args:
13
- visible: Whether the block is visible or not
14
- """
15
-
16
- with gr.Column(scale=0):
17
-
18
- # Which chain to be extracted
19
- chain_box = gr.Textbox(label="Chain (to be extracted from the pdb file)", value="A",
20
- visible=chain_visible, interactive=True)
21
-
22
- upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", visible=visible)
23
-
24
- return upload_btn, chain_box
25
-
26
-
27
- ####################################################
28
- # Trigger functions #
29
- ####################################################
30
- def parse_pdb_file(input_type: str, file: str, chain: str) -> str:
31
- """
32
- Parse the uploaded structure file
33
-
34
- Args:
35
- input_type: Type of input. Must be one of ["protein sequence", "protein structure"]
36
-
37
- file: Path to the uploaded file
38
-
39
- chain: Chain to be extracted from the pdb file
40
-
41
- Returns:
42
- Protein sequence or Foldseek sequence
43
- """
44
- try:
45
- parsed_seqs = get_struc_seq("/tmp/foldseek", file, [chain])[chain]
46
- if input_type == "sequence":
47
- return parsed_seqs[0]
48
- else:
49
- return parsed_seqs[1].lower()
50
-
51
- except Exception as e:
52
- raise gr.Error(f"{e}")
53
-
54
-
55
- def set_upload_visible(visible: bool) -> gr.Interface:
56
- """
57
- Set the visibility of the upload button
58
-
59
- Args:
60
- visible: Whether the block is visible or not
61
-
62
- Returns:
63
- gr.Interface: Updated interface
64
- """
65
-
66
- return gr.update(visible=visible)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/modules/compute_score.py DELETED
@@ -1,127 +0,0 @@
1
- import gradio as gr
2
- import torch
3
-
4
- from .init_model import model
5
- from .blocks import upload_pdb_button, parse_pdb_file
6
-
7
-
8
- input_types = ["sequence", "structure", "text"]
9
-
10
- input_examples = {
11
- "sequence": [
12
- "MQLQRLGAPLLKRLVGGCIRQSTAPIMPCVVVSGSGGFLTPVRTYMPLPNDQSDFSPYIEIDLPSESRIQSLHKSGLAAQEWVACEKVHGTNFGIYLINQGDHEVVRFAKRSGIMDPNENFFGYHILIDEFTAQIRILNDLLKQKYGLSRVGRLVLNGELFGAKYKHPLVPKSEKWCTLPNGKKFPIAGVQIQREPFPQYSPELHFFAFDIKYSVSGAEEDFVLLGYDEFVEFSSKVPNLLYARALVRGTLDECLAFDVENFMTPLPALLGLGNYPLEGNLAEGVVIRHVRRGDPAVEKHNVSTIIKLRCSSFMELKHPGKQKELKETFIDTVRSGALRRVRGNVTVISDSMLPQVEAAANDLLLNNVSDGRLSNVLSKIGREPLLSGEVSQVDVALMLAKDALKDFLKEVDSLVLNTTLAFRKLLITNVYFESKRLVEQKWKELMQEEAAAQSEAIPPLSPAAPTKGE",
13
- "MSLSTEQMLRDYPRSMQINGQIPKNAIHETYGNDGVDVFIAGSGPIGATYAKLCVEAGLRVVMVEIGAADSFYAVNAEEGTAVPYVPGYHKKNEIEFQKDIDRFVNVIKGALQQVSVPVRNQNVPTLDPGAWSAPPGSSAISNGKNPHQREFENLSAEAVTRGVGGMSTHWTCSTPRIHPPMESLPGIGRPKLSNDPAEDDKEWNELYSEAERLIGTSTKEFDESIRHTLVLRSLQDAYKDRQRIFRPLPLACHRLKNAPEYVEWHSAENLFHSIYNDDKQKKLFTLLTNHRCTRLALTGGYEKKIGAAEVRNLLATRNPSSQLDSYIMAKVYVLASGAIGNPQILYNSGFSGLQVTPRNDSLIPNLGRYITEQPMAFCQIVLRQEFVDSVRDDPYGLPWWKEAVAQHIAKNPTDALPIPFRDPEPQVTTPFTEEHPWHTQIHRDAFSYGAVGPEVDSRVIVDLRWFGATDPEANNLLVFQNDVQDGYSMPQPTFRYRPSTASNVRARKMMADMCEVASNLGGYLPTSPPQFMDPGLALHLAGTTRIGFDKATTVADNNSLVWDFANLYVAGNGTIRTGFGENPTLTSMCHAIKSARSIINTLKGGTDGKNTGEHRNL",
14
- "MGVHECPAWLWLLLSLLSLPLGLPVLGAPPRLICDSRVLERYLLEAKEAENITTGCAEHCSLNENITVPDTKVNFYAWKRMEVGQQAVEVWQGLALLSEAVLRGQALLVNSSQPWEPLQLHVDKAVSGLRSLTTLLRALGAQKEAISPPDAASAAPLRTITADTFRKLFRVYSNFLRGKLKLYTGEACRTGDR"
15
- ],
16
-
17
- "structure": [
18
- "ddddddddddddddddddddddddddddddddpdpddpddpqpdddfddpdqqlddadddfaaddpvqvvlcvvvvvlqakkfkwfdadffkkkwkwadpdpdidifidtnvgtdglqpddllclvcvvlsvqlvvllqvvvcvvvvapafrmkmfiwgkdalddpfppadadpdwhagsvgdidgsvpgdrdddpaqhahsdiaietewiwiarnsdpvriqtafqvvvcvsqvprpphhyidgqfmggnllnlldpqqpaaqlrnqqvvnqvgddpprggqfikmfrrpprppvvcvsvrhgihtdghlvnvcvvdppcsvvcccnrcvprnvvscvvvvndhdtdvlsrhhpvlsvllvqllvlldpvllvvldvvvdlpclqvvvqdllnsllsslvvsvvvsvvpddpvnvpgdpvsvvvssvsssvsssvvsvvcvvvvnvvsvvvvvvvddppdpdddpddd",
19
- "dpdplvvqppdddplqappppfaadpvcvlvdpvaaaeeeeaqallsllllllclvlvgfyeyefqaeqpdwdddpddvpdddftqtqfapcqppvclqpqqvllvvqvvfwdwqeaefdqpppvpddppddhddppdgdddqqhdppfdpqqdlgqatwgghrntcqnhdpqfddawadadpvahqgtfdaldpdpvvrvvlvvvllvvlcvqlvkdqclqvpflqqcllqvllcvvcvvppwhkgggtgswhadpvhsldirhttsssscvvqrvdpssvssydyhyskhqqewhaghdpfgetawtkiarnccvvpvpdrgihigghrfyeypralprvllrcvssvqalqdpggdprhnqdqffalkwfwwkkkfkfffdpvsqvcqcvppppdpssnvqlvvqcvvcvpdpgsgdssrakhfmwtdadpvqqktktwidghhndddddppddpsrmimimiihwafrdrqfgwgfdppgdhpvrttrihtrddgdpvsvvsvvvrlvvsvvssvstgdtdprgpididrrnsvnlieqrqaedddsvngqayqlqhgpsyphygyfdrnhrngigngdcvsvrssssvsnsvvsscvvvvdpdddppdddddd",
20
- "ddppppdcvvvvvvvvvppppppvppldplvvlldvvllvvqlvllvvllvvcvvpdpnfflqdwqkafdlddpvvvvvpddlllllqlllvrlvsllvrlvsslvslvpdpdrdvvnnvssvvlnvssvvvnvssvslvsvvsnppddppprdddgdididrgssvssvsvssnsvgsvvvssvvssvvvvd"
21
- ],
22
-
23
- "text": [
24
- "RNA-editing ligase in kinetoplastid mitochondrial.",
25
- "Oxidase which catalyzes the oxidation of various aldopyranoses and disaccharides.",
26
- "Erythropoietin for regulation of erythrocyte proliferation and differentiation."
27
- ]
28
- }
29
-
30
- samples = [[s1, s2] for s1, s2 in zip(input_examples["sequence"], input_examples["text"])]
31
-
32
-
33
- def compute_score(input_type_1: str, input_1: str, input_type_2: str, input_2: str):
34
- with torch.no_grad():
35
- input_reprs = []
36
-
37
- for input_type, input in [(input_type_1, input_1), (input_type_2, input_2)]:
38
- if input_type == "sequence":
39
- input_reprs.append(model.get_protein_repr([input]))
40
-
41
- elif input_type == "structure":
42
- input_reprs.append(model.get_structure_repr([input]))
43
-
44
- else:
45
- input_reprs.append(model.get_text_repr([input]))
46
-
47
- score = input_reprs[0] @ input_reprs[1].T / model.temperature
48
-
49
- return f"{score.item():.4f}"
50
-
51
-
52
- def change_input_type(choice_1: str, choice_2: str):
53
- examples_1 = input_examples[choice_1]
54
- examples_2 = input_examples[choice_2]
55
-
56
- # Change examples if input type is changed
57
- global samples
58
- samples = [[s1, s2] for s1, s2 in zip(examples_1, examples_2)]
59
-
60
- # Set visibility of upload button
61
- if choice_1 == "text":
62
- visible_1 = False
63
- else:
64
- visible_1 = True
65
-
66
- if choice_2 == "text":
67
- visible_2 = False
68
- else:
69
- visible_2 = True
70
-
71
- return (gr.update(samples=samples), "", "", gr.update(visible=visible_1), gr.update(visible=visible_1),
72
- gr.update(visible=visible_2), gr.update(visible=visible_2))
73
-
74
-
75
- # Load example from dataset
76
- def load_example(example_id):
77
- return samples[example_id]
78
-
79
-
80
- # Build the block for computing protein-text similarity
81
- def build_score_computation():
82
- gr.Markdown(f"# Compute similarity score between two modalities")
83
- with gr.Row(equal_height=True):
84
- with gr.Column():
85
- # Compute similarity score between sequence and text
86
- with gr.Row():
87
- input_1 = gr.Textbox(label="Input 1")
88
-
89
- # Choose the type of input 1
90
- input_type_1 = gr.Dropdown(input_types, label="Input type", value="sequence",
91
- interactive=True, visible=True)
92
-
93
- # Provide an upload button to upload a pdb file
94
- upload_btn_1, chain_box_1 = upload_pdb_button(visible=True)
95
- upload_btn_1.upload(parse_pdb_file, inputs=[input_type_1, upload_btn_1, chain_box_1], outputs=[input_1])
96
-
97
- with gr.Row():
98
- input_2 = gr.Textbox(label="Input 2")
99
-
100
- # Choose the type of input 2
101
- input_type_2 = gr.Dropdown(input_types, label="Input type", value="text",
102
- interactive=True, visible=True)
103
-
104
- # Provide an upload button to upload a pdb file
105
- upload_btn_2, chain_box_2 = upload_pdb_button(visible=False)
106
- upload_btn_2.upload(parse_pdb_file, inputs=[input_type_2, upload_btn_2, chain_box_2], outputs=[input_2])
107
-
108
- # Provide examples
109
- examples = gr.Dataset(samples=samples, type="index", components=[input_1, input_2], label="Input examples")
110
-
111
- # Add click event to examples
112
- examples.click(fn=load_example, inputs=[examples], outputs=[input_1, input_2])
113
-
114
- compute_btn = gr.Button(value="Compute")
115
-
116
- # Change examples based on input type
117
- input_type_1.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
118
- outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
119
- upload_btn_2, chain_box_2])
120
-
121
- input_type_2.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
122
- outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
123
- upload_btn_2, chain_box_2])
124
-
125
- similarity_score = gr.Label(label="similarity score")
126
- compute_btn.click(fn=compute_score, inputs=[input_type_1, input_1, input_type_2, input_2],
127
- outputs=[similarity_score])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/modules/init_model.py DELETED
@@ -1,118 +0,0 @@
1
- import faiss
2
- import numpy as np
3
- import pandas as pd
4
- import os
5
- import yaml
6
- import glob
7
-
8
- from easydict import EasyDict
9
- from utils.constants import sequence_level
10
- from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
11
- from tqdm import tqdm
12
-
13
- print(os.listdir("/data"))
14
- def load_model():
15
- model_config = {
16
- "protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
17
- "text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
18
- "structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
19
- "load_protein_pretrained": False,
20
- "load_text_pretrained": False,
21
- "from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
22
- }
23
-
24
- model = ProTrekTrimodalModel(**model_config)
25
- model.eval()
26
- return model
27
-
28
-
29
- def load_faiss_index(index_path: str):
30
- if config.faiss_config.IO_FLAG_MMAP:
31
- index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
32
- else:
33
- index = faiss.read_index(index_path)
34
-
35
- index.metric_type = faiss.METRIC_INNER_PRODUCT
36
- return index
37
-
38
-
39
- def load_index():
40
- all_index = {}
41
-
42
- # Load protein sequence index
43
- all_index["sequence"] = {}
44
- for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
45
- db_name = db["name"]
46
- index_dir = db["index_dir"]
47
-
48
- index_path = f"{index_dir}/sequence.index"
49
- sequence_index = load_faiss_index(index_path)
50
-
51
- id_path = f"{index_dir}/ids.tsv"
52
- uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
53
-
54
- all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
55
-
56
- # Load protein structure index
57
- print("Loading structure index...")
58
- all_index["structure"] = {}
59
- for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
60
- db_name = db["name"]
61
- index_dir = db["index_dir"]
62
-
63
- index_path = f"{index_dir}/structure.index"
64
- structure_index = load_faiss_index(index_path)
65
-
66
- id_path = f"{index_dir}/ids.tsv"
67
- uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
68
-
69
- all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
70
-
71
- # Load text index
72
- all_index["text"] = {}
73
- valid_subsections = {}
74
- for db in tqdm(config.text_index_dir, desc="Loading text index..."):
75
- db_name = db["name"]
76
- index_dir = db["index_dir"]
77
- all_index["text"][db_name] = {}
78
- text_dir = f"{index_dir}/subsections"
79
-
80
- # Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
81
- valid_subsections[db_name] = set()
82
- sequence_level.add("Global")
83
- for subsection in tqdm(sequence_level):
84
- index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
85
- if not os.path.exists(index_path):
86
- continue
87
-
88
- text_index = load_faiss_index(index_path)
89
-
90
- id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
91
- text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
92
-
93
- all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
94
- valid_subsections[db_name].add(subsection)
95
-
96
- # Sort valid_subsections
97
- for db_name in valid_subsections:
98
- valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
99
-
100
- return all_index, valid_subsections
101
-
102
-
103
- # Load the config file
104
- root_dir = __file__.rsplit("/", 3)[0]
105
- config_path = f"{root_dir}/demo/config.yaml"
106
- with open(config_path, 'r', encoding='utf-8') as r:
107
- config = EasyDict(yaml.safe_load(r))
108
-
109
- device = "cuda"
110
-
111
- print("Loading model...")
112
- model = load_model()
113
- # model.to(device)
114
-
115
- all_index, valid_subsections = load_index()
116
- print("Done...")
117
- # model = None
118
- # all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/modules/search.py DELETED
@@ -1,301 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import pandas as pd
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
-
7
- from scipy.stats import norm
8
- from .init_model import model, all_index, valid_subsections
9
- from .blocks import upload_pdb_button, parse_pdb_file
10
-
11
-
12
- tmp_file_path = "/tmp/results.tsv"
13
- tmp_plot_path = "/tmp/histogram.svg"
14
-
15
- # Samples for input
16
- samples = {
17
- "sequence": [
18
- ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
19
- ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
20
- ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
21
- ],
22
-
23
- "structure": [
24
- ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
25
- ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
26
- ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
27
- ],
28
-
29
- "text": [
30
- ["Proteins with zinc bindings."],
31
- ["Proteins locating at cell membrane."],
32
- ["Protein that serves as an enzyme."]
33
- ],
34
- }
35
-
36
-
37
- def clear_results():
38
- return "", gr.update(visible=False), gr.update(visible=False)
39
-
40
-
41
- def plot(scores) -> None:
42
- """
43
- Plot the distribution of scores and fit a normal distribution.
44
- Args:
45
- scores: List of scores
46
- """
47
- plt.hist(scores, bins=100, density=True, alpha=0.6)
48
- plt.title('Distribution of similarity scores in the database', fontsize=15)
49
- plt.xlabel('Similarity score', fontsize=15)
50
- plt.ylabel('Density', fontsize=15)
51
- y_ed = plt.gca().get_ylim()[-1]
52
- plt.ylim(-0.05, y_ed)
53
-
54
- # Add note
55
- x_st = plt.gca().get_xlim()[0]
56
- text = ("Note: For the \"UniRef50\" and \"Uncharacterized\" databases, the figure illustrates\n "
57
- "only top-ranked clusters (identified using Faiss), whereas for other databases, it\n "
58
- "displays the distribution across all samples.")
59
- plt.text(x_st, -0.04, text, fontsize=8)
60
- mu, std = norm.fit(scores)
61
-
62
- # Plot the Gaussian
63
- xmin, xmax = plt.xlim()
64
- _, ymax = plt.ylim()
65
- x = np.linspace(xmin, xmax, 100)
66
- p = norm.pdf(x, mu, std)
67
- plt.plot(x, p)
68
-
69
- # Plot total number of scores
70
- plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
71
-
72
- # Convert the plot to svg format
73
- plt.savefig(tmp_plot_path)
74
- plt.cla()
75
-
76
-
77
- # Search from database
78
- def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
79
- print(f"Input type: {input_type}\n Output type: {query_type}\nDatabase: {db}\nSubsection: {subsection_type}")
80
-
81
- input_modality = input_type.replace("sequence", "protein")
82
- with torch.no_grad():
83
- input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
84
-
85
- if query_type == "text":
86
- index = all_index["text"][db][subsection_type]["index"]
87
- ids = all_index["text"][db][subsection_type]["ids"]
88
-
89
- else:
90
- index = all_index[query_type][db]["index"]
91
- ids = all_index[query_type][db]["ids"]
92
-
93
- if hasattr(index, "nprobe"):
94
- if index.nlist < nprobe:
95
- raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
96
- else:
97
- index.nprobe = nprobe
98
-
99
- if topk > index.ntotal:
100
- raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
101
-
102
- # Retrieve all scores to plot the distribution
103
- scores, ranks = index.search(input_embedding, index.ntotal)
104
- scores, ranks = scores[0], ranks[0]
105
-
106
- # Remove inf values
107
- selector = scores > -1
108
- scores = scores[selector]
109
- ranks = ranks[selector]
110
- scores = scores / model.temperature.item()
111
- plot(scores)
112
-
113
- top_scores = scores[:topk]
114
- top_ranks = ranks[:topk]
115
-
116
- # ranks = [list(range(topk))]
117
- # ids = ["P12345"] * topk
118
- # scores = torch.randn(topk).tolist()
119
-
120
- # Write the results to a temporary file for downloading
121
- with open(tmp_file_path, "w") as w:
122
- w.write("Id\tMatching score\n")
123
- for i in range(topk):
124
- rank = top_ranks[i]
125
- w.write(f"{ids[rank]}\t{top_scores[i]}\n")
126
-
127
- # Get topk ids
128
- topk_ids = []
129
- for rank in top_ranks:
130
- now_id = ids[rank]
131
- if query_type == "text":
132
- topk_ids.append(now_id.replace("|", "\\|"))
133
- else:
134
- if db != "PDB":
135
- # Provide link to uniprot website
136
- topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
137
- else:
138
- # Provide link to pdb website
139
- pdb_id = now_id.split("-")[0]
140
- topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
141
-
142
- limit = 1000
143
- df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
144
- if len(topk_ids) > limit:
145
- info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
146
- index=[1000])
147
- df = pd.concat([df, info_df], axis=0)
148
-
149
- output = df.to_markdown()
150
- return (output,
151
- gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
152
- gr.update(value=tmp_plot_path, visible=True))
153
-
154
-
155
- def change_input_type(choice: str):
156
- # Change examples if input type is changed
157
- global samples
158
-
159
- # Set visibility of upload button
160
- if choice == "text":
161
- visible = False
162
- else:
163
- visible = True
164
-
165
- return gr.update(samples=samples[choice]), "", gr.update(visible=visible), gr.update(visible=visible)
166
-
167
-
168
- # Load example from dataset
169
- def load_example(example_id):
170
- return example_id[0]
171
-
172
-
173
- # Change the visibility of subsection type
174
- def change_output_type(query_type: str, subsection_type: str):
175
- db_type = list(all_index[query_type].keys())[0]
176
- nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
177
- subsection_visible = True if query_type == "text" else False
178
-
179
- return (
180
- gr.update(visible=subsection_visible),
181
- gr.update(visible=nprobe_visible),
182
- gr.update(choices=list(all_index[query_type].keys()), value=db_type)
183
- )
184
-
185
-
186
- def check_index_ivf(index_type: str, db: str, subsection_type: str = None) -> bool:
187
- """
188
- Check if the index is of IVF type.
189
- Args:
190
- index_type: Type of index.
191
- subsection_type: If the "index_type" is "text", get the index based on the subsection type.
192
-
193
- Returns:
194
- Whether the index is of IVF type or not.
195
- """
196
- if index_type == "sequence":
197
- index = all_index["sequence"][db]["index"]
198
-
199
- elif index_type == "structure":
200
- index = all_index["structure"][db]["index"]
201
-
202
- elif index_type == "text":
203
- index = all_index["text"][db][subsection_type]["index"]
204
-
205
- # nprobe_visible = True if hasattr(index, "nprobe") else False
206
- # return nprobe_visible
207
- return False
208
-
209
-
210
- def change_db_type(query_type: str, subsection_type: str, db_type: str):
211
- """
212
- Change the database to search.
213
- Args:
214
- query_type: The output type.
215
- db_type: The database to search.
216
- """
217
- if query_type == "text":
218
- subsection_update = gr.update(choices=list(valid_subsections[db_type]), value="Function")
219
- else:
220
- subsection_update = gr.update(visible=False)
221
-
222
- nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
223
- return subsection_update, gr.update(visible=nprobe_visible)
224
-
225
-
226
- # Build the searching block
227
- def build_search_module():
228
- gr.Markdown(f"# Search from database")
229
- with gr.Row(equal_height=True):
230
- with gr.Column():
231
- # Set input type
232
- input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
233
-
234
- with gr.Row():
235
- # Set output type
236
- query_type = gr.Radio(
237
- ["sequence", "structure", "text"],
238
- label="Output type (e.g. 'sequence' means returning qualified sequences)",
239
- value="sequence",
240
- scale=2,
241
- )
242
-
243
- # If the output type is "text", provide an option to choose the subsection of text
244
- text_db = list(all_index["text"].keys())[0]
245
- sequence_db = list(all_index["sequence"].keys())[0]
246
- subsection_type = gr.Dropdown(valid_subsections[text_db], label="Subsection of text", value="Function",
247
- interactive=True, visible=False, scale=0)
248
-
249
- db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=sequence_db,
250
- interactive=True, visible=True, scale=0)
251
-
252
- with gr.Row():
253
- # Input box
254
- input = gr.Text(label="Input")
255
-
256
- # Provide an upload button to upload a pdb file
257
- upload_btn, chain_box = upload_pdb_button(visible=False, chain_visible=False)
258
- upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
259
-
260
-
261
- # If the index is of IVF type, provide an option to choose the number of clusters.
262
- nprobe_visible = check_index_ivf(query_type.value, db_type.value)
263
- nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
264
- label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
265
-
266
- # Add event listener to output type
267
- query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
268
- outputs=[subsection_type, nprobe, db_type])
269
-
270
- # Add event listener to db type
271
- db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
272
- outputs=[subsection_type, nprobe])
273
-
274
- # Choose topk results
275
- topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
276
-
277
- # Provide examples
278
- examples = gr.Dataset(samples=samples["text"], components=[input], label="Input examples")
279
-
280
- # Add click event to examples
281
- examples.click(fn=load_example, inputs=[examples], outputs=input)
282
-
283
- # Change examples based on input type
284
- input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
285
-
286
- with gr.Row():
287
- search_btn = gr.Button(value="Search")
288
- clear_btn = gr.Button(value="Clear")
289
-
290
- with gr.Row():
291
- with gr.Column():
292
- results = gr.Markdown(label="results", height=450)
293
- download_btn = gr.DownloadButton(label="Download results", visible=False)
294
-
295
- # Plot the distribution of scores
296
- histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
297
-
298
- search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type, db_type],
299
- outputs=[results, download_btn, histogram])
300
-
301
- clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/modules/tmalign.py DELETED
@@ -1,78 +0,0 @@
1
- import gradio as gr
2
- import os
3
-
4
- from .blocks import upload_pdb_button
5
- from utils.downloader import download_pdb, download_af2
6
-
7
-
8
- root_dir = __file__.rsplit("/", 3)[0]
9
- structure_types = ["AlphaFoldDB", "PDB"]
10
-
11
-
12
- def upload_structure(file: str):
13
- return file
14
-
15
-
16
- def get_structure_path(structure: str, structure_type: str) -> str:
17
- # If the structure is manually uploaded
18
- if structure[0] == "/":
19
- return structure
20
-
21
- # If the structure is a Uniprot ID, download the structure from AlphaFoldDB
22
- elif structure_type == "AlphaFoldDB":
23
- save_path = f"{root_dir}/demo/cache/{structure}.pdb"
24
- if not os.path.exists(save_path):
25
- download_af2(structure, "pdb", save_path)
26
- return save_path
27
-
28
- # If the structure is a PDB ID, download the structure from PDB
29
- elif structure_type == "PDB":
30
- save_path = f"{root_dir}/demo/cache/{structure}.cif"
31
- if not os.path.exists(save_path):
32
- download_pdb(structure, "cif", save_path)
33
- return save_path
34
-
35
-
36
- def tmalign(structure_1: str, structure_type_1: str, structure_2: str, structure_type_2: str):
37
- structure_path_1 = get_structure_path(structure_1, structure_type_1)
38
- structure_path_2 = get_structure_path(structure_2, structure_type_2)
39
-
40
- cmd = f"/tmp/TMalign {structure_path_1} {structure_path_2}"
41
-
42
- r = os.popen(cmd)
43
- text = r.read()
44
- return text
45
-
46
-
47
- # Build the block for computing protein-text similarity
48
- def build_TMalign():
49
- gr.Markdown(f"# Calculate TM-score between two protein structures")
50
- with gr.Row(equal_height=True):
51
- with gr.Column():
52
- # Compute similarity score between sequence and text
53
- with gr.Row():
54
- structure_1 = gr.Textbox(label="Protein structure 1 (input Uniprot ID or PDB ID or upload a pdb file)")
55
-
56
- structure_type_1 = gr.Dropdown(structure_types, label="Structure type (if the structure is manually uploaded, ignore this field)",
57
- value="AlphaFoldDB", interactive=True, visible=True)
58
-
59
- # Provide an upload button to upload a pdb file
60
- upload_btn_1, _ = upload_pdb_button(visible=True, chain_visible=False)
61
- upload_btn_1.upload(upload_structure, inputs=[upload_btn_1], outputs=[structure_1])
62
-
63
- with gr.Row():
64
- structure_2 = gr.Textbox(label="Protein structure 2 (input Uniprot ID or PDB ID or upload a pdb file)")
65
-
66
- structure_type_2 = gr.Dropdown(structure_types, label="Structure type (if the structure is manually uploaded, ignore this field)",
67
- value="AlphaFoldDB", interactive=True, visible=True)
68
-
69
- # Provide an upload button to upload a pdb file
70
- upload_btn_2, _ = upload_pdb_button(visible=True, chain_visible=False)
71
- upload_btn_2.upload(upload_structure, inputs=[upload_btn_2], outputs=[structure_2])
72
-
73
- compute_btn = gr.Button(value="Compute TM-score")
74
- tmscore = gr.TextArea(label="TM-score", interactive=False)
75
-
76
- compute_btn.click(tmalign, inputs=[structure_1, structure_type_1, structure_2, structure_type_2],
77
- outputs=[tmscore])
78
-