File size: 3,632 Bytes
ac117b5
71382c0
19dfa7a
 
 
f98cc68
0c8cec9
71382c0
19dfa7a
 
71382c0
f98cc68
 
 
 
19dfa7a
 
 
 
 
 
f98cc68
 
 
 
0c8cec9
 
 
 
f98cc68
 
19dfa7a
f98cc68
19dfa7a
 
71382c0
19dfa7a
 
f98cc68
19dfa7a
 
71382c0
f98cc68
 
 
 
 
71382c0
0c8cec9
 
 
 
 
 
 
19dfa7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71382c0
19dfa7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71382c0
ac117b5
 
19dfa7a
 
0c8cec9
 
ac117b5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr

from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
from mammal_demo.dti_task import DtiTask
from mammal_demo.ppi_task import PpiTask
from mammal_demo.tcr_task import TcrTask
from mammal_demo.ps_task import PsTask

all_tasks: dict[str, MammalTask] = dict()
all_models: dict[str, MammalObjectBroker] = dict()


# first create the required tasks
# Note that the tasks need access to the models, as the model to use depends on the state of the widget
# we pass the all_models dict and update it when we actualy have the models.
ppi_task = PpiTask(model_dict=all_models)
all_tasks[ppi_task.name] = ppi_task

tdi_task = DtiTask(model_dict=all_models)
all_tasks[tdi_task.name] = tdi_task

tcr_task = TcrTask(model_dict=all_models)
all_tasks[tcr_task.name] = tcr_task


ps_task = PsTask(model_dict=all_models)
all_tasks[ps_task.name] = ps_task


# create the model holders. hold the model and the tokenizer, lazy download
# note that the list of relevent tasks needs to be stated.
ppi_model = MammalObjectBroker(
    model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=[ppi_task.name,tcr_task.name]
)
all_models[ppi_model.name] = ppi_model

tdi_model = MammalObjectBroker(
    model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
task_list=[tdi_task.name],
)
all_models[tdi_model.name] = tdi_model

tcr_model = MammalObjectBroker(
    model_path= "ibm/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind",
    task_list=[tcr_task.name]
)
all_models[tcr_model.name] = tcr_model

ps_model = MammalObjectBroker(
    model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility",
    task_list=[ps_task.name]
)
all_models[ps_model.name] = ps_model


def create_application():
    def task_change(value):
        visibility = [gr.update(visible=(task == value)) for task in all_tasks.keys()]
        choices = [
            model_name
            for model_name, model in all_models.items()
            if value in model.tasks
        ]
        if choices:
            return (gr.update(choices=choices, value=choices[0], visible=True), *visibility)
        else:
            return (gr.skip, *visibility)
        # return model_name_dropdown

    with gr.Blocks() as application:
        task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()), label="Mammal Task")
        task_dropdown.interactive = True
        model_name_dropdown = gr.Dropdown(
            choices=[
                model_name
                for model_name, model in all_models.items()
                if task_dropdown.value in model.tasks
            ],
            interactive=True,
            label="Matching Mammal models",
            visible=False,
        )

        task_dropdown.change(
            task_change,
            inputs=[task_dropdown],
            outputs=[model_name_dropdown]
            + [all_tasks[task].demo(model_name_widgit=model_name_dropdown) for task in all_tasks],
        )

        # def set_demo_vis(main_text):
        #     main_text=main_text
        #     print(f"main text is {main_text}")
        #     return gr.Group(visible=True)
        #     #return gr.Group(visible=(main_text == "PPI"))
        # # , gr.Group(                visible=(main_text == "DTI")            )

        # task_dropdown.change(
        # set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
        # )
        return application


full_demo = None


def main():
    global full_demo
    full_demo = create_application()
    full_demo.launch(show_error=True, share=False)
    # full_demo.launch(show_error=True, share=True)


if __name__ == "__main__":
    main()