Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from transformers import pipeline as pl | |
| from GPUtil import showUtilization as gpu_usage | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import sys | |
| import plotly.graph_objects as go | |
| import torch | |
| import gc | |
| from numba import cuda | |
| print('GPU available',torch.cuda.is_available()) | |
| print('__CUDA Device Name:',torch.cuda.get_device_name(0)) | |
| print(os.getcwd()) | |
| if "/home/user/app/alphafold" not in sys.path: | |
| sys.path.append("/home/user/app/alphafold") | |
| from alphafold.common import protein | |
| from alphafold.data import pipeline | |
| from alphafold.data import templates | |
| from alphafold.model import data | |
| from alphafold.model import config | |
| from alphafold.model import model | |
| def mk_mock_template(query_sequence): | |
| """create blank template""" | |
| ln = len(query_sequence) | |
| output_templates_sequence = "-" * ln | |
| templates_all_atom_positions = np.zeros( | |
| (ln, templates.residue_constants.atom_type_num, 3) | |
| ) | |
| templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num)) | |
| templates_aatype = templates.residue_constants.sequence_to_onehot( | |
| output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID | |
| ) | |
| template_features = { | |
| "template_all_atom_positions": templates_all_atom_positions[None], | |
| "template_all_atom_masks": templates_all_atom_masks[None], | |
| "template_aatype": np.array(templates_aatype)[None], | |
| "template_domain_names": [f"none".encode()], | |
| } | |
| return template_features | |
| def predict_structure(prefix, feature_dict, model_runners, random_seed=0): | |
| """Predicts structure using AlphaFold for the given sequence.""" | |
| # Run the models. | |
| # currently we only run model1 | |
| plddts = {} | |
| for model_name, model_runner in model_runners.items(): | |
| processed_feature_dict = model_runner.process_features( | |
| feature_dict, random_seed=random_seed | |
| ) | |
| prediction_result = model_runner.predict(processed_feature_dict) | |
| b_factors = ( | |
| prediction_result["plddt"][:, None] | |
| * prediction_result["structure_module"]["final_atom_mask"] | |
| ) | |
| unrelaxed_protein = protein.from_prediction( | |
| processed_feature_dict, prediction_result, b_factors | |
| ) | |
| unrelaxed_pdb_path = f"{prefix}_unrelaxed_{model_name}.pdb" | |
| plddts[model_name] = prediction_result["plddt"] | |
| print(f"{model_name} {plddts[model_name].mean()}") | |
| with open(unrelaxed_pdb_path, "w") as f: | |
| f.write(protein.to_pdb(unrelaxed_protein)) | |
| return plddts | |
| def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs): | |
| protgpt2 = pl("text-generation", model="nferruz/ProtGPT2") | |
| sequences = protgpt2( | |
| startsequence, | |
| max_length=length, | |
| do_sample=True, | |
| top_k=top_k_poolsize, | |
| repetition_penalty=repetitionPenalty, | |
| num_return_sequences=max_seqs, | |
| eos_token_id=0, | |
| ) | |
| print("Cleaning up after protGPT2") | |
| print(gpu_usage()) | |
| del protgpt2 | |
| torch.cuda.empty_cache() | |
| device = cuda.get_current_device() | |
| device.reset() | |
| print(gpu_usage()) | |
| return sequences | |
| def run_alphafold(startsequence): | |
| print(gpu_usage()) | |
| model_runners = {} | |
| models = ["model_1"] # ,"model_2","model_3","model_4","model_5"] | |
| for model_name in models: | |
| model_config = config.model_config(model_name) | |
| model_config.data.eval.num_ensemble = 1 | |
| model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".") | |
| model_runner = model.RunModel(model_config, model_params) | |
| model_runners[model_name] = model_runner | |
| query_sequence = startsequence.replace("\n", "") | |
| feature_dict = { | |
| **pipeline.make_sequence_features( | |
| sequence=query_sequence, description="none", num_res=len(query_sequence) | |
| ), | |
| **pipeline.make_msa_features( | |
| msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]] | |
| ), | |
| **mk_mock_template(query_sequence), | |
| } | |
| plddts = predict_structure("test", feature_dict, model_runners) | |
| print("Cleaning up after AF2") | |
| print(gpu_usage()) | |
| device = cuda.get_current_device() | |
| device.reset() | |
| print(gpu_usage()) | |
| return plddts["model_1"] | |
| def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs): | |
| startsequence = inp | |
| seqlen = length | |
| generated_seqs = run_protgpt2(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs) | |
| gen_seqs = [x["generated_text"] for x in generated_seqs] | |
| print(gen_seqs) | |
| sequencestxt = "" | |
| for i, seq in enumerate(gen_seqs): | |
| s = seq.replace("\n","") | |
| s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)]) | |
| sequencestxt +=f">seq{i}\n{seq}\n" | |
| return sequencestxt | |
| def update(inp): | |
| print("Running AF on", inp) | |
| startsequence = inp | |
| plddts = run_alphafold(startsequence) | |
| print(plddts) | |
| x = np.arange(10) | |
| #plt.style.use(["seaborn-ticks", "seaborn-talk"]) | |
| #fig = plt.figure() | |
| #ax = fig.add_subplot(111) | |
| #ax.plot(plddts) | |
| #ax.set_ylabel("predicted LDDT") | |
| #ax.set_xlabel("positions") | |
| #ax.set_title("pLDDT") | |
| fig = go.Figure(data=go.Scatter(x=np.arange(len(plddts)), y=plddts, hovertemplate='<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}')) | |
| fig.update_layout(title="pLDDT", | |
| xaxis_title="Residue index", | |
| yaxis_title="pLDDT", | |
| height=500, | |
| template="simple_white") | |
| return ( | |
| molecule( | |
| f"test_unrelaxed_model_1.pdb", | |
| ), | |
| fig, | |
| f"{np.mean(plddts):.1f} ± {np.std(plddts):.1f}", | |
| ) | |
| def read_mol(molpath): | |
| with open(molpath, "r") as fp: | |
| lines = fp.readlines() | |
| mol = "" | |
| for l in lines: | |
| mol += l | |
| return mol | |
| def molecule(pdb): | |
| mol = read_mol(pdb) | |
| x = ( | |
| """<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
| <link rel="stylesheet" href="https://unpkg.com/[email protected]/dist/flowbite.min.css" /> | |
| <style> | |
| body{ | |
| font-family:sans-serif | |
| } | |
| .mol-container { | |
| width: 100%; | |
| height: 800px; | |
| position: relative; | |
| } | |
| .space-x-2 > * + *{ | |
| margin-left: 0.5rem; | |
| } | |
| .p-1{ | |
| padding:0.5rem; | |
| } | |
| .flex{ | |
| display:flex; | |
| align-items: center; | |
| } | |
| .w-4{ | |
| width:1rem; | |
| } | |
| .h-4{ | |
| height:1rem; | |
| } | |
| .mt-4{ | |
| margin-top:1rem; | |
| } | |
| </style> | |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
| </head> | |
| <body> | |
| <div id="container" class="mol-container"></div> | |
| <div class="flex"> | |
| <div class="px-4"> | |
| <label for="sidechain" class="relative inline-flex items-center mb-4 cursor-pointer "> | |
| <input id="sidechain"type="checkbox" class="sr-only peer"> | |
| <div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div> | |
| <span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show side chains</span> | |
| </label> | |
| </div> | |
| <button type="button" class="text-gray-900 bg-white hover:bg-gray-100 border border-gray-200 focus:ring-4 focus:outline-none focus:ring-gray-100 font-medium rounded-lg text-sm px-5 py-2.5 text-center inline-flex items-center dark:focus:ring-gray-600 dark:bg-gray-800 dark:border-gray-700 dark:text-white dark:hover:bg-gray-700 mr-2 mb-2" id="download"> | |
| <svg class="w-6 h-6 mr-2 -ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path></svg> | |
| Download predicted structure | |
| </button> | |
| </div> | |
| <div class="text-sm"> | |
| <div class="font-medium mt-4"><b>AlphaFold model confidence:</b></div> | |
| <div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
| style="background-color: rgb(0, 83, 214);"> </span><span class="legendlabel">Very high | |
| (pLDDT > 90)</span></div> | |
| <div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
| style="background-color: rgb(101, 203, 243);"> </span><span class="legendlabel">Confident | |
| (90 > pLDDT > 70)</span></div> | |
| <div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
| style="background-color: rgb(255, 219, 19);"> </span><span class="legendlabel">Low (70 > | |
| pLDDT > 50)</span></div> | |
| <div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
| style="background-color: rgb(255, 125, 69);"> </span><span class="legendlabel">Very low | |
| (pLDDT < 50)</span></div> | |
| <div class="row column legendDesc"> AlphaFold produces a per-residue confidence | |
| score (pLDDT) between 0 and 100. Some regions below 50 pLDDT may be unstructured in isolation. | |
| </div> | |
| </div> | |
| <script> | |
| let viewer = null; | |
| let voldata = null; | |
| $(document).ready(function () { | |
| let element = $("#container"); | |
| let config = { backgroundColor: "white" }; | |
| viewer = $3Dmol.createViewer( element, config ); | |
| viewer.ui.initiateUI(); | |
| let data = `""" | |
| + mol | |
| + """` | |
| viewer.addModel( data, "pdb" ); | |
| //AlphaFold code from https://gist.github.com/piroyon/30d1c1099ad488a7952c3b21a5bebc96 | |
| let colorAlpha = function (atom) { | |
| if (atom.b < 50) { | |
| return "OrangeRed"; | |
| } else if (atom.b < 70) { | |
| return "Gold"; | |
| } else if (atom.b < 90) { | |
| return "MediumTurquoise"; | |
| } else { | |
| return "Blue"; | |
| } | |
| }; | |
| viewer.setStyle({}, { cartoon: { colorfunc: colorAlpha } }); | |
| viewer.zoomTo(); | |
| viewer.render(); | |
| viewer.zoom(0.8, 2000); | |
| viewer.getModel(0).setHoverable({}, true, | |
| function (atom, viewer, event, container) { | |
| console.log(atom) | |
| if (!atom.label) { | |
| atom.label = viewer.addLabel(atom.resn+atom.resi+" pLDDT=" + atom.b, { position: atom, backgroundColor: "mintcream", fontColor: "black" }); | |
| } | |
| }, | |
| function (atom, viewer) { | |
| if (atom.label) { | |
| viewer.removeLabel(atom.label); | |
| delete atom.label; | |
| } | |
| } | |
| ); | |
| $("#sidechain").change(function () { | |
| if (this.checked) { | |
| BB = ["C", "O", "N"] | |
| viewer.setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: { colorfunc: colorAlpha }}); | |
| viewer.render() | |
| } else { | |
| viewer.setStyle({cartoon: { colorfunc: colorAlpha }}); | |
| viewer.render() | |
| } | |
| }); | |
| $("#download").click(function () { | |
| download("gradioFold_model1.pdb", data); | |
| }) | |
| }); | |
| function download(filename, text) { | |
| var element = document.createElement("a"); | |
| element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(text)); | |
| element.setAttribute("download", filename); | |
| element.style.display = "none"; | |
| document.body.appendChild(element); | |
| element.click(); | |
| document.body.removeChild(element); | |
| } | |
| </script> | |
| </body></html>""" | |
| ) | |
| return f"""<iframe style="width: 800px; height: 1200px" name="result" allow="midi; geolocation; microphone; camera; | |
| display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
| allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
| def change_sequence(chosenSeq): | |
| return chosenSeq | |
| proteindream = gr.Blocks() | |
| with proteindream: | |
| gr.Markdown("# GradioFold") | |
| gr.Markdown( | |
| """GradioFold is a web-based tool that combines a large language model trained on natural protein sequence (protGPT2) with structure prediction using AlphaFold. | |
| Type a start sequence or provide a sequence with blanks that protGPT2 can complete.""" | |
| ) | |
| gr.Markdown("## protGPT2") | |
| gr.Markdown( | |
| """ | |
| Enter a start sequence and have the language model complete it. | |
| """ | |
| ) | |
| with gr.Box(): | |
| with gr.Row(): | |
| inp = gr.Textbox(placeholder="M", label="Start sequence") | |
| length = gr.Number(value=50, label="Max sequence length") | |
| with gr.Row(): | |
| repetitionPenalty = gr.Slider(minimum=1, maximum=5,value=1.2, label="Repetition penalty") | |
| top_k_poolsize = gr.Slider(minimum=700, maximum=52056,value=950, label="Top-K sampling pool size") | |
| max_seqs = gr.Slider(minimum=2, maximum=20,value=5, label="Number of sequences to generate") | |
| btn = gr.Button("Predict sequences using protGPT2") | |
| results = gr.Textbox(label="Results", lines=15) | |
| btn.click(fn=update_protGPT2, inputs=[inp, length, repetitionPenalty, top_k_poolsize, max_seqs], outputs=results) | |
| gr.Markdown("## AlphaFold") | |
| gr.Markdown( | |
| "Select a generated sequence above and copy it in the field below for structure prediction using AlphaFold2." | |
| ) | |
| with gr.Group(): | |
| chosenSeq = gr.Textbox(label="Chosen sequence") | |
| btn2 = gr.Button("Predict structure") | |
| with gr.Group(): | |
| meanpLDDT = gr.Textbox(label="Mean pLDDT of chosen sequence") | |
| with gr.Row(): | |
| mol = gr.HTML() | |
| plot = gr.Plot(label="pLDDT") | |
| gr.Markdown( | |
| """## Acknowledgements | |
| This was a fun demo using Gradio, Huggingface Spaces and ColabFold as inspiration. More information about the used algorithms can be found below. | |
| All code is available on [Github]() and licensed under MIT license. | |
| - ProtGPT2: Ferruz et.al [BioRxiv](https://doi.org/10.1101/2022.03.09.483666) [Code](https://huggingface.co/nferruz/ProtGPT2) | |
| - AlphaFold2: Jumper et.al [Paper](https://doi.org/10.1038/s41586-021-03819-2) [Code](https://github.com/deepmind/alphafold) Model parameters released under CC BY 4.0 | |
| - ColabFold: Mirdita et.al [Paper](https://doi.org/10.1101/2021.08.15.456425 ) [Code](https://github.com/sokrypton/ColabFold) | |
| Created by [@simonduerr](https://twitter.com/simonduerr) | |
| """ | |
| ) | |
| #seqChoice.change(fn=update_seqs, inputs=seqChoice, outputs=chosenSeq) | |
| btn2.click(fn=update, inputs=chosenSeq, outputs=[mol, plot, meanpLDDT]) | |
| proteindream.launch(share=False) | |