# credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py

import os
import sys
from urllib import request

import gradio as gr
import requests
from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel
import torch
import progres as pg


tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
model_nt.eval()

tokenizer_aa = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
model_aa = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
model_aa.eval()

tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model_se.eval()


def nt_embed(sequence: str):
    tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"]
    attention_mask = tokens_ids != tokenizer_nt.pad_token_id
    with torch.no_grad():
        torch_outs = model_nt(
            tokens_ids,#.to('cuda'),
            attention_mask=attention_mask,#.to('cuda'),
            output_hidden_states=True
        )
    last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0]
    return last_layer_CLS


def aa_embed(sequence: str):
    tokens = tokenizer_aa([sequence], return_tensors="pt")
    with torch.no_grad():
        torch_outs = model_aa(**tokens)
    return torch_outs


def se_embed(sentence: str):
    encoded_input = tokenizer_se([sentence], return_tensors='pt')
    with torch.no_grad():
        model_output = model_se(**encoded_input)
    return model_output


def download_data_if_required():
    url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
    fps = [pg.trained_model_fp]
    urls = [f"{url_base}/trained_model.pt"]
    #for targetdb in pre_embedded_dbs:
    #    fps.append(os.path.join(database_dir, targetdb + ".pt"))
    #    urls.append(f"{url_base}/{targetdb}.pt")

    if not os.path.isdir(pg.trained_model_dir):
        os.makedirs(pg.trained_model_dir)
    #if not os.path.isdir(database_dir):
    #    os.makedirs(database_dir)

    printed = False
    for fp, url in zip(fps, urls):
        if not os.path.isfile(fp):
            if not printed:
                print("Downloading data as first time setup (~340 MB) to ", pg.progres_dir,
                      ", internet connection required, this can take a few minutes",
                      sep="", file=sys.stderr)
                printed = True
            try:
                request.urlretrieve(url, fp)
                d = torch.load(fp, map_location="cpu")
                if fp == pg.trained_model_fp:
                    assert "model" in d
                else:
                    assert "embeddings" in d
            except:
                if os.path.isfile(fp):
                    os.remove(fp)
                print("Failed to download from", url, "and save to", fp, file=sys.stderr)
                print("Exiting", file=sys.stderr)
                sys.exit(1)

    if printed:
        print("Data downloaded successfully", file=sys.stderr)


def get_pdb(pdb_code="", filepath=""):
    if pdb_code is None or pdb_code == "":
        try:
            with open(filepath.name) as f:
                return f.read()
        except AttributeError as e:
            return None
    else:
        return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode()


def molecule(pdb):

    x = (
        """<!DOCTYPE html>
        <html>
        <head>    
    <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
    <style>
    body{
        font-family:sans-serif
    }
    .mol-container {
    width: 100%;
    height: 600px;
    position: relative;
    }
    .mol-container select{
        background-image:None;
    }
    </style>
     <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
    <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
    </head>
    <body>  
    <div id="container" class="mol-container"></div>
  
            <script>
               let pdb = `"""
        + pdb
        + """`  
      
             $(document).ready(function () {
                let element = $("#container");
                let config = { backgroundColor: "black" };
                let viewer = $3Dmol.createViewer(element, config);
                viewer.addModel(pdb, "pdb");
                viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } });
                viewer.addSurface("MS", { opacity: .5, color: "white" });
                viewer.zoomTo();
                viewer.render();
                viewer.zoom(0.8, 2000);
              })
        </script>
        </body></html>"""
    )

    return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""


def str2coords(s):
    coords = []
    for line in s.split('\n'):
        if (line.startswith("ATOM  ") or line.startswith("HETATM")) and line[12:16].strip() == "CA":
            coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
        elif line.startswith("ENDMDL"):
            break
    return coords


def update_st(inp, file):
    pdb = get_pdb(inp, file)
    return (molecule(pdb), pg.embed_coords(str2coords(pdb)))


def update_nt(inp):
    return str(nt_embed(inp or ''))


def update_aa(inp):
    return str(aa_embed(inp))


def update_se(inp):
    return str(se_embed(inp))


demo = gr.Blocks()

with demo:
    with gr.Tabs():
        with gr.TabItem("PDB Structural Embeddings"):
            with gr.Row():
                with gr.Box():
                    inp = gr.Textbox(
                        placeholder="PDB Code or upload file below", label="Input structure"
                    )
                    file = gr.File(file_count="single")
                    gr.Examples(["2CBA", "6VXX"], inp)
                    btn = gr.Button("View structure")
            gr.Markdown("# PDB viewer using 3Dmol.js")
            mol = gr.HTML()
            emb = gr.Textbox(interactive=False)
            btn.click(fn=update_st, inputs=[inp, file], outputs=[mol, emb])
        with gr.TabItem("Nucleotide Sequence Embeddings"):
            with gr.Box():
                inp = gr.Textbox(
                    placeholder="ATCGCTGCCCGTAGATAATAAGAGACACTGAGGCC", label="Input Nucleotide Sequence"
                )
                btn = gr.Button("View embeddings")
                emb = gr.Textbox(interactive=False)
                btn.click(fn=update_nt, inputs=[inp], outputs=emb)
        with gr.TabItem("Amino Acid Sequence Embeddings"):
            with gr.Box():
                inp = gr.Textbox(
                    placeholder="AAGQCYRGRCSGGLCCSKYGYCGSGPAYCG", label="Input Amino Acid Sequence"
                )
                btn = gr.Button("View embeddings")
                emb = gr.Textbox(interactive=False)
                btn.click(fn=update_aa, inputs=[inp], outputs=emb)
        with gr.TabItem("Sentence Embeddings"):
            with gr.Box():
                inp = gr.Textbox(
                    placeholder="Your text here", label="Input Sentence"
                )
                btn = gr.Button("View embeddings")
                emb = gr.Textbox(interactive=False)
                btn.click(fn=update_se, inputs=[inp], outputs=emb)

if __name__ == "__main__":
    download_data_if_required()
    demo.launch()