File size: 4,300 Bytes
f8080fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
from mammal.keys import *
from mammal.model import Mammal

from mammal_demo.demo_framework import MammalObjectBroker, MammalTask    

class DtiTask(MammalTask):
    def __init__(self, model_dict):
        super().__init__(name="Drug-Target Binding Affinity", model_dict=model_dict)
        self.description = "Drug-Target Binding Affinity (tdi)"
        self.examples = {
            "target_seq": "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
            "drug_seq":"CC(=O)NCCC1=CNc2c1cc(OC)cc2"
            }
        self.markup_text = """
# Mammal based Target-Drug binding affinity demonstration

Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
"""
    
    def crate_sample_dict(self, sample_inputs:dict, model_holder:MammalObjectBroker):
        """convert sample_inputs to sample_dict including creating a proper prompt

        Args:
            sample_inputs (dict): dictionary containing the inputs to the model
            model_holder (MammalObjectBroker): model holder
        Returns:
           dict: sample_dict for feeding into model
        """
        sample_dict = dict(sample_inputs)
        sample_dict = DtiBindingdbKdTask.data_preprocessing(
            sample_dict=sample_dict,
            tokenizer_op=model_holder.tokenizer_op,
            target_sequence_key="target_seq",
            drug_sequence_key="drug_seq",
            norm_y_mean=None,
            norm_y_std=None,
            device=model_holder.model.device,
        )
        return sample_dict
        

    def run_model(self, sample_dict, model: Mammal):
        # Generate Prediction
        batch_dict = model.forward_encoder_only([sample_dict])
        return batch_dict
        
    def decode_output(self,batch_dict, model_holder):

        # Get output
        batch_dict = DtiBindingdbKdTask.process_model_output(
            batch_dict,
            scalars_preds_processed_key="model.out.dti_bindingdb_kd",
            norm_y_mean=5.79384684128215,
            norm_y_std=1.33808027428196,
            )
        ans = (
        "model.out.dti_bindingdb_kd",
        float(batch_dict["model.out.dti_bindingdb_kd"][0]),
        ) 
        return ans


    def create_and_run_prompt(self,model_name,target_seq, drug_seq):
        model_holder = self.model_dict[model_name]
        inputs = {
            "target_seq": target_seq,
            "drug_seq": drug_seq,
        }
        sample_dict = self.crate_sample_dict(sample_inputs=inputs, model_holder=model_holder)
        prompt=sample_dict[ENCODER_INPUTS_STR]
        batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
        res = prompt, *self.decode_output(batch_dict,model_holder=model_holder)
        return res

    
    def create_demo(self,model_name_widget):
        
    # """
    # ### Using the model from

    # ```{model} ```
    # """
        with gr.Group() as demo:
            gr.Markdown(self.markup_text)
            with gr.Row():
                target_textbox = gr.Textbox(
                    label="target sequence",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["target_seq"],
                )
                drug_textbox = gr.Textbox(
                    label="Drug sequance (in SMILES)",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["drug_seq"],
                )
            with gr.Row():
                run_mammal = gr.Button(
                    "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
                )
            with gr.Row():
                prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

            with gr.Row():
                decoded = gr.Textbox(label="Mammal output key")
                run_mammal.click(
                    fn=self.create_and_run_prompt,
                    inputs=[model_name_widget, target_textbox, drug_textbox],
                    outputs=[prompt_box, decoded, gr.Number(label="binding affinity")],
                )
            demo.visible = False
            return demo