Spaces:
Runtime error
Runtime error
Commit
·
2993f76
1
Parent(s):
c5d2adf
Added the gradio demo files
Browse files- .gitignore +2 -0
- app.py +248 -0
- data/examples/drr_0.png +0 -0
- data/examples/drr_2.png +0 -0
- data/examples/drr_3.png +0 -0
- data/examples/drr_4.png +0 -0
- data/examples/drr_5.png +0 -0
- data/examples/drr_6.png +0 -0
- data/examples/drr_7.png +0 -0
- data/examples/drr_8.png +0 -0
- data/examples/xr_1.png +0 -0
- data/examples/xr_2.png +0 -0
- data/examples/xr_3.png +0 -0
- data/examples/xr_4.png +0 -0
- data/examples/xr_5.png +0 -0
- data/examples/xr_6.png +0 -0
- data/examples/xr_7.png +0 -0
- data/examples/xr_8.png +0 -0
- diffusion_configs.yaml +49 -0
- requirements.txt +6 -0
- style.css +53 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
flagged/
|
2 |
+
*.ckpt
|
app.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import osail_utils
|
7 |
+
import pandas as pd
|
8 |
+
import skimage
|
9 |
+
from mediffusion import DiffusionModule
|
10 |
+
import monai as mn
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# Loading the model for inference
|
14 |
+
|
15 |
+
model = DiffusionModule("./diffusion_configs.yaml")
|
16 |
+
model.load_ckpt("./data/model.ckpt")
|
17 |
+
model.cuda().half()
|
18 |
+
model.eval();
|
19 |
+
|
20 |
+
# Loading a baseline noise for making predictions
|
21 |
+
|
22 |
+
seed = 3407
|
23 |
+
np.random.seed(seed)
|
24 |
+
torch.random.manual_seed(seed)
|
25 |
+
torch.cuda.manual_seed(seed)
|
26 |
+
torch.cuda.manual_seed_all(seed)
|
27 |
+
torch.backends.cudnn.deterministic = True
|
28 |
+
BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half()
|
29 |
+
|
30 |
+
# Model helper functions
|
31 |
+
|
32 |
+
def create_ds(img_paths):
|
33 |
+
if type(img_paths) == str:
|
34 |
+
img_paths = [img_paths]
|
35 |
+
data_list = [{"img": img_path} for img_path in img_paths]
|
36 |
+
|
37 |
+
# Get the transforms
|
38 |
+
Ts_list = [
|
39 |
+
osail_utils.io.LoadImageD(keys=["img"], transpose=True, normalize=True),
|
40 |
+
mn.transforms.EnsureChannelFirstD(
|
41 |
+
keys=["img"], channel_dim="no_channel"
|
42 |
+
),
|
43 |
+
mn.transforms.ResizeD(
|
44 |
+
keys=["img"],
|
45 |
+
spatial_size=(256, 256),
|
46 |
+
mode=["bicubic"],
|
47 |
+
),
|
48 |
+
mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1),
|
49 |
+
mn.transforms.ToTensorD(keys=["img"], track_meta=None),
|
50 |
+
mn.transforms.SelectItemsD(keys=["img"]),
|
51 |
+
]
|
52 |
+
return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list))
|
53 |
+
|
54 |
+
def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"):
|
55 |
+
|
56 |
+
global model
|
57 |
+
global BASELINE_NOISE
|
58 |
+
|
59 |
+
# Create the image dataset
|
60 |
+
if cls_batch is not None:
|
61 |
+
ds = create_ds([img_path]*len(cls_batch))
|
62 |
+
else:
|
63 |
+
ds = create_ds(img_path)
|
64 |
+
dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False)
|
65 |
+
input_batch = next(iter(dl))
|
66 |
+
original_imgs = input_batch["img"].detach().cpu().numpy()
|
67 |
+
|
68 |
+
# Create the classifier condition if not provided
|
69 |
+
if cls_batch is None:
|
70 |
+
fp = torch.zeros(768)
|
71 |
+
if rotate_to_standard or angles is None:
|
72 |
+
angles = [1000, 1000, 1000]
|
73 |
+
cls_value = torch.tensor([2, *angles, *fp])
|
74 |
+
else:
|
75 |
+
cls_value = torch.tensor([1, *angles, *fp])
|
76 |
+
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1).cuda().half()
|
77 |
+
|
78 |
+
# Generate noise
|
79 |
+
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
|
80 |
+
model_kwargs = {
|
81 |
+
"cls": cls_batch,
|
82 |
+
"concat": input_batch["img"].cuda().half(),
|
83 |
+
}
|
84 |
+
|
85 |
+
# Make predictions
|
86 |
+
preds = model.predict(
|
87 |
+
noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler
|
88 |
+
)
|
89 |
+
adjusted_preds = list()
|
90 |
+
for pred, original_img in zip(preds, original_imgs):
|
91 |
+
adjusted_pred = pred.detach().cpu().numpy().squeeze()
|
92 |
+
original_img = original_img.squeeze()
|
93 |
+
adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img)
|
94 |
+
adjusted_preds.append(adjusted_pred)
|
95 |
+
return adjusted_preds
|
96 |
+
|
97 |
+
# Gradio helper functions
|
98 |
+
|
99 |
+
current_img = None
|
100 |
+
live_preds = None
|
101 |
+
|
102 |
+
def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
|
103 |
+
|
104 |
+
global current_img
|
105 |
+
|
106 |
+
angles = [float(xt), float(yt), float(zt)]
|
107 |
+
out_img = make_predictions(img_path, angles)[0]
|
108 |
+
if not add_bone_cmap:
|
109 |
+
print(out_img.shape)
|
110 |
+
return out_img
|
111 |
+
cmap = plt.get_cmap('bone')
|
112 |
+
out_img = cmap(out_img)
|
113 |
+
out_img = (out_img[..., :3] * 255).astype(np.uint8)
|
114 |
+
current_img = out_img
|
115 |
+
return out_img
|
116 |
+
|
117 |
+
def rotate_to_standard_btn_fn(img_path, add_bone_cmap=False):
|
118 |
+
|
119 |
+
global current_img
|
120 |
+
|
121 |
+
out_img = make_predictions(img_path, rotate_to_standard=True)[0]
|
122 |
+
if not add_bone_cmap:
|
123 |
+
return out_img
|
124 |
+
cmap = plt.get_cmap('bone')
|
125 |
+
out_img = cmap(out_img)
|
126 |
+
out_img = (out_img[..., :3] * 255).astype(np.uint8)
|
127 |
+
current_img = out_img
|
128 |
+
return out_img
|
129 |
+
|
130 |
+
def use_current_btn_fn(input_img):
|
131 |
+
return input_img
|
132 |
+
|
133 |
+
|
134 |
+
def make_live_btn_fn(img_path, axis, add_bone_cmap=False):
|
135 |
+
|
136 |
+
global live_preds
|
137 |
+
|
138 |
+
base_angles = list(range(-20, 21, 1))
|
139 |
+
base_angles = [float(i) for i in base_angles]
|
140 |
+
if axis.lower() == "axis x":
|
141 |
+
all_angles = [[i, 0, 0] for i in base_angles]
|
142 |
+
elif axis.lower() == "axis y":
|
143 |
+
all_angles = [[0, i, 0] for i in base_angles]
|
144 |
+
elif axis.lower() == "axis z":
|
145 |
+
all_angles = [[0, 0, i] for i in base_angles]
|
146 |
+
fp = torch.zeros(768)
|
147 |
+
cls_batch = torch.tensor([[1, *angles, *fp] for angles in all_angles])
|
148 |
+
|
149 |
+
live_preds = make_predictions(img_path, cls_batch=cls_batch)
|
150 |
+
live_preds = {angle: live_preds[i] for i, angle in enumerate(base_angles)}
|
151 |
+
return img_path
|
152 |
+
|
153 |
+
def rotate_live_img_fn(angle, add_bone_cmap=False):
|
154 |
+
|
155 |
+
global live_img
|
156 |
+
global live_preds
|
157 |
+
|
158 |
+
if live_img is not None:
|
159 |
+
if angle == 0:
|
160 |
+
return live_img
|
161 |
+
return live_preds[float(angle)]
|
162 |
+
|
163 |
+
css_style = "./style.css"
|
164 |
+
callback = gr.CSVLogger()
|
165 |
+
with gr.Blocks(css=css_style) as app:
|
166 |
+
gr.HTML("VCNet: A Deep Learning Solution for Roating RadioGraphs in 3D Space", elem_classes="title")
|
167 |
+
gr.HTML("Developed by the Orthopedics Surgery Artificial Intelligence Lab (OSAIL)", elem_classes="note")
|
168 |
+
gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note")
|
169 |
+
|
170 |
+
with gr.TabItem("Single Rotation"):
|
171 |
+
with gr.Row():
|
172 |
+
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
|
173 |
+
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
|
174 |
+
with gr.Row():
|
175 |
+
gr.Examples(
|
176 |
+
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
|
177 |
+
inputs = [input_img],
|
178 |
+
label = "Xray Examples",
|
179 |
+
elem_id='examples'
|
180 |
+
)
|
181 |
+
gr.Examples(
|
182 |
+
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
|
183 |
+
inputs = [input_img],
|
184 |
+
label = "DRR Examples",
|
185 |
+
elem_id='examples'
|
186 |
+
)
|
187 |
+
with gr.Row():
|
188 |
+
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
|
189 |
+
with gr.Row():
|
190 |
+
with gr.Column(scale=1):
|
191 |
+
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
|
192 |
+
with gr.Column(scale=1):
|
193 |
+
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
|
194 |
+
with gr.Column(scale=1):
|
195 |
+
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
|
196 |
+
with gr.Row():
|
197 |
+
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
|
198 |
+
with gr.Row():
|
199 |
+
rotate_to_standard_btn = gr.Button("Rotate to standard view!", elem_classes='rotate_to_standard_button')
|
200 |
+
with gr.Row():
|
201 |
+
use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
|
202 |
+
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
|
203 |
+
rotate_to_standard_btn.click(fn=rotate_to_standard_btn_fn, inputs=[input_img], outputs=output_img)
|
204 |
+
use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
|
205 |
+
|
206 |
+
with gr.TabItem("Live Rotation"):
|
207 |
+
with gr.Row():
|
208 |
+
live_img = gr.Image(type='filepath', label='Live Image', sources='upload', interactive=False, elem_classes='imgs')
|
209 |
+
with gr.Row():
|
210 |
+
gr.Examples(
|
211 |
+
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
|
212 |
+
inputs = [live_img],
|
213 |
+
label = "Xray Examples",
|
214 |
+
elem_id='examples'
|
215 |
+
)
|
216 |
+
gr.Examples(
|
217 |
+
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
|
218 |
+
inputs = [live_img],
|
219 |
+
label = "DRR Examples",
|
220 |
+
elem_id='examples'
|
221 |
+
)
|
222 |
+
with gr.Row():
|
223 |
+
gr.Markdown('Please select an example image, an axis, and then press Make Live!', elem_classes='text')
|
224 |
+
with gr.Row():
|
225 |
+
axis = gr.Dropdown(choices=['Axis X', 'Axis Y', 'Axis Z'], show_label=False, elem_classes='angle', value='Axis X')
|
226 |
+
live_btn = gr.Button("Make Live!", elem_classes='make_live_button')
|
227 |
+
with gr.Row():
|
228 |
+
gr.Markdown('You can now rotate the radiograph in your selected axis using the scaler.', elem_classes='text')
|
229 |
+
with gr.Row():
|
230 |
+
slider = gr.Slider(show_label=False, minimum=-20, maximum=20, step=1, value=0, elem_classes='slider', interactive=True)
|
231 |
+
live_btn.click(fn=make_live_btn_fn, inputs=[live_img, axis], outputs=live_img)
|
232 |
+
slider.change(fn=rotate_live_img_fn, inputs=[slider], outputs=live_img)
|
233 |
+
|
234 |
+
try:
|
235 |
+
app.close()
|
236 |
+
gr.close_all()
|
237 |
+
except:
|
238 |
+
pass
|
239 |
+
|
240 |
+
demo = app.launch(
|
241 |
+
max_threads=4,
|
242 |
+
share=True,
|
243 |
+
inline=False,
|
244 |
+
show_api=False,
|
245 |
+
show_error=True,
|
246 |
+
server_port=1902,
|
247 |
+
server_name="0.0.0.0",
|
248 |
+
)
|
data/examples/drr_0.png
ADDED
![]() |
data/examples/drr_2.png
ADDED
![]() |
data/examples/drr_3.png
ADDED
![]() |
data/examples/drr_4.png
ADDED
![]() |
data/examples/drr_5.png
ADDED
![]() |
data/examples/drr_6.png
ADDED
![]() |
data/examples/drr_7.png
ADDED
![]() |
data/examples/drr_8.png
ADDED
![]() |
data/examples/xr_1.png
ADDED
![]() |
data/examples/xr_2.png
ADDED
![]() |
data/examples/xr_3.png
ADDED
![]() |
data/examples/xr_4.png
ADDED
![]() |
data/examples/xr_5.png
ADDED
![]() |
data/examples/xr_6.png
ADDED
![]() |
data/examples/xr_7.png
ADDED
![]() |
data/examples/xr_8.png
ADDED
![]() |
diffusion_configs.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusion:
|
2 |
+
timesteps: 1000
|
3 |
+
schedule_name: cosine
|
4 |
+
enforce_zero_terminal_snr: true
|
5 |
+
schedule_params:
|
6 |
+
beta_start: 0.0001
|
7 |
+
beta_end: 0.02
|
8 |
+
cosine_s: 0.008
|
9 |
+
timestep_respacing: null
|
10 |
+
mean_type: VELOCITY
|
11 |
+
var_type: LEARNED_RANGE
|
12 |
+
loss_type: MSE
|
13 |
+
|
14 |
+
optimizer:
|
15 |
+
lr: 0.00001
|
16 |
+
type: bkh_pytorch_utils.Lion
|
17 |
+
|
18 |
+
validation:
|
19 |
+
classifier_cond_scale: 4
|
20 |
+
protocol: DDPM
|
21 |
+
log_original: true
|
22 |
+
log_concat: true
|
23 |
+
cls_log_indices: [0, 1, 2, 3]
|
24 |
+
|
25 |
+
model:
|
26 |
+
input_size: 256
|
27 |
+
dims: 2
|
28 |
+
attention_resolutions: [8, 16, 32]
|
29 |
+
channel_mult: [1, 1, 2, 2, 4, 4]
|
30 |
+
dropout: 0.0
|
31 |
+
in_channels: 2
|
32 |
+
out_channels: 2
|
33 |
+
model_channels: 128
|
34 |
+
num_head_channels: -1
|
35 |
+
num_heads: 4
|
36 |
+
num_heads_upsample: -1
|
37 |
+
num_res_blocks: [2, 2, 2, 2, 2, 2]
|
38 |
+
resblock_updown: false
|
39 |
+
use_checkpoint: false
|
40 |
+
use_new_attention_order: false
|
41 |
+
use_scale_shift_norm: true
|
42 |
+
scale_skip_connection: false
|
43 |
+
|
44 |
+
# conditions
|
45 |
+
num_classes: 772
|
46 |
+
# num_classes: 4
|
47 |
+
concat_channels: 1
|
48 |
+
guidance_drop_prob: 0.1
|
49 |
+
missing_class_value: null
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.19.2
|
2 |
+
mediffusion==0.7.2
|
3 |
+
monai==1.3.0
|
4 |
+
sckit-image==0.22.0
|
5 |
+
numpy==1.24.1
|
6 |
+
pillow==9.3.0
|
style.css
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.title {
|
2 |
+
text-align: center;
|
3 |
+
display: block;
|
4 |
+
font-size: 30px;
|
5 |
+
font-weight: bold;
|
6 |
+
}
|
7 |
+
.note {
|
8 |
+
text-align: center;
|
9 |
+
min-height: none;
|
10 |
+
}
|
11 |
+
#examples {
|
12 |
+
justify-content: center;
|
13 |
+
align-items: center;
|
14 |
+
}
|
15 |
+
.angle {
|
16 |
+
background: #F2F1EB !important;
|
17 |
+
}
|
18 |
+
.slider {
|
19 |
+
background: #88AB8E !important;
|
20 |
+
font-weight: bold;
|
21 |
+
}
|
22 |
+
.text {
|
23 |
+
display: flex;
|
24 |
+
text-align: center;
|
25 |
+
justify-content: center;
|
26 |
+
}
|
27 |
+
.normal {
|
28 |
+
font-weight: normal
|
29 |
+
}
|
30 |
+
.rotate_button {
|
31 |
+
background: #88AB8E !important;
|
32 |
+
font-size: 20px;
|
33 |
+
}
|
34 |
+
.rotate_to_standard_button {
|
35 |
+
background: #AFC8AD !important;
|
36 |
+
font-size: 20px;
|
37 |
+
}
|
38 |
+
.use_current_button {
|
39 |
+
background: #EEE7DA !important;
|
40 |
+
font-size: 20px;
|
41 |
+
}
|
42 |
+
.make_live_button {
|
43 |
+
background: #EEE7DA !important;
|
44 |
+
font-size: 20px;
|
45 |
+
}
|
46 |
+
.imgs{
|
47 |
+
justify-content: center;
|
48 |
+
align-items: center;
|
49 |
+
display: grid;
|
50 |
+
margin: auto;
|
51 |
+
width: 256px;
|
52 |
+
height: 256px;
|
53 |
+
}
|