File size: 4,435 Bytes
0c8cec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.protein_solubility.task import ProteinSolubilityTask
from mammal.keys import (
    ENCODER_INPUTS_STR,
    CLS_PRED,
    SCORES,
)
from mammal.model import Mammal

from mammal_demo.demo_framework import MammalObjectBroker, MammalTask


class PsTask(MammalTask):
    def __init__(self, model_dict):
        super().__init__(name="Protein Solubility", model_dict=model_dict)
        self.description = "Protein Solubility (PS)"
        self.examples = {
            "protein_seq": "LLQTGIHVRVSQPSL",
        }
        self.markup_text = """
# Mammal based TODO:  T-cell receptors-peptide binding specificity demonstration

Given the TCR beta sequance and the epitope sequacne, estimate the binding specificity.
"""



    def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
        """convert sample_inputs to sample_dict including creating a proper prompt

        Args:
            sample_inputs (dict): dictionary containing the inputs to the model
            model_holder (MammalObjectBroker): model holder
        Returns:
           dict: sample_dict for feeding into model
        """
        sample_dict = dict(sample_inputs) # shallow copy 
        sample_dict = ProteinSolubilityTask.data_preprocessing(
        sample_dict=sample_dict,
        protein_sequence_key="protein_seq",
        tokenizer_op=model_holder.tokenizer_op,
        device=model_holder.model.device,
        )

        return sample_dict

    def run_model(self, sample_dict, model: Mammal):
        # Generate Prediction
        batch_dict = model.generate(
            [sample_dict],
            output_scores=True,
            return_dict_in_generate=True,
            max_new_tokens=5,
        )
        return batch_dict

    def decode_output(self, batch_dict, tokenizer_op: ModularTokenizerOp)-> dict:

        """
        Extract predicted class and scores
        """
        ans_dict = ProteinSolubilityTask.process_model_output(
            tokenizer_op=tokenizer_op,
            decoder_output=batch_dict[CLS_PRED][0],
            decoder_output_scores=batch_dict[SCORES][0],
        )
        ans = [
            tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]),
            ans_dict["pred"],
            ans_dict["not_normalized_scores"].item(),
            ans_dict["normalized_scores"].item(),
        ]        
        return ans



    def create_and_run_prompt(self, model_name, protein_seq):
        model_holder = self.model_dict[model_name]
        inputs = {
            "protein_seq": protein_seq,
        }
        sample_dict = self.crate_sample_dict(
            sample_inputs=inputs, model_holder=model_holder
        )
        prompt = sample_dict[ENCODER_INPUTS_STR]
        batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
        res = prompt, *self.decode_output(batch_dict, tokenizer_op=model_holder.tokenizer_op)
        return res


        
    def create_demo(self, model_name_widget):

        
        with gr.Group() as demo:
            gr.Markdown(self.markup_text)
            with gr.Row():
                protein_textbox = gr.Textbox(
                    label="Protein sequance",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["protein_seq"],
                )
            with gr.Row():
                run_mammal = gr.Button(
                    "Run Mammal prompt for TCL-Epitope Interaction",
                    variant="primary",
                )
            with gr.Row():
                prompt_box = gr.Textbox(label="Mammal prompt", lines=5)

            with gr.Row():
                decoded = gr.Textbox(label="Mammal output")
                predicted_class = gr.Textbox(label="Mammal prediction")
                with gr.Column():
                    non_norm_score = gr.Number(label="Non normelized score")
                    norm_score = gr.Number(label="Normelized score")
                run_mammal.click(
                    fn=self.create_and_run_prompt,
                    inputs=[model_name_widget, protein_textbox],
                    outputs=[prompt_box, decoded, predicted_class,non_norm_score,norm_score],
                )
            demo.visible = False
            return demo