Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import osail_utils | |
import pandas as pd | |
import skimage | |
from mediffusion import DiffusionModule | |
import monai as mn | |
import torch | |
# Loading the model for inference | |
model = DiffusionModule("./diffusion_configs.yaml") | |
model.load_ckpt("./data/model.ckpt") | |
model.cuda().half() | |
model.eval(); | |
# Loading a baseline noise for making predictions | |
seed = 3407 | |
np.random.seed(seed) | |
torch.random.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half() | |
# Model helper functions | |
def create_ds(img_paths): | |
if type(img_paths) == str: | |
img_paths = [img_paths] | |
data_list = [{"img": img_path} for img_path in img_paths] | |
# Get the transforms | |
Ts_list = [ | |
osail_utils.io.LoadImageD(keys=["img"], transpose=True, normalize=True), | |
mn.transforms.EnsureChannelFirstD( | |
keys=["img"], channel_dim="no_channel" | |
), | |
mn.transforms.ResizeD( | |
keys=["img"], | |
spatial_size=(256, 256), | |
mode=["bicubic"], | |
), | |
mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1), | |
mn.transforms.ToTensorD(keys=["img"], track_meta=None), | |
mn.transforms.SelectItemsD(keys=["img"]), | |
] | |
return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list)) | |
def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"): | |
global model | |
global BASELINE_NOISE | |
# Create the image dataset | |
if cls_batch is not None: | |
ds = create_ds([img_path]*len(cls_batch)) | |
else: | |
ds = create_ds(img_path) | |
dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False) | |
input_batch = next(iter(dl)) | |
original_imgs = input_batch["img"].detach().cpu().numpy() | |
# Create the classifier condition if not provided | |
if cls_batch is None: | |
fp = torch.zeros(768) | |
if rotate_to_standard or angles is None: | |
angles = [1000, 1000, 1000] | |
cls_value = torch.tensor([2, *angles, *fp]) | |
else: | |
cls_value = torch.tensor([1, *angles, *fp]) | |
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1).cuda().half() | |
# Generate noise | |
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1) | |
model_kwargs = { | |
"cls": cls_batch, | |
"concat": input_batch["img"].cuda().half(), | |
} | |
# Make predictions | |
preds = model.predict( | |
noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler | |
) | |
adjusted_preds = list() | |
for pred, original_img in zip(preds, original_imgs): | |
adjusted_pred = pred.detach().cpu().numpy().squeeze() | |
original_img = original_img.squeeze() | |
adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img) | |
adjusted_preds.append(adjusted_pred) | |
return adjusted_preds | |
# Gradio helper functions | |
current_img = None | |
live_preds = None | |
def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False): | |
global current_img | |
angles = [float(xt), float(yt), float(zt)] | |
out_img = make_predictions(img_path, angles)[0] | |
if not add_bone_cmap: | |
print(out_img.shape) | |
return out_img | |
cmap = plt.get_cmap('bone') | |
out_img = cmap(out_img) | |
out_img = (out_img[..., :3] * 255).astype(np.uint8) | |
current_img = out_img | |
return out_img | |
def rotate_to_standard_btn_fn(img_path, add_bone_cmap=False): | |
global current_img | |
out_img = make_predictions(img_path, rotate_to_standard=True)[0] | |
if not add_bone_cmap: | |
return out_img | |
cmap = plt.get_cmap('bone') | |
out_img = cmap(out_img) | |
out_img = (out_img[..., :3] * 255).astype(np.uint8) | |
current_img = out_img | |
return out_img | |
def use_current_btn_fn(input_img): | |
return input_img | |
def make_live_btn_fn(img_path, axis, add_bone_cmap=False): | |
global live_preds | |
base_angles = list(range(-20, 21, 1)) | |
base_angles = [float(i) for i in base_angles] | |
if axis.lower() == "axis x": | |
all_angles = [[i, 0, 0] for i in base_angles] | |
elif axis.lower() == "axis y": | |
all_angles = [[0, i, 0] for i in base_angles] | |
elif axis.lower() == "axis z": | |
all_angles = [[0, 0, i] for i in base_angles] | |
fp = torch.zeros(768) | |
cls_batch = torch.tensor([[1, *angles, *fp] for angles in all_angles]) | |
live_preds = make_predictions(img_path, cls_batch=cls_batch) | |
live_preds = {angle: live_preds[i] for i, angle in enumerate(base_angles)} | |
return img_path | |
def rotate_live_img_fn(angle, add_bone_cmap=False): | |
global live_img | |
global live_preds | |
if live_img is not None: | |
if angle == 0: | |
return live_img | |
return live_preds[float(angle)] | |
css_style = "./style.css" | |
callback = gr.CSVLogger() | |
with gr.Blocks(css=css_style) as app: | |
gr.HTML("VCNet: A Deep Learning Solution for Roating RadioGraphs in 3D Space", elem_classes="title") | |
gr.HTML("Developed by the Orthopedics Surgery Artificial Intelligence Lab (OSAIL)", elem_classes="note") | |
gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note") | |
with gr.TabItem("Single Rotation"): | |
with gr.Row(): | |
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs') | |
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs') | |
with gr.Row(): | |
gr.Examples( | |
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], | |
inputs = [input_img], | |
label = "Xray Examples", | |
elem_id='examples' | |
) | |
gr.Examples( | |
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], | |
inputs = [input_img], | |
label = "DRR Examples", | |
elem_id='examples' | |
) | |
with gr.Row(): | |
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
with gr.Column(scale=1): | |
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
with gr.Column(scale=1): | |
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
with gr.Row(): | |
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') | |
with gr.Row(): | |
rotate_to_standard_btn = gr.Button("Rotate to standard view!", elem_classes='rotate_to_standard_button') | |
with gr.Row(): | |
use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button') | |
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) | |
rotate_to_standard_btn.click(fn=rotate_to_standard_btn_fn, inputs=[input_img], outputs=output_img) | |
use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img) | |
with gr.TabItem("Live Rotation"): | |
with gr.Row(): | |
live_img = gr.Image(type='filepath', label='Live Image', sources='upload', interactive=False, elem_classes='imgs') | |
with gr.Row(): | |
gr.Examples( | |
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], | |
inputs = [live_img], | |
label = "Xray Examples", | |
elem_id='examples' | |
) | |
gr.Examples( | |
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], | |
inputs = [live_img], | |
label = "DRR Examples", | |
elem_id='examples' | |
) | |
with gr.Row(): | |
gr.Markdown('Please select an example image, an axis, and then press Make Live!', elem_classes='text') | |
with gr.Row(): | |
axis = gr.Dropdown(choices=['Axis X', 'Axis Y', 'Axis Z'], show_label=False, elem_classes='angle', value='Axis X') | |
live_btn = gr.Button("Make Live!", elem_classes='make_live_button') | |
with gr.Row(): | |
gr.Markdown('You can now rotate the radiograph in your selected axis using the scaler.', elem_classes='text') | |
with gr.Row(): | |
slider = gr.Slider(show_label=False, minimum=-20, maximum=20, step=1, value=0, elem_classes='slider', interactive=True) | |
live_btn.click(fn=make_live_btn_fn, inputs=[live_img, axis], outputs=live_img) | |
slider.change(fn=rotate_live_img_fn, inputs=[slider], outputs=live_img) | |
try: | |
app.close() | |
gr.close_all() | |
except: | |
pass | |
demo = app.launch( | |
max_threads=4, | |
share=True, | |
inline=False, | |
show_api=False, | |
show_error=True, | |
server_port=1902, | |
server_name="0.0.0.0", | |
) |