matanninio commited on
Commit
4c8737b
·
1 Parent(s): 83ccd79

new_app now works for ppi

Browse files
Files changed (1) hide show
  1. new_app.py +297 -0
new_app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
+ from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
5
+ from mammal.keys import *
6
+ from mammal.model import Mammal
7
+ from abc import ABC, abstractmethod
8
+ class MammalObjectBroker():
9
+ def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:
10
+ self.model_path = model_path
11
+ if name is None:
12
+ name = model_path
13
+ self.name = name
14
+
15
+ if task_list is not None:
16
+ self.tasks=task_list
17
+ else:
18
+ self.task = []
19
+ self._model = None
20
+ self._tokenizer_op = None
21
+
22
+
23
+ @property
24
+ def model(self)-> Mammal:
25
+ if self._model is None:
26
+ self._model = Mammal.from_pretrained(self.model_path)
27
+ self._model.eval()
28
+ return self._model
29
+
30
+ @property
31
+ def tokenizer_op(self):
32
+ if self._tokenizer_op is None:
33
+ self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
34
+ return self._tokenizer_op
35
+
36
+
37
+
38
+
39
+
40
+ class MammalTask(ABC):
41
+ def __init__(self, name:str) -> None:
42
+ self.name = name
43
+ self.description = None
44
+ self._demo = None
45
+
46
+ @abstractmethod
47
+ def generate_prompt(self, **kwargs) -> str:
48
+ """Formatting prompt to match pre-training syntax
49
+
50
+ Args:
51
+ prot1 (_type_): _description_
52
+ prot2 (_type_): _description_
53
+
54
+ Raises:
55
+ No: _description_
56
+ """
57
+ raise NotImplementedError()
58
+
59
+ @abstractmethod
60
+ def crate_sample_dict(self, prompt: str, **kwargs) -> dict:
61
+ """Formatting prompt to match pre-training syntax
62
+
63
+ Args:
64
+ prompt (str): _description_
65
+
66
+ Returns:
67
+ dict: sample_dict for feeding into model
68
+ """
69
+ raise NotImplementedError()
70
+
71
+ # @abstractmethod
72
+ def run_model(self, sample_dict, model:Mammal):
73
+ raise NotImplementedError()
74
+
75
+ @abstractmethod
76
+ def create_demo(self, model_name_dropdown):
77
+ """create an gradio demo group
78
+
79
+ Returns:
80
+ _type_: _description_
81
+ """
82
+ raise NotImplementedError()
83
+
84
+
85
+ def demo(self,model_name_dropdown=None):
86
+ if self._demo is None:
87
+ self._demo = self.create_demo(model_name_dropdown=model_name_dropdown)
88
+ return self._demo
89
+
90
+ @abstractmethod
91
+ def decode_output(self,batch_dict, model:Mammal):
92
+ raise NotImplementedError()
93
+
94
+ #self._setup()
95
+
96
+ # def _setup(self):
97
+ # pass
98
+
99
+
100
+
101
+ all_tasks = dict()
102
+ all_models= dict()
103
+
104
+ class PpiTask(MammalTask):
105
+ def __init__(self):
106
+ super().__init__(name="PPI")
107
+ self.description = "Protein-Protein Interaction (PPI)"
108
+ self.examples = {
109
+ "protein_calmodulin": "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK",
110
+ "protein_calcineurin": "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ",
111
+ }
112
+ self.markup_text = """
113
+ # Mammal based {self.description} demonstration
114
+
115
+ Given two protein sequences, estimate if the proteins interact or not."""
116
+
117
+
118
+
119
+ @staticmethod
120
+ def positive_token_id(model_holder: MammalObjectBroker):
121
+ """token for positive binding
122
+
123
+ Args:
124
+ model (MammalTrainedModel): model holding tokenizer
125
+
126
+ Returns:
127
+ int: id of positive binding token
128
+ """
129
+ return model_holder.tokenizer_op.get_token_id("<1>")
130
+
131
+ def generate_prompt(self, prot1, prot2):
132
+ """Formatting prompt to match pre-training syntax
133
+
134
+ Args:
135
+ prot1 (str): sequance of protein number 1
136
+ prot2 (str): sequance of protein number 2
137
+
138
+ Returns:
139
+ str: prompt
140
+ """
141
+ prompt = "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"\
142
+ "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
143
+ f"<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"\
144
+ "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
145
+ f"<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
146
+ return prompt
147
+
148
+
149
+ def crate_sample_dict(self,prompt: str, model_holder:MammalObjectBroker):
150
+ # Create and load sample
151
+ sample_dict = dict()
152
+ sample_dict[ENCODER_INPUTS_STR] = prompt
153
+
154
+ # Tokenize
155
+ sample_dict = model_holder.tokenizer_op(
156
+ sample_dict=sample_dict,
157
+ key_in=ENCODER_INPUTS_STR,
158
+ key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
159
+ key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
160
+ )
161
+ sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
162
+ sample_dict[ENCODER_INPUTS_TOKENS]
163
+ )
164
+ sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
165
+ sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
166
+ )
167
+ return sample_dict
168
+
169
+ def run_model(self, sample_dict, model: Mammal):
170
+ # Generate Prediction
171
+ batch_dict = model.generate(
172
+ [sample_dict],
173
+ output_scores=True,
174
+ return_dict_in_generate=True,
175
+ max_new_tokens=5,
176
+ )
177
+ return batch_dict
178
+
179
+ def decode_output(self,batch_dict, model_holder):
180
+
181
+ # Get output
182
+ generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
183
+ score = batch_dict["model.out.scores"][0][1][self.positive_token_id(model_holder)].item()
184
+
185
+ return generated_output, score
186
+
187
+
188
+ def create_and_run_prompt(self,model_name,protein1, protein2):
189
+ model_holder = all_models[model_name]
190
+ prompt = self.generate_prompt(protein1, protein2)
191
+ sample_dict = self.crate_sample_dict(prompt=prompt, model_holder=model_holder)
192
+ batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
193
+ res = prompt, *self.decode_output(batch_dict,model_holder=model_holder)
194
+ return res
195
+
196
+
197
+ def create_demo(self,model_name_dropdown):
198
+
199
+ # """
200
+ # ### Using the model from
201
+
202
+ # ```{model} ```
203
+ # """
204
+ with gr.Group() as demo:
205
+ gr.Markdown(self.markup_text)
206
+ with gr.Row():
207
+ prot1 = gr.Textbox(
208
+ label="Protein 1 sequence",
209
+ # info="standard",
210
+ interactive=True,
211
+ lines=3,
212
+ value=self.examples["protein_calmodulin"],
213
+ )
214
+ prot2 = gr.Textbox(
215
+ label="Protein 2 sequence",
216
+ # info="standard",
217
+ interactive=True,
218
+ lines=3,
219
+ value=self.examples["protein_calcineurin"],
220
+ )
221
+ with gr.Row():
222
+ run_mammal = gr.Button(
223
+ "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
224
+ )
225
+ with gr.Row():
226
+ prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
227
+
228
+ with gr.Row():
229
+ decoded = gr.Textbox(label="Mammal output")
230
+ run_mammal.click(
231
+ fn=self.create_and_run_prompt,
232
+ inputs=[model_name_dropdown, prot1, prot2],
233
+ outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
234
+ )
235
+ with gr.Row():
236
+ gr.Markdown(
237
+ "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
238
+ )
239
+ demo.visible = True
240
+ return demo
241
+
242
+ ppi_task = PpiTask()
243
+ all_tasks[ppi_task.name]=ppi_task
244
+
245
+ ppi_model = MammalObjectBroker(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=["PPI"])
246
+
247
+ all_models[ppi_model.name]=ppi_model
248
+ # tdi_model = MammalTrainedModel(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd") TODO: ## task list still empty
249
+ # all_models.append(tdi_model)
250
+
251
+
252
+ def create_application():
253
+ def task_change(value):
254
+ choices=[model_name for model_name, model in all_models.items() if value in model.tasks]
255
+ if choices:
256
+ return gr.update(choices=choices, value=choices[0])
257
+ else:
258
+ return
259
+ # return model_name_dropdown
260
+
261
+
262
+ with gr.Blocks() as demo:
263
+ task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()))
264
+ task_dropdown.interactive = True
265
+ model_name_dropdown = gr.Dropdown(choices=[model_name for model_name, model in all_models.items() if task_dropdown.value in model.tasks], interactive=True)
266
+ task_dropdown.change(task_change,inputs=[task_dropdown],outputs=[model_name_dropdown])
267
+
268
+
269
+
270
+
271
+
272
+ ppi_demo = all_tasks["PPI"].demo(model_name_dropdown = model_name_dropdown)
273
+ ppi_demo.visible = True
274
+ # dtb_demo = create_tdb_demo()
275
+
276
+ def set_ppi_vis(main_text):
277
+ main_text=main_text
278
+ print(f"main text is {main_text}")
279
+ return gr.Group(visible=True)
280
+ #return gr.Group(visible=(main_text == "PPI"))
281
+ # , gr.Group( visible=(main_text == "DTI") )
282
+
283
+ task_dropdown.change(
284
+ set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
285
+ )
286
+ return demo
287
+
288
+ full_demo=None
289
+ def main():
290
+ global full_demo
291
+ full_demo = create_application()
292
+ full_demo.launch(show_error=True, share=False)
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
297
+