Pouriarouzrokh commited on
Commit
2993f76
·
1 Parent(s): c5d2adf

Added the gradio demo files

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ flagged/
2
+ *.ckpt
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import osail_utils
7
+ import pandas as pd
8
+ import skimage
9
+ from mediffusion import DiffusionModule
10
+ import monai as mn
11
+ import torch
12
+
13
+ # Loading the model for inference
14
+
15
+ model = DiffusionModule("./diffusion_configs.yaml")
16
+ model.load_ckpt("./data/model.ckpt")
17
+ model.cuda().half()
18
+ model.eval();
19
+
20
+ # Loading a baseline noise for making predictions
21
+
22
+ seed = 3407
23
+ np.random.seed(seed)
24
+ torch.random.manual_seed(seed)
25
+ torch.cuda.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+ torch.backends.cudnn.deterministic = True
28
+ BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half()
29
+
30
+ # Model helper functions
31
+
32
+ def create_ds(img_paths):
33
+ if type(img_paths) == str:
34
+ img_paths = [img_paths]
35
+ data_list = [{"img": img_path} for img_path in img_paths]
36
+
37
+ # Get the transforms
38
+ Ts_list = [
39
+ osail_utils.io.LoadImageD(keys=["img"], transpose=True, normalize=True),
40
+ mn.transforms.EnsureChannelFirstD(
41
+ keys=["img"], channel_dim="no_channel"
42
+ ),
43
+ mn.transforms.ResizeD(
44
+ keys=["img"],
45
+ spatial_size=(256, 256),
46
+ mode=["bicubic"],
47
+ ),
48
+ mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1),
49
+ mn.transforms.ToTensorD(keys=["img"], track_meta=None),
50
+ mn.transforms.SelectItemsD(keys=["img"]),
51
+ ]
52
+ return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list))
53
+
54
+ def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"):
55
+
56
+ global model
57
+ global BASELINE_NOISE
58
+
59
+ # Create the image dataset
60
+ if cls_batch is not None:
61
+ ds = create_ds([img_path]*len(cls_batch))
62
+ else:
63
+ ds = create_ds(img_path)
64
+ dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False)
65
+ input_batch = next(iter(dl))
66
+ original_imgs = input_batch["img"].detach().cpu().numpy()
67
+
68
+ # Create the classifier condition if not provided
69
+ if cls_batch is None:
70
+ fp = torch.zeros(768)
71
+ if rotate_to_standard or angles is None:
72
+ angles = [1000, 1000, 1000]
73
+ cls_value = torch.tensor([2, *angles, *fp])
74
+ else:
75
+ cls_value = torch.tensor([1, *angles, *fp])
76
+ cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1).cuda().half()
77
+
78
+ # Generate noise
79
+ noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
80
+ model_kwargs = {
81
+ "cls": cls_batch,
82
+ "concat": input_batch["img"].cuda().half(),
83
+ }
84
+
85
+ # Make predictions
86
+ preds = model.predict(
87
+ noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler
88
+ )
89
+ adjusted_preds = list()
90
+ for pred, original_img in zip(preds, original_imgs):
91
+ adjusted_pred = pred.detach().cpu().numpy().squeeze()
92
+ original_img = original_img.squeeze()
93
+ adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img)
94
+ adjusted_preds.append(adjusted_pred)
95
+ return adjusted_preds
96
+
97
+ # Gradio helper functions
98
+
99
+ current_img = None
100
+ live_preds = None
101
+
102
+ def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
103
+
104
+ global current_img
105
+
106
+ angles = [float(xt), float(yt), float(zt)]
107
+ out_img = make_predictions(img_path, angles)[0]
108
+ if not add_bone_cmap:
109
+ print(out_img.shape)
110
+ return out_img
111
+ cmap = plt.get_cmap('bone')
112
+ out_img = cmap(out_img)
113
+ out_img = (out_img[..., :3] * 255).astype(np.uint8)
114
+ current_img = out_img
115
+ return out_img
116
+
117
+ def rotate_to_standard_btn_fn(img_path, add_bone_cmap=False):
118
+
119
+ global current_img
120
+
121
+ out_img = make_predictions(img_path, rotate_to_standard=True)[0]
122
+ if not add_bone_cmap:
123
+ return out_img
124
+ cmap = plt.get_cmap('bone')
125
+ out_img = cmap(out_img)
126
+ out_img = (out_img[..., :3] * 255).astype(np.uint8)
127
+ current_img = out_img
128
+ return out_img
129
+
130
+ def use_current_btn_fn(input_img):
131
+ return input_img
132
+
133
+
134
+ def make_live_btn_fn(img_path, axis, add_bone_cmap=False):
135
+
136
+ global live_preds
137
+
138
+ base_angles = list(range(-20, 21, 1))
139
+ base_angles = [float(i) for i in base_angles]
140
+ if axis.lower() == "axis x":
141
+ all_angles = [[i, 0, 0] for i in base_angles]
142
+ elif axis.lower() == "axis y":
143
+ all_angles = [[0, i, 0] for i in base_angles]
144
+ elif axis.lower() == "axis z":
145
+ all_angles = [[0, 0, i] for i in base_angles]
146
+ fp = torch.zeros(768)
147
+ cls_batch = torch.tensor([[1, *angles, *fp] for angles in all_angles])
148
+
149
+ live_preds = make_predictions(img_path, cls_batch=cls_batch)
150
+ live_preds = {angle: live_preds[i] for i, angle in enumerate(base_angles)}
151
+ return img_path
152
+
153
+ def rotate_live_img_fn(angle, add_bone_cmap=False):
154
+
155
+ global live_img
156
+ global live_preds
157
+
158
+ if live_img is not None:
159
+ if angle == 0:
160
+ return live_img
161
+ return live_preds[float(angle)]
162
+
163
+ css_style = "./style.css"
164
+ callback = gr.CSVLogger()
165
+ with gr.Blocks(css=css_style) as app:
166
+ gr.HTML("VCNet: A Deep Learning Solution for Roating RadioGraphs in 3D Space", elem_classes="title")
167
+ gr.HTML("Developed by the Orthopedics Surgery Artificial Intelligence Lab (OSAIL)", elem_classes="note")
168
+ gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note")
169
+
170
+ with gr.TabItem("Single Rotation"):
171
+ with gr.Row():
172
+ input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
173
+ output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
174
+ with gr.Row():
175
+ gr.Examples(
176
+ examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
177
+ inputs = [input_img],
178
+ label = "Xray Examples",
179
+ elem_id='examples'
180
+ )
181
+ gr.Examples(
182
+ examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
183
+ inputs = [input_img],
184
+ label = "DRR Examples",
185
+ elem_id='examples'
186
+ )
187
+ with gr.Row():
188
+ gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
189
+ with gr.Row():
190
+ with gr.Column(scale=1):
191
+ xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
192
+ with gr.Column(scale=1):
193
+ yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
194
+ with gr.Column(scale=1):
195
+ zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
196
+ with gr.Row():
197
+ rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
198
+ with gr.Row():
199
+ rotate_to_standard_btn = gr.Button("Rotate to standard view!", elem_classes='rotate_to_standard_button')
200
+ with gr.Row():
201
+ use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
202
+ rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
203
+ rotate_to_standard_btn.click(fn=rotate_to_standard_btn_fn, inputs=[input_img], outputs=output_img)
204
+ use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
205
+
206
+ with gr.TabItem("Live Rotation"):
207
+ with gr.Row():
208
+ live_img = gr.Image(type='filepath', label='Live Image', sources='upload', interactive=False, elem_classes='imgs')
209
+ with gr.Row():
210
+ gr.Examples(
211
+ examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
212
+ inputs = [live_img],
213
+ label = "Xray Examples",
214
+ elem_id='examples'
215
+ )
216
+ gr.Examples(
217
+ examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
218
+ inputs = [live_img],
219
+ label = "DRR Examples",
220
+ elem_id='examples'
221
+ )
222
+ with gr.Row():
223
+ gr.Markdown('Please select an example image, an axis, and then press Make Live!', elem_classes='text')
224
+ with gr.Row():
225
+ axis = gr.Dropdown(choices=['Axis X', 'Axis Y', 'Axis Z'], show_label=False, elem_classes='angle', value='Axis X')
226
+ live_btn = gr.Button("Make Live!", elem_classes='make_live_button')
227
+ with gr.Row():
228
+ gr.Markdown('You can now rotate the radiograph in your selected axis using the scaler.', elem_classes='text')
229
+ with gr.Row():
230
+ slider = gr.Slider(show_label=False, minimum=-20, maximum=20, step=1, value=0, elem_classes='slider', interactive=True)
231
+ live_btn.click(fn=make_live_btn_fn, inputs=[live_img, axis], outputs=live_img)
232
+ slider.change(fn=rotate_live_img_fn, inputs=[slider], outputs=live_img)
233
+
234
+ try:
235
+ app.close()
236
+ gr.close_all()
237
+ except:
238
+ pass
239
+
240
+ demo = app.launch(
241
+ max_threads=4,
242
+ share=True,
243
+ inline=False,
244
+ show_api=False,
245
+ show_error=True,
246
+ server_port=1902,
247
+ server_name="0.0.0.0",
248
+ )
data/examples/drr_0.png ADDED
data/examples/drr_2.png ADDED
data/examples/drr_3.png ADDED
data/examples/drr_4.png ADDED
data/examples/drr_5.png ADDED
data/examples/drr_6.png ADDED
data/examples/drr_7.png ADDED
data/examples/drr_8.png ADDED
data/examples/xr_1.png ADDED
data/examples/xr_2.png ADDED
data/examples/xr_3.png ADDED
data/examples/xr_4.png ADDED
data/examples/xr_5.png ADDED
data/examples/xr_6.png ADDED
data/examples/xr_7.png ADDED
data/examples/xr_8.png ADDED
diffusion_configs.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusion:
2
+ timesteps: 1000
3
+ schedule_name: cosine
4
+ enforce_zero_terminal_snr: true
5
+ schedule_params:
6
+ beta_start: 0.0001
7
+ beta_end: 0.02
8
+ cosine_s: 0.008
9
+ timestep_respacing: null
10
+ mean_type: VELOCITY
11
+ var_type: LEARNED_RANGE
12
+ loss_type: MSE
13
+
14
+ optimizer:
15
+ lr: 0.00001
16
+ type: bkh_pytorch_utils.Lion
17
+
18
+ validation:
19
+ classifier_cond_scale: 4
20
+ protocol: DDPM
21
+ log_original: true
22
+ log_concat: true
23
+ cls_log_indices: [0, 1, 2, 3]
24
+
25
+ model:
26
+ input_size: 256
27
+ dims: 2
28
+ attention_resolutions: [8, 16, 32]
29
+ channel_mult: [1, 1, 2, 2, 4, 4]
30
+ dropout: 0.0
31
+ in_channels: 2
32
+ out_channels: 2
33
+ model_channels: 128
34
+ num_head_channels: -1
35
+ num_heads: 4
36
+ num_heads_upsample: -1
37
+ num_res_blocks: [2, 2, 2, 2, 2, 2]
38
+ resblock_updown: false
39
+ use_checkpoint: false
40
+ use_new_attention_order: false
41
+ use_scale_shift_norm: true
42
+ scale_skip_connection: false
43
+
44
+ # conditions
45
+ num_classes: 772
46
+ # num_classes: 4
47
+ concat_channels: 1
48
+ guidance_drop_prob: 0.1
49
+ missing_class_value: null
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.19.2
2
+ mediffusion==0.7.2
3
+ monai==1.3.0
4
+ sckit-image==0.22.0
5
+ numpy==1.24.1
6
+ pillow==9.3.0
style.css ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .title {
2
+ text-align: center;
3
+ display: block;
4
+ font-size: 30px;
5
+ font-weight: bold;
6
+ }
7
+ .note {
8
+ text-align: center;
9
+ min-height: none;
10
+ }
11
+ #examples {
12
+ justify-content: center;
13
+ align-items: center;
14
+ }
15
+ .angle {
16
+ background: #F2F1EB !important;
17
+ }
18
+ .slider {
19
+ background: #88AB8E !important;
20
+ font-weight: bold;
21
+ }
22
+ .text {
23
+ display: flex;
24
+ text-align: center;
25
+ justify-content: center;
26
+ }
27
+ .normal {
28
+ font-weight: normal
29
+ }
30
+ .rotate_button {
31
+ background: #88AB8E !important;
32
+ font-size: 20px;
33
+ }
34
+ .rotate_to_standard_button {
35
+ background: #AFC8AD !important;
36
+ font-size: 20px;
37
+ }
38
+ .use_current_button {
39
+ background: #EEE7DA !important;
40
+ font-size: 20px;
41
+ }
42
+ .make_live_button {
43
+ background: #EEE7DA !important;
44
+ font-size: 20px;
45
+ }
46
+ .imgs{
47
+ justify-content: center;
48
+ align-items: center;
49
+ display: grid;
50
+ margin: auto;
51
+ width: 256px;
52
+ height: 256px;
53
+ }