Spaces:
Runtime error
Runtime error
import sys | |
import csv | |
import numpy as np | |
import gradio as gr | |
import nibabel as nib | |
import matplotlib.pyplot as plt | |
from scipy import ndimage | |
from huggingface_hub import from_pretrained_keras | |
csv.field_size_limit(sys.maxsize) | |
def read_nifti_file(filepath): | |
"""Read and load volume""" | |
# Read file | |
scan = nib.load(filepath) | |
# Get raw data | |
scan = scan.get_fdata() | |
return scan | |
def normalize(volume): | |
"""Normalize the volume""" | |
min = -1000 | |
max = 400 | |
volume[volume < min] = min | |
volume[volume > max] = max | |
volume = (volume - min) / (max - min) | |
volume = volume.astype("float32") | |
return volume | |
def resize_volume(img): | |
"""Resize across z-axis""" | |
# Set the desired depth | |
desired_depth = 64 | |
desired_width = 128 | |
desired_height = 128 | |
# Get current depth | |
current_depth = img.shape[-1] | |
current_width = img.shape[0] | |
current_height = img.shape[1] | |
# Compute depth factor | |
depth = current_depth / desired_depth | |
width = current_width / desired_width | |
height = current_height / desired_height | |
depth_factor = 1 / depth | |
width_factor = 1 / width | |
height_factor = 1 / height | |
# Rotate | |
img = ndimage.rotate(img, 90, reshape=False) | |
# Resize across z-axis | |
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1) | |
return img | |
def process_scan(path): | |
"""Read and resize volume""" | |
# Read scan | |
volume = read_nifti_file(path) | |
# Normalize | |
volume = normalize(volume) | |
# Resize width, height and depth | |
volume = resize_volume(volume) | |
return volume | |
def plot_slices(num_rows, num_columns, width, height, data): | |
"""Plot a montage of 20 CT slices""" | |
data = np.rot90(np.array(data)) | |
data = np.transpose(data) | |
data = np.reshape(data, (num_rows, num_columns, width, height)) | |
rows_data, columns_data = data.shape[0], data.shape[1] | |
heights = [slc[0].shape[0] for slc in data] | |
widths = [slc.shape[1] for slc in data[0]] | |
fig_width = 12.0 | |
fig_height = fig_width * sum(heights) / sum(widths) | |
f, axarr = plt.subplots( | |
rows_data, | |
columns_data, | |
figsize=(fig_width, fig_height), | |
gridspec_kw={"height_ratios": heights}, | |
) | |
for i in range(rows_data): | |
for j in range(columns_data): | |
axarr[i, j].imshow(data[i][j], cmap="gray") | |
axarr[i, j].axis("off") | |
return f | |
def infer(filename): | |
vol = process_scan(filename.name) | |
vol = np.expand_dims(vol, axis=0) | |
prediction = model.predict(vol)[0] | |
scores = [1 - prediction[0], prediction[0]] | |
class_names = ["normal", "abnormal"] | |
result = [] | |
for score, name in zip(scores, class_names): | |
result = result + [f"This model is {(100 * score):.2f} percent confident that CT scan is {name}"] | |
return result, plot_slices(2, 10, 128, 128, vol[0, :, :, :20]) | |
model = from_pretrained_keras('jalFaizy/3D_CNN') | |
inputs = gr.inputs.File() | |
outputs = [gr.outputs.Textbox(), 'plot'] | |
iface = gr.Interface( | |
infer, | |
inputs, | |
outputs, | |
title='3D CNN for CT scans', | |
examples=['example_1_normal.nii.gz'] | |
) | |
iface.launch() |