File size: 5,674 Bytes
08232dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import os
from npc_bert_models.gradio_demo import *
from npc_bert_models.mlm_module import NpcBertMLM
from npc_bert_models.cls_module import NpcBertCLS
import json


class main_window():
    def __init__(self):
        self.interface = None
        self.examples = json.load(open("examples.json", 'r'))
        
    def initialize(self):
        #! Initialize MLM
        self.npc_mlm = NpcBertMLM()
        self.npc_mlm.load()
       
        with gr.Blocks() as self.mlm_interface:
            gr.Markdown("# Masked work prediction\n"
                        "Enter any sentence. Use the token `[MASK]` to see what the model predicts.\n"
                        "## Our examples:\n"
                        "|Input masked sequence|Ground-truth masked word|\n"
                        "|---------------------|------------------------|\n"
                        + "\n".join([f"|{a}|{b}|" for a, b in zip(self.examples['mlm-inp'], self.examples['mlm-inp-GT'])]))
            
            with gr.Row():
                with gr.Column():
                    inp = gr.Textbox("The tumor is confined in the [MASK].", label='mlm-inp')
                    btn = gr.Button("Run", variant='primary')
                    
                with gr.Column():
                    out = gr.Label(num_top_classes=5)
            

            gr.Examples(self.examples['mlm-inp'], inputs=inp, label='mlm-inp')
            btn.click(fn=self.npc_mlm.__call__, inputs=inp, outputs=out)
            inp.submit(fn=self.npc_mlm.__call__, inputs=inp, outputs=out)

        #! Initialize report classification
        self.npc_cls = NpcBertCLS()
        self.npc_cls.load()
        
        with gr.Blocks() as self.cls_interface:
            gr.Markdown("""
                        # Report discrimination
                        
                        In this example we explored how the fine-tuned BERT aids downstream task. We further trained it
                        to do a simple task of discriminating between reports written for non-NPC patients and NPC patients.
                        
                        # Disclaimer
                        
                        The examples are mock reports that is created with reference to authentic reports, they do not
                        represent any real patients. However, it was written to be an authentic representation of NPC or
                        patient under investigation for suspected NPC but with negative imaging findings.
                        """)
            
            with gr.Row():
                with gr.Column():
                    inp = gr.TextArea(placeholder="Use examples at the bottom to load example text reports.")
                    inf = gr.File(file_types=['.txt'], label="Report file (plaintext)", show_label=True, interactive=True)
                    inf.upload(fn=self._set_report_file_helper, inputs=inf, outputs=inp)
                    inf.change(fn=self._set_report_file_helper, inputs=inf, outputs=inp)
                    btn = gr.Button("Run", variant='primary')
                    

                with gr.Column():
                    out = gr.Label(num_top_classes=2)
                
            # gr.Examples(examples=list(self.examples['reports'].values()), inputs=inp)
            gr.Examples(examples="./report_examples", inputs=inf)
            btn.click(fn=self.npc_cls.__call__, inputs=inp, outputs=out)
            inp.submit(fn=self.npc_cls.__call__, inputs=inp, outputs=out)
        
        with gr.Blocks() as self.interface:
            gr.Markdown("""
                        # Introduction 
                        
                        This is a demo for displaying the potential of language models fine tunned using the carefully curated dataset
                        of structured MRI radiology reports for nasopharyngeal carcinoma (NPC) examination. Our team has an established
                        track record for researching the role of AI in early detectio for NPC. We have already developed an AI system
                        with high sensitivty and specificity > 90%. However. we explanability of the system is currently a major challenge
                        for translation. In fact, this is a general problem for AI's developement in radiology. Therefore, in this pilot 
                        study, we investigate language model in understanding the context of the disease to explore the possibility of incorporating
                        language model in our existing system for explanability. 
                        
                        # Affliations
                        
                        * Dr. M.Lun Wong, Dept. Imaging and Interventional Radiology. The Chinese University of Hong Kong
                        
                        # Disclaimer
                        
                        This software is provided as is and it is not a clinically validated software. The authors disclaim any responsibility
                        arising as a consequence from using this demo. 
                        """)
            tabs = gr.TabbedInterface([self.mlm_interface, self.cls_interface], tab_names=["Masked Language Model", "Report classification"])

    def lauch(self):
        self.interface.launch()
        pass
            
    def _set_report_file_helper(self, file_in):
        try:
            text = open(file_in, 'r').read()
            return text
        except:
            print(f"Cannot read file {file_in}")
            # Do nothing
            pass
        

if __name__ == '__main__':
    mw = main_window()
    mw.initialize()
    mw.lauch()