Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from scipy.stats import norm | |
| from .init_model import model, all_index, valid_subsections | |
| from .blocks import upload_pdb_button, parse_pdb_file | |
| tmp_file_path = "/tmp/results.tsv" | |
| tmp_plot_path = "/tmp/histogram.svg" | |
| # Samples for input | |
| samples = [ | |
| ["Proteins with zinc bindings."], | |
| ["Proteins locating at cell membrane."], | |
| ["Protein that serves as an enzyme."] | |
| ] | |
| # Databases for different modalities | |
| now_db = { | |
| "sequence": list(all_index["sequence"].keys())[0], | |
| "structure": list(all_index["structure"].keys())[0], | |
| "text": list(all_index["text"].keys())[0] | |
| } | |
| def clear_results(): | |
| return "", gr.update(visible=False), gr.update(visible=False) | |
| def plot(scores) -> None: | |
| """ | |
| Plot the distribution of scores and fit a normal distribution. | |
| Args: | |
| scores: List of scores | |
| """ | |
| plt.hist(scores, bins=100, density=True, alpha=0.6) | |
| plt.title('Distribution of similarity scores in the database', fontsize=15) | |
| plt.xlabel('Similarity score', fontsize=15) | |
| plt.ylabel('Density', fontsize=15) | |
| mu, std = norm.fit(scores) | |
| # Plot the Gaussian | |
| xmin, xmax = plt.xlim() | |
| _, ymax = plt.ylim() | |
| x = np.linspace(xmin, xmax, 100) | |
| p = norm.pdf(x, mu, std) | |
| plt.plot(x, p) | |
| # Plot total number of scores | |
| plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12) | |
| # Convert the plot to svg format | |
| plt.savefig(tmp_plot_path) | |
| plt.cla() | |
| # Search from database | |
| def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str): | |
| input_modality = input_type.replace("sequence", "protein") | |
| with torch.no_grad(): | |
| input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy() | |
| db = now_db[query_type] | |
| if query_type == "text": | |
| index = all_index["text"][db][subsection_type]["index"] | |
| ids = all_index["text"][db][subsection_type]["ids"] | |
| else: | |
| index = all_index[query_type][db]["index"] | |
| ids = all_index[query_type][db]["ids"] | |
| if check_index_ivf(query_type, subsection_type): | |
| if index.nlist < nprobe: | |
| raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).") | |
| else: | |
| index.nprobe = nprobe | |
| if topk > index.ntotal: | |
| raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).") | |
| # Retrieve all scores to plot the distribution | |
| scores, ranks = index.search(input_embedding, index.ntotal) | |
| scores, ranks = scores[0], ranks[0] | |
| # Remove inf values | |
| selector = scores > -1 | |
| scores = scores[selector] | |
| ranks = ranks[selector] | |
| scores = scores / model.temperature.item() | |
| plot(scores) | |
| top_scores = scores[:topk] | |
| top_ranks = ranks[:topk] | |
| # ranks = [list(range(topk))] | |
| # ids = ["P12345"] * topk | |
| # scores = torch.randn(topk).tolist() | |
| # Write the results to a temporary file for downloading | |
| with open(tmp_file_path, "w") as w: | |
| w.write("Id\tMatching score\n") | |
| for i in range(topk): | |
| rank = top_ranks[i] | |
| w.write(f"{ids[rank]}\t{top_scores[i]}\n") | |
| # Get topk ids | |
| topk_ids = [] | |
| for rank in top_ranks: | |
| now_id = ids[rank] | |
| if query_type == "text": | |
| topk_ids.append(now_id) | |
| else: | |
| if db != "PDB": | |
| # Provide link to uniprot website | |
| topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})") | |
| else: | |
| # Provide link to pdb website | |
| pdb_id = now_id.split("-")[0] | |
| topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})") | |
| limit = 1000 | |
| df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]}) | |
| if len(topk_ids) > limit: | |
| info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]}, | |
| index=[1000]) | |
| df = pd.concat([df, info_df], axis=0) | |
| output = df.to_markdown() | |
| return (output, | |
| gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0), | |
| gr.update(value=tmp_plot_path, visible=True)) | |
| def change_input_type(choice: str): | |
| # Change examples if input type is changed | |
| global samples | |
| if choice == "text": | |
| samples = [ | |
| ["Proteins with zinc bindings."], | |
| ["Proteins locating at cell membrane."], | |
| ["Protein that serves as an enzyme."] | |
| ] | |
| elif choice == "sequence": | |
| samples = [ | |
| ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"], | |
| ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"], | |
| ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"] | |
| ] | |
| elif choice == "structure": | |
| samples = [ | |
| ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"], | |
| ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"], | |
| ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"] | |
| ] | |
| # Set visibility of upload button | |
| if choice == "text": | |
| visible = False | |
| else: | |
| visible = True | |
| return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible) | |
| # Load example from dataset | |
| def load_example(example_id): | |
| return samples[example_id][0] | |
| # Change the visibility of subsection type | |
| def change_output_type(query_type: str, subsection_type: str): | |
| nprobe_visible = check_index_ivf(query_type, subsection_type) | |
| subsection_visible = True if query_type == "text" else False | |
| return ( | |
| gr.update(visible=subsection_visible), | |
| gr.update(visible=nprobe_visible), | |
| gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type]) | |
| ) | |
| def check_index_ivf(index_type: str, subsection_type: str = None) -> bool: | |
| """ | |
| Check if the index is of IVF type. | |
| Args: | |
| index_type: Type of index. | |
| subsection_type: If the "index_type" is "text", get the index based on the subsection type. | |
| Returns: | |
| Whether the index is of IVF type or not. | |
| """ | |
| db = now_db[index_type] | |
| if index_type == "sequence": | |
| index = all_index["sequence"][db]["index"] | |
| elif index_type == "structure": | |
| index = all_index["structure"][db]["index"] | |
| elif index_type == "text": | |
| index = all_index["text"][db][subsection_type]["index"] | |
| nprobe_visible = True if hasattr(index, "nprobe") else False | |
| return nprobe_visible | |
| def change_db_type(query_type: str, subsection_type: str, db_type: str): | |
| """ | |
| Change the database to search. | |
| Args: | |
| query_type: The output type. | |
| db_type: The database to search. | |
| """ | |
| now_db[query_type] = db_type | |
| if query_type == "text": | |
| subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function") | |
| else: | |
| subsection_update = gr.update(visible=False) | |
| nprobe_visible = check_index_ivf(query_type, subsection_type) | |
| return subsection_update, gr.update(visible=nprobe_visible) | |
| # Build the searching block | |
| def build_search_module(): | |
| gr.Markdown(f"# Search from Swiss-Prot database (the whole UniProt database will be supported soon)") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| # Set input type | |
| input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text") | |
| with gr.Row(): | |
| # Set output type | |
| query_type = gr.Radio( | |
| ["sequence", "structure", "text"], | |
| label="Output type (e.g. 'sequence' means returning qualified sequences)", | |
| value="sequence", | |
| scale=2, | |
| ) | |
| # If the output type is "text", provide an option to choose the subsection of text | |
| subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function", | |
| interactive=True, visible=False, scale=0) | |
| db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"], | |
| interactive=True, visible=True, scale=0) | |
| with gr.Row(): | |
| # Input box | |
| input = gr.Text(label="Input") | |
| # Provide an upload button to upload a pdb file | |
| upload_btn, chain_box = upload_pdb_button(visible=False) | |
| upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input]) | |
| # If the index is of IVF type, provide an option to choose the number of clusters. | |
| nprobe_visible = check_index_ivf(query_type.value) | |
| nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible, | |
| label="Number of clusters to search (lower value for faster search and higher value for more accurate search)") | |
| # Add event listener to output type | |
| query_type.change(fn=change_output_type, inputs=[query_type, subsection_type], | |
| outputs=[subsection_type, nprobe, db_type]) | |
| # Add event listener to db type | |
| db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type], | |
| outputs=[subsection_type, nprobe]) | |
| # Choose topk results | |
| topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results") | |
| # Provide examples | |
| examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples") | |
| # Add click event to examples | |
| examples.click(fn=load_example, inputs=[examples], outputs=input) | |
| # Change examples based on input type | |
| input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box]) | |
| with gr.Row(): | |
| search_btn = gr.Button(value="Search") | |
| clear_btn = gr.Button(value="Clear") | |
| with gr.Row(): | |
| with gr.Column(): | |
| results = gr.Markdown(label="results", height=450) | |
| download_btn = gr.DownloadButton(label="Download results", visible=False) | |
| # Plot the distribution of scores | |
| histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False) | |
| search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type], | |
| outputs=[results, download_btn, histogram]) | |
| clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram]) |