|
|
|
|
|
import torch |
|
import numpy as np |
|
import nibabel as nib |
|
import os |
|
from options.test_options import TestOptions |
|
from models import create_model |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from models.inpaint_networks import Generator |
|
import torch.nn.functional as F |
|
import math |
|
from scipy.ndimage import label |
|
|
|
def remove_small_connected_components(input_array, min_size): |
|
|
|
|
|
|
|
structure = np.ones((3, 3), dtype=np.int32) |
|
labeled, ncomponents = label(input_array, structure) |
|
|
|
|
|
for i in range(1, ncomponents + 1): |
|
if np.sum(labeled == i) < min_size: |
|
input_array[labeled == i] = 0 |
|
|
|
|
|
|
|
return input_array |
|
|
|
def load_model(model_path, netG_params, device): |
|
model = Generator(netG_params, True) |
|
if os.path.exists(model_path): |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.eval() |
|
model.to(device) |
|
return model |
|
|
|
def numpy_to_pil(img_np): |
|
if img_np.dtype != np.uint8: |
|
raise ValueError("NumPy array should have uint8 data type.") |
|
img_pil = Image.fromarray(img_np) |
|
return img_pil |
|
|
|
def run_model(model,CAM_data,label_data,ct_data,vert_id,index_ratio,A_transform,mask_transform,device,maxheight=40): |
|
vert_label_slice = np.zeros_like(label_data) |
|
vert_label_slice[label_data==vert_id]=1 |
|
|
|
vert_label_slice = remove_small_connected_components(vert_label_slice,50) |
|
coords = np.argwhere(vert_label_slice) |
|
if coords.size==0: |
|
return None |
|
x1, x2 = min(coords[:, 0]), max(coords[:, 0]) |
|
width,length = vert_label_slice.shape |
|
height = x2-x1 |
|
if height>maxheight: |
|
x_mean = int(np.mean(coords[:, 0])) |
|
x1 = x_mean-20 |
|
x2 = x1+40 |
|
|
|
mask_x = (x1+x2)//2 |
|
h2 = maxheight |
|
if mask_x<=h2//2: |
|
min_x = 0 |
|
max_x = min_x + h2 |
|
elif width-mask_x<=h2/2: |
|
max_x = width |
|
min_x = max_x -h2 |
|
else: |
|
min_x = mask_x-h2//2 |
|
max_x = min_x + h2 |
|
|
|
mask_slice = np.zeros_like(vert_label_slice).astype(np.uint8) |
|
mask_slice[min_x:max_x+1] = 255 |
|
ct_data_slice = np.zeros_like(mask_slice).astype(np.uint8) |
|
ct_data_slice[:min_x,:] = ct_data[(x1-min_x):x1,:] |
|
ct_data_slice[max_x:,:] = ct_data[x2:x2+(width-max_x),:] |
|
|
|
CAM_slice = np.zeros_like(mask_slice).astype(np.uint8) |
|
CAM_slice[:min_x,:] = CAM_data[(x1-min_x):x1,:] |
|
CAM_slice[max_x:,:] = CAM_data[x2:x2+(width-max_x),:] |
|
|
|
ct_batch = numpy_to_pil(ct_data_slice) |
|
ct_batch = A_transform(ct_batch) |
|
|
|
ori_ct = numpy_to_pil(ct_data.astype(np.uint8)) |
|
ori_ct = A_transform(ori_ct) |
|
|
|
mask_batch = numpy_to_pil(mask_slice) |
|
mask_batch = mask_transform(mask_batch) |
|
|
|
CAM = numpy_to_pil(CAM_slice) |
|
CAM = mask_transform(CAM) |
|
|
|
ct_batch = ct_batch.unsqueeze(0).to(device) |
|
mask_batch = mask_batch.unsqueeze(0).to(device) |
|
CAM = CAM.unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
_, fake_B_mask_sigmoid, _, fake_B_raw, _,_,pred_h = model(ct_batch, mask_batch, 1-CAM,index_ratio) |
|
|
|
pred_h = math.ceil(pred_h[0]*maxheight) |
|
|
|
fake_B_mask_raw = torch.where(fake_B_mask_sigmoid > 0.5, torch.ones_like(fake_B_mask_sigmoid), torch.zeros_like(fake_B_mask_sigmoid)) |
|
|
|
|
|
if pred_h<height: |
|
pred_h = height |
|
height_diff = pred_h-height |
|
x_upper = x1-height_diff//2 |
|
x_bottom = x_upper+pred_h |
|
single_image = torch.zeros_like(fake_B_raw) |
|
single_image[:,:,x_upper:x_bottom,:] = fake_B_raw[:,:,x_upper:x_bottom,:] |
|
ct_upper = torch.zeros_like(single_image) |
|
ct_upper[0,:,:x_upper,:] = ori_ct[:, height_diff//2:x1, :] |
|
ct_bottom = torch.zeros_like(single_image) |
|
ct_bottom[0,:,x_bottom:,:] = ori_ct[:, x2:x2+256-x_bottom, :] |
|
interpolated_image = single_image+ct_upper+ct_bottom |
|
fake_B = interpolated_image.squeeze().cpu().numpy() |
|
fake_B = (fake_B+1)*127.5 |
|
|
|
mid_seg = np.zeros_like(fake_B_mask_raw.squeeze().cpu().numpy()) |
|
mid_seg[x_upper:x_bottom,:] = fake_B_mask_raw[:,:,x_upper:x_bottom,:].squeeze().cpu().numpy()*vert_id |
|
seg_upper = np.zeros_like(mid_seg) |
|
seg_upper[:x_upper,:] = label_data[height_diff//2:x1, :] |
|
seg_bottom = np.zeros_like(mid_seg) |
|
seg_bottom[x_bottom:,:] = label_data[x2:x2+256-x_bottom, :] |
|
interpolated_seg = mid_seg+seg_upper+seg_bottom |
|
fake_B_mask_raw = interpolated_seg |
|
|
|
|
|
return fake_B_mask_raw,fake_B,height |
|
|
|
|
|
def process_nii_files(folder_path,CAM_folder, model, output_folder, device): |
|
A_transform = transforms.Compose([ |
|
transforms.Grayscale(1), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
|
|
mask_transform = transforms.Compose([ |
|
transforms.ToTensor() |
|
]) |
|
|
|
if not os.path.exists(os.path.join(output_folder, 'CT')): |
|
os.makedirs(os.path.join(output_folder, 'CT')) |
|
if not os.path.exists(os.path.join(output_folder, 'label')): |
|
os.makedirs(os.path.join(output_folder, 'label')) |
|
|
|
count = 0 |
|
for file_name in os.listdir(folder_path): |
|
|
|
|
|
if file_name.endswith('.nii.gz'): |
|
if os.path.exists(os.path.join(output_folder, 'CT_fake', file_name)): |
|
continue |
|
|
|
|
|
file_path = os.path.join(folder_path, file_name) |
|
label_path = file_path.replace('CT', 'label') |
|
ct_nii = nib.load(file_path) |
|
ct_data = ct_nii.get_fdata() |
|
label_nii = nib.load(label_path) |
|
label_data = label_nii.get_fdata() |
|
patient_id, vert_id = file_name[:-7].rsplit('_', 1) |
|
vert_id = int(vert_id) |
|
|
|
CAM_path_0 = os.path.join(CAM_folder, file_name[:-7]+'_0.nii.gz') |
|
CAM_path_1 = os.path.join(CAM_folder, file_name[:-7]+'_1.nii.gz') |
|
CAM_path_2 = os.path.join(CAM_folder, file_name[:-7]+'.nii.gz') |
|
if os.path.exists(CAM_path_0): |
|
CAM_path = CAM_path_0 |
|
elif os.path.exists(CAM_path_1): |
|
CAM_path = CAM_path_1 |
|
else: |
|
CAM_path = CAM_path_2 |
|
|
|
|
|
CAM_data = nib.load(CAM_path).get_fdata() * 255 |
|
|
|
vert_label = np.zeros_like(label_data) |
|
vert_label[label_data==vert_id]=1 |
|
|
|
loc = np.where(vert_label) |
|
|
|
z0 = min(loc[2]) |
|
z1 = max(loc[2]) |
|
range_length = z1 - z0 + 1 |
|
new_range_length = int(range_length * 4 / 5) |
|
new_z0 = z0 + (range_length - new_range_length) // 2 |
|
new_z1 = new_z0 + new_range_length - 1 |
|
|
|
output_ct_data = np.zeros_like(ct_data) |
|
output_seg_data = np.zeros_like(ct_data) |
|
center_index = (new_z0 + new_z1) // 2 |
|
|
|
maxheight = 40 |
|
|
|
for z in range(new_z0, new_z1 + 1): |
|
index_ratio = abs(z-center_index)/range_length*2 |
|
index_ratio = torch.tensor([index_ratio]) |
|
if int(vert_id)>8 and np.sum(label_data[:, :, z]==int(vert_id)-1)>200: |
|
|
|
vert_id_upper = int(vert_id)-1 |
|
|
|
fake_B_mask_upper,fake_B_ct_upper,_ = run_model(model,CAM_data[:, :, z],label_data[:, :, z],ct_data[:, :, z],vert_id_upper,index_ratio,\ |
|
A_transform,mask_transform,device,maxheight) |
|
else: |
|
fake_B_mask_upper,fake_B_ct_upper = label_data[:, :, z],ct_data[:, :, z] |
|
|
|
if int(vert_id)<24 and np.sum(label_data[:, :, z]==int(vert_id)+1)>200: |
|
|
|
vert_id_bottom = int(vert_id)+1 |
|
|
|
fake_B_mask_bottom,fake_B_ct_bottom,_ = run_model(model,CAM_data[:, :, z],fake_B_mask_upper,fake_B_ct_upper,vert_id_bottom,index_ratio,\ |
|
A_transform,mask_transform,device,maxheight) |
|
else: |
|
fake_B_mask_bottom,fake_B_ct_bottom = fake_B_mask_upper,fake_B_ct_upper |
|
|
|
|
|
|
|
output = run_model(model,CAM_data[:, :, z],fake_B_mask_bottom,fake_B_ct_bottom,int(vert_id),index_ratio,\ |
|
A_transform,mask_transform,device,maxheight) |
|
if output==None: |
|
continue |
|
else: |
|
fake_B_mask_raw,fake_B,height = output |
|
if height>maxheight: |
|
print("Height exceeds in %s, in slice %d"%(file_name,z)) |
|
|
|
output_seg_data[:, :, z] = fake_B_mask_raw |
|
output_ct_data[:, :, z] = fake_B |
|
|
|
new_ct_nii = nib.Nifti1Image(output_ct_data, ct_nii.affine) |
|
nib.save(new_ct_nii, os.path.join(output_folder, 'CT_fake', file_name)) |
|
new_label_nii = nib.Nifti1Image(output_seg_data, ct_nii.affine) |
|
nib.save(new_label_nii, os.path.join(output_folder, 'label_fake', file_name)) |
|
print(f"Now {file_name} has been generateed in {output_folder}") |
|
count+=1 |
|
|
|
|
|
def main(): |
|
model_path = '/home/zhangqi/Project/pytorch-CycleGAN-and-pix2pix-master/checkpoints/0421_adaptive_sagittal/latest_net_G.pth' |
|
netG_params = {'input_dim': 1, 'ngf': 16} |
|
|
|
|
|
|
|
folder_path = '/home/zhangqi/Project/pytorch-CycleGAN-and-pix2pix-master/datasets/local/straighten/CT' |
|
CAM_folder = '/home/zhangqi/Project/VertebralFractureGrading/heatmap/local_sagittal_0508/binaryclass_1' |
|
output_folder = '/home/zhangqi/Project/pytorch-CycleGAN-and-pix2pix-master/output_3d/local_dataset/sagittal/fine' |
|
if not os.path.exists(output_folder+'/CT_fake'): |
|
os.makedirs(output_folder+'/CT_fake') |
|
if not os.path.exists(output_folder+'/label_fake'): |
|
os.makedirs(output_folder+'/label_fake') |
|
device = 'cuda:0' |
|
|
|
model = load_model(model_path, netG_params, device) |
|
process_nii_files(folder_path,CAM_folder, model, output_folder, device) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|