RadRotator / app.py
Pouriarouzrokh's picture
Added the gradio demo files
2993f76
raw
history blame
9.56 kB
import os
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import osail_utils
import pandas as pd
import skimage
from mediffusion import DiffusionModule
import monai as mn
import torch
# Loading the model for inference
model = DiffusionModule("./diffusion_configs.yaml")
model.load_ckpt("./data/model.ckpt")
model.cuda().half()
model.eval();
# Loading a baseline noise for making predictions
seed = 3407
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half()
# Model helper functions
def create_ds(img_paths):
if type(img_paths) == str:
img_paths = [img_paths]
data_list = [{"img": img_path} for img_path in img_paths]
# Get the transforms
Ts_list = [
osail_utils.io.LoadImageD(keys=["img"], transpose=True, normalize=True),
mn.transforms.EnsureChannelFirstD(
keys=["img"], channel_dim="no_channel"
),
mn.transforms.ResizeD(
keys=["img"],
spatial_size=(256, 256),
mode=["bicubic"],
),
mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1),
mn.transforms.ToTensorD(keys=["img"], track_meta=None),
mn.transforms.SelectItemsD(keys=["img"]),
]
return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list))
def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"):
global model
global BASELINE_NOISE
# Create the image dataset
if cls_batch is not None:
ds = create_ds([img_path]*len(cls_batch))
else:
ds = create_ds(img_path)
dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False)
input_batch = next(iter(dl))
original_imgs = input_batch["img"].detach().cpu().numpy()
# Create the classifier condition if not provided
if cls_batch is None:
fp = torch.zeros(768)
if rotate_to_standard or angles is None:
angles = [1000, 1000, 1000]
cls_value = torch.tensor([2, *angles, *fp])
else:
cls_value = torch.tensor([1, *angles, *fp])
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1).cuda().half()
# Generate noise
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
model_kwargs = {
"cls": cls_batch,
"concat": input_batch["img"].cuda().half(),
}
# Make predictions
preds = model.predict(
noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler
)
adjusted_preds = list()
for pred, original_img in zip(preds, original_imgs):
adjusted_pred = pred.detach().cpu().numpy().squeeze()
original_img = original_img.squeeze()
adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img)
adjusted_preds.append(adjusted_pred)
return adjusted_preds
# Gradio helper functions
current_img = None
live_preds = None
def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
global current_img
angles = [float(xt), float(yt), float(zt)]
out_img = make_predictions(img_path, angles)[0]
if not add_bone_cmap:
print(out_img.shape)
return out_img
cmap = plt.get_cmap('bone')
out_img = cmap(out_img)
out_img = (out_img[..., :3] * 255).astype(np.uint8)
current_img = out_img
return out_img
def rotate_to_standard_btn_fn(img_path, add_bone_cmap=False):
global current_img
out_img = make_predictions(img_path, rotate_to_standard=True)[0]
if not add_bone_cmap:
return out_img
cmap = plt.get_cmap('bone')
out_img = cmap(out_img)
out_img = (out_img[..., :3] * 255).astype(np.uint8)
current_img = out_img
return out_img
def use_current_btn_fn(input_img):
return input_img
def make_live_btn_fn(img_path, axis, add_bone_cmap=False):
global live_preds
base_angles = list(range(-20, 21, 1))
base_angles = [float(i) for i in base_angles]
if axis.lower() == "axis x":
all_angles = [[i, 0, 0] for i in base_angles]
elif axis.lower() == "axis y":
all_angles = [[0, i, 0] for i in base_angles]
elif axis.lower() == "axis z":
all_angles = [[0, 0, i] for i in base_angles]
fp = torch.zeros(768)
cls_batch = torch.tensor([[1, *angles, *fp] for angles in all_angles])
live_preds = make_predictions(img_path, cls_batch=cls_batch)
live_preds = {angle: live_preds[i] for i, angle in enumerate(base_angles)}
return img_path
def rotate_live_img_fn(angle, add_bone_cmap=False):
global live_img
global live_preds
if live_img is not None:
if angle == 0:
return live_img
return live_preds[float(angle)]
css_style = "./style.css"
callback = gr.CSVLogger()
with gr.Blocks(css=css_style) as app:
gr.HTML("VCNet: A Deep Learning Solution for Roating RadioGraphs in 3D Space", elem_classes="title")
gr.HTML("Developed by the Orthopedics Surgery Artificial Intelligence Lab (OSAIL)", elem_classes="note")
gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note")
with gr.TabItem("Single Rotation"):
with gr.Row():
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
with gr.Row():
gr.Examples(
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
inputs = [input_img],
label = "Xray Examples",
elem_id='examples'
)
gr.Examples(
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
inputs = [input_img],
label = "DRR Examples",
elem_id='examples'
)
with gr.Row():
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
with gr.Row():
with gr.Column(scale=1):
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
with gr.Column(scale=1):
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
with gr.Column(scale=1):
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
with gr.Row():
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
with gr.Row():
rotate_to_standard_btn = gr.Button("Rotate to standard view!", elem_classes='rotate_to_standard_button')
with gr.Row():
use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
rotate_to_standard_btn.click(fn=rotate_to_standard_btn_fn, inputs=[input_img], outputs=output_img)
use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
with gr.TabItem("Live Rotation"):
with gr.Row():
live_img = gr.Image(type='filepath', label='Live Image', sources='upload', interactive=False, elem_classes='imgs')
with gr.Row():
gr.Examples(
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
inputs = [live_img],
label = "Xray Examples",
elem_id='examples'
)
gr.Examples(
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
inputs = [live_img],
label = "DRR Examples",
elem_id='examples'
)
with gr.Row():
gr.Markdown('Please select an example image, an axis, and then press Make Live!', elem_classes='text')
with gr.Row():
axis = gr.Dropdown(choices=['Axis X', 'Axis Y', 'Axis Z'], show_label=False, elem_classes='angle', value='Axis X')
live_btn = gr.Button("Make Live!", elem_classes='make_live_button')
with gr.Row():
gr.Markdown('You can now rotate the radiograph in your selected axis using the scaler.', elem_classes='text')
with gr.Row():
slider = gr.Slider(show_label=False, minimum=-20, maximum=20, step=1, value=0, elem_classes='slider', interactive=True)
live_btn.click(fn=make_live_btn_fn, inputs=[live_img, axis], outputs=live_img)
slider.change(fn=rotate_live_img_fn, inputs=[slider], outputs=live_img)
try:
app.close()
gr.close_all()
except:
pass
demo = app.launch(
max_threads=4,
share=True,
inline=False,
show_api=False,
show_error=True,
server_port=1902,
server_name="0.0.0.0",
)