Upload 7 files
Browse files- Inference.py +53 -0
- basics.py +74 -0
- data_loader_cache.py +385 -0
- hce_metric_main.py +188 -0
- pytorch18.yml +92 -0
- requirements.txt +87 -4
- train_valid_inference_main.py +731 -0
Inference.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from skimage import io
|
5 |
+
import time
|
6 |
+
from glob import glob
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
import torch, gc
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.autograd import Variable
|
12 |
+
import torch.optim as optim
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torchvision.transforms.functional import normalize
|
15 |
+
|
16 |
+
from models import *
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
dataset_path="../demo_datasets/your_dataset" #Your dataset path
|
21 |
+
model_path="../saved_models/IS-Net/isnet-general-use.pth" # the model path
|
22 |
+
result_path="../demo_datasets/your_dataset_result" #The folder path that you want to save the results
|
23 |
+
input_size=[1024,1024]
|
24 |
+
net=ISNetDIS()
|
25 |
+
|
26 |
+
if torch.cuda.is_available():
|
27 |
+
net.load_state_dict(torch.load(model_path))
|
28 |
+
net=net.cuda()
|
29 |
+
else:
|
30 |
+
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
31 |
+
net.eval()
|
32 |
+
im_list = glob(dataset_path+"/*.jpg")+glob(dataset_path+"/*.JPG")+glob(dataset_path+"/*.jpeg")+glob(dataset_path+"/*.JPEG")+glob(dataset_path+"/*.png")+glob(dataset_path+"/*.PNG")+glob(dataset_path+"/*.bmp")+glob(dataset_path+"/*.BMP")+glob(dataset_path+"/*.tiff")+glob(dataset_path+"/*.TIFF")
|
33 |
+
with torch.no_grad():
|
34 |
+
for i, im_path in tqdm(enumerate(im_list), total=len(im_list)):
|
35 |
+
print("im_path: ", im_path)
|
36 |
+
im = io.imread(im_path)
|
37 |
+
if len(im.shape) < 3:
|
38 |
+
im = im[:, :, np.newaxis]
|
39 |
+
im_shp=im.shape[0:2]
|
40 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
41 |
+
im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
|
42 |
+
image = torch.divide(im_tensor,255.0)
|
43 |
+
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
44 |
+
|
45 |
+
if torch.cuda.is_available():
|
46 |
+
image=image.cuda()
|
47 |
+
result=net(image)
|
48 |
+
result=torch.squeeze(F.upsample(result[0][0],im_shp,mode='bilinear'),0)
|
49 |
+
ma = torch.max(result)
|
50 |
+
mi = torch.min(result)
|
51 |
+
result = (result-mi)/(ma-mi)
|
52 |
+
im_name=im_path.split('/')[-1].split('.')[0]
|
53 |
+
io.imsave(os.path.join(result_path,im_name+".png"),(result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8))
|
basics.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
3 |
+
from skimage import io, transform
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from torch.autograd import Variable
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
from torchvision import transforms, utils
|
11 |
+
import torch.optim as optim
|
12 |
+
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
import glob
|
17 |
+
|
18 |
+
def mae_torch(pred,gt):
|
19 |
+
|
20 |
+
h,w = gt.shape[0:2]
|
21 |
+
sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
|
22 |
+
maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
|
23 |
+
|
24 |
+
return maeError
|
25 |
+
|
26 |
+
def f1score_torch(pd,gt):
|
27 |
+
|
28 |
+
# print(gt.shape)
|
29 |
+
gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
|
30 |
+
|
31 |
+
pp = pd[gt>128]
|
32 |
+
nn = pd[gt<=128]
|
33 |
+
|
34 |
+
pp_hist =torch.histc(pp,bins=255,min=0,max=255)
|
35 |
+
nn_hist = torch.histc(nn,bins=255,min=0,max=255)
|
36 |
+
|
37 |
+
|
38 |
+
pp_hist_flip = torch.flipud(pp_hist)
|
39 |
+
nn_hist_flip = torch.flipud(nn_hist)
|
40 |
+
|
41 |
+
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
|
42 |
+
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
|
43 |
+
|
44 |
+
precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
|
45 |
+
recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
|
46 |
+
f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
|
47 |
+
|
48 |
+
return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
|
49 |
+
|
50 |
+
|
51 |
+
def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
|
52 |
+
|
53 |
+
import time
|
54 |
+
tic = time.time()
|
55 |
+
|
56 |
+
if(len(gt.shape)>2):
|
57 |
+
gt = gt[:,:,0]
|
58 |
+
|
59 |
+
pre, rec, f1 = f1score_torch(pred,gt)
|
60 |
+
mae = mae_torch(pred,gt)
|
61 |
+
|
62 |
+
|
63 |
+
# hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
|
64 |
+
if(hypar["valid_out_dir"]!=""):
|
65 |
+
if(not os.path.exists(hypar["valid_out_dir"])):
|
66 |
+
os.mkdir(hypar["valid_out_dir"])
|
67 |
+
dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
|
68 |
+
if(not os.path.exists(dataset_folder)):
|
69 |
+
os.mkdir(dataset_folder)
|
70 |
+
io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
|
71 |
+
print(valid_dataset.dataset["im_name"][idx]+".png")
|
72 |
+
print("time for evaluation : ", time.time()-tic)
|
73 |
+
|
74 |
+
return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()
|
data_loader_cache.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## data loader
|
2 |
+
## Ackownledgement:
|
3 |
+
## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
|
4 |
+
## for his helps in implementing cache machanism of our DIS dataloader.
|
5 |
+
from __future__ import print_function, division
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
from copy import deepcopy
|
10 |
+
import json
|
11 |
+
from tqdm import tqdm
|
12 |
+
from skimage import io
|
13 |
+
import os
|
14 |
+
from glob import glob
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch.utils.data import Dataset, DataLoader
|
18 |
+
from torchvision import transforms, utils
|
19 |
+
from torchvision.transforms.functional import normalize
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
#### --------------------- DIS dataloader cache ---------------------####
|
23 |
+
|
24 |
+
def get_im_gt_name_dict(datasets, flag='valid'):
|
25 |
+
print("------------------------------", flag, "--------------------------------")
|
26 |
+
name_im_gt_list = []
|
27 |
+
for i in range(len(datasets)):
|
28 |
+
print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
|
29 |
+
tmp_im_list, tmp_gt_list = [], []
|
30 |
+
tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
|
31 |
+
|
32 |
+
# img_name_dict[im_dirs[i][0]] = tmp_im_list
|
33 |
+
print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
|
34 |
+
|
35 |
+
if(datasets[i]["gt_dir"]==""):
|
36 |
+
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
|
37 |
+
tmp_gt_list = []
|
38 |
+
else:
|
39 |
+
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
|
40 |
+
|
41 |
+
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
42 |
+
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
|
43 |
+
|
44 |
+
|
45 |
+
if flag=="train": ## combine multiple training sets into one dataset
|
46 |
+
if len(name_im_gt_list)==0:
|
47 |
+
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
48 |
+
"im_path":tmp_im_list,
|
49 |
+
"gt_path":tmp_gt_list,
|
50 |
+
"im_ext":datasets[i]["im_ext"],
|
51 |
+
"gt_ext":datasets[i]["gt_ext"],
|
52 |
+
"cache_dir":datasets[i]["cache_dir"]})
|
53 |
+
else:
|
54 |
+
name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
|
55 |
+
name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list
|
56 |
+
name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list
|
57 |
+
if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
|
58 |
+
print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
|
59 |
+
exit()
|
60 |
+
name_im_gt_list[0]["im_ext"] = ".jpg"
|
61 |
+
name_im_gt_list[0]["gt_ext"] = ".png"
|
62 |
+
name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"]
|
63 |
+
else: ## keep different validation or inference datasets as separate ones
|
64 |
+
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
65 |
+
"im_path":tmp_im_list,
|
66 |
+
"gt_path":tmp_gt_list,
|
67 |
+
"im_ext":datasets[i]["im_ext"],
|
68 |
+
"gt_ext":datasets[i]["gt_ext"],
|
69 |
+
"cache_dir":datasets[i]["cache_dir"]})
|
70 |
+
|
71 |
+
return name_im_gt_list
|
72 |
+
|
73 |
+
def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False):
|
74 |
+
## model="train": return one dataloader for training
|
75 |
+
## model="valid": return a list of dataloaders for validation or testing
|
76 |
+
|
77 |
+
gos_dataloaders = []
|
78 |
+
gos_datasets = []
|
79 |
+
|
80 |
+
if(len(name_im_gt_list)==0):
|
81 |
+
return gos_dataloaders, gos_datasets
|
82 |
+
|
83 |
+
num_workers_ = 1
|
84 |
+
if(batch_size>1):
|
85 |
+
num_workers_ = 2
|
86 |
+
if(batch_size>4):
|
87 |
+
num_workers_ = 4
|
88 |
+
if(batch_size>8):
|
89 |
+
num_workers_ = 8
|
90 |
+
|
91 |
+
for i in range(0,len(name_im_gt_list)):
|
92 |
+
gos_dataset = GOSDatasetCache([name_im_gt_list[i]],
|
93 |
+
cache_size = cache_size,
|
94 |
+
cache_path = name_im_gt_list[i]["cache_dir"],
|
95 |
+
cache_boost = cache_boost,
|
96 |
+
transform = transforms.Compose(my_transforms))
|
97 |
+
gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
|
98 |
+
gos_datasets.append(gos_dataset)
|
99 |
+
|
100 |
+
return gos_dataloaders, gos_datasets
|
101 |
+
|
102 |
+
def im_reader(im_path):
|
103 |
+
return io.imread(im_path)
|
104 |
+
|
105 |
+
def im_preprocess(im,size):
|
106 |
+
if len(im.shape) < 3:
|
107 |
+
im = im[:, :, np.newaxis]
|
108 |
+
if im.shape[2] == 1:
|
109 |
+
im = np.repeat(im, 3, axis=2)
|
110 |
+
im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
|
111 |
+
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
112 |
+
if(len(size)<2):
|
113 |
+
return im_tensor, im.shape[0:2]
|
114 |
+
else:
|
115 |
+
im_tensor = torch.unsqueeze(im_tensor,0)
|
116 |
+
im_tensor = F.upsample(im_tensor, size, mode="bilinear")
|
117 |
+
im_tensor = torch.squeeze(im_tensor,0)
|
118 |
+
|
119 |
+
return im_tensor.type(torch.uint8), im.shape[0:2]
|
120 |
+
|
121 |
+
def gt_preprocess(gt,size):
|
122 |
+
if len(gt.shape) > 2:
|
123 |
+
gt = gt[:, :, 0]
|
124 |
+
|
125 |
+
gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
|
126 |
+
|
127 |
+
if(len(size)<2):
|
128 |
+
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
129 |
+
else:
|
130 |
+
gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
|
131 |
+
gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
|
132 |
+
gt_tensor = torch.squeeze(gt_tensor,0)
|
133 |
+
|
134 |
+
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
135 |
+
# return gt_tensor, gt.shape[0:2]
|
136 |
+
|
137 |
+
class GOSRandomHFlip(object):
|
138 |
+
def __init__(self,prob=0.5):
|
139 |
+
self.prob = prob
|
140 |
+
def __call__(self,sample):
|
141 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
142 |
+
|
143 |
+
# random horizontal flip
|
144 |
+
if random.random() >= self.prob:
|
145 |
+
image = torch.flip(image,dims=[2])
|
146 |
+
label = torch.flip(label,dims=[2])
|
147 |
+
|
148 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
149 |
+
|
150 |
+
class GOSResize(object):
|
151 |
+
def __init__(self,size=[320,320]):
|
152 |
+
self.size = size
|
153 |
+
def __call__(self,sample):
|
154 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
155 |
+
|
156 |
+
# import time
|
157 |
+
# start = time.time()
|
158 |
+
|
159 |
+
image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
|
160 |
+
label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
|
161 |
+
|
162 |
+
# print("time for resize: ", time.time()-start)
|
163 |
+
|
164 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
165 |
+
|
166 |
+
class GOSRandomCrop(object):
|
167 |
+
def __init__(self,size=[288,288]):
|
168 |
+
self.size = size
|
169 |
+
def __call__(self,sample):
|
170 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
171 |
+
|
172 |
+
h, w = image.shape[1:]
|
173 |
+
new_h, new_w = self.size
|
174 |
+
|
175 |
+
top = np.random.randint(0, h - new_h)
|
176 |
+
left = np.random.randint(0, w - new_w)
|
177 |
+
|
178 |
+
image = image[:,top:top+new_h,left:left+new_w]
|
179 |
+
label = label[:,top:top+new_h,left:left+new_w]
|
180 |
+
|
181 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
182 |
+
|
183 |
+
|
184 |
+
class GOSNormalize(object):
|
185 |
+
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
|
186 |
+
self.mean = mean
|
187 |
+
self.std = std
|
188 |
+
|
189 |
+
def __call__(self,sample):
|
190 |
+
|
191 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
192 |
+
image = normalize(image,self.mean,self.std)
|
193 |
+
|
194 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
195 |
+
|
196 |
+
|
197 |
+
class GOSDatasetCache(Dataset):
|
198 |
+
|
199 |
+
def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None):
|
200 |
+
|
201 |
+
|
202 |
+
self.cache_size = cache_size
|
203 |
+
self.cache_path = cache_path
|
204 |
+
self.cache_file_name = cache_file_name
|
205 |
+
self.cache_boost_name = ""
|
206 |
+
|
207 |
+
self.cache_boost = cache_boost
|
208 |
+
# self.ims_npy = None
|
209 |
+
# self.gts_npy = None
|
210 |
+
|
211 |
+
## cache all the images and ground truth into a single pytorch tensor
|
212 |
+
self.ims_pt = None
|
213 |
+
self.gts_pt = None
|
214 |
+
|
215 |
+
## we will cache the npy as well regardless of the cache_boost
|
216 |
+
# if(self.cache_boost):
|
217 |
+
self.cache_boost_name = cache_file_name.split('.json')[0]
|
218 |
+
|
219 |
+
self.transform = transform
|
220 |
+
|
221 |
+
self.dataset = {}
|
222 |
+
|
223 |
+
## combine different datasets into one
|
224 |
+
dataset_names = []
|
225 |
+
dt_name_list = [] # dataset name per image
|
226 |
+
im_name_list = [] # image name
|
227 |
+
im_path_list = [] # im path
|
228 |
+
gt_path_list = [] # gt path
|
229 |
+
im_ext_list = [] # im ext
|
230 |
+
gt_ext_list = [] # gt ext
|
231 |
+
for i in range(0,len(name_im_gt_list)):
|
232 |
+
dataset_names.append(name_im_gt_list[i]["dataset_name"])
|
233 |
+
# dataset name repeated based on the number of images in this dataset
|
234 |
+
dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]])
|
235 |
+
im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]])
|
236 |
+
im_path_list.extend(name_im_gt_list[i]["im_path"])
|
237 |
+
gt_path_list.extend(name_im_gt_list[i]["gt_path"])
|
238 |
+
im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]])
|
239 |
+
gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]])
|
240 |
+
|
241 |
+
|
242 |
+
self.dataset["data_name"] = dt_name_list
|
243 |
+
self.dataset["im_name"] = im_name_list
|
244 |
+
self.dataset["im_path"] = im_path_list
|
245 |
+
self.dataset["ori_im_path"] = deepcopy(im_path_list)
|
246 |
+
self.dataset["gt_path"] = gt_path_list
|
247 |
+
self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
|
248 |
+
self.dataset["im_shp"] = []
|
249 |
+
self.dataset["gt_shp"] = []
|
250 |
+
self.dataset["im_ext"] = im_ext_list
|
251 |
+
self.dataset["gt_ext"] = gt_ext_list
|
252 |
+
|
253 |
+
|
254 |
+
self.dataset["ims_pt_dir"] = ""
|
255 |
+
self.dataset["gts_pt_dir"] = ""
|
256 |
+
|
257 |
+
self.dataset = self.manage_cache(dataset_names)
|
258 |
+
|
259 |
+
def manage_cache(self,dataset_names):
|
260 |
+
if not os.path.exists(self.cache_path): # create the folder for cache
|
261 |
+
os.makedirs(self.cache_path)
|
262 |
+
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
|
263 |
+
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
|
264 |
+
return self.cache(cache_folder)
|
265 |
+
return self.load_cache(cache_folder)
|
266 |
+
|
267 |
+
def cache(self,cache_folder):
|
268 |
+
os.mkdir(cache_folder)
|
269 |
+
cached_dataset = deepcopy(self.dataset)
|
270 |
+
|
271 |
+
# ims_list = []
|
272 |
+
# gts_list = []
|
273 |
+
ims_pt_list = []
|
274 |
+
gts_pt_list = []
|
275 |
+
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
|
276 |
+
|
277 |
+
im_id = cached_dataset["im_name"][i]
|
278 |
+
print("im_path: ", im_path)
|
279 |
+
im = im_reader(im_path)
|
280 |
+
im, im_shp = im_preprocess(im,self.cache_size)
|
281 |
+
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
282 |
+
torch.save(im,im_cache_file)
|
283 |
+
|
284 |
+
cached_dataset["im_path"][i] = im_cache_file
|
285 |
+
if(self.cache_boost):
|
286 |
+
ims_pt_list.append(torch.unsqueeze(im,0))
|
287 |
+
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
288 |
+
|
289 |
+
gt = np.zeros(im.shape[0:2])
|
290 |
+
if len(self.dataset["gt_path"])!=0:
|
291 |
+
gt = im_reader(self.dataset["gt_path"][i])
|
292 |
+
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
293 |
+
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
294 |
+
torch.save(gt,gt_cache_file)
|
295 |
+
if len(self.dataset["gt_path"])>0:
|
296 |
+
cached_dataset["gt_path"][i] = gt_cache_file
|
297 |
+
else:
|
298 |
+
cached_dataset["gt_path"].append(gt_cache_file)
|
299 |
+
if(self.cache_boost):
|
300 |
+
gts_pt_list.append(torch.unsqueeze(gt,0))
|
301 |
+
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
302 |
+
|
303 |
+
# im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
|
304 |
+
# torch.save(gt_shp, shp_cache_file)
|
305 |
+
cached_dataset["im_shp"].append(im_shp)
|
306 |
+
# self.dataset["im_shp"].append(im_shp)
|
307 |
+
|
308 |
+
# shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
|
309 |
+
# torch.save(gt_shp, shp_cache_file)
|
310 |
+
cached_dataset["gt_shp"].append(gt_shp)
|
311 |
+
# self.dataset["gt_shp"].append(gt_shp)
|
312 |
+
|
313 |
+
if(self.cache_boost):
|
314 |
+
cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
|
315 |
+
cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
|
316 |
+
self.ims_pt = torch.cat(ims_pt_list,dim=0)
|
317 |
+
self.gts_pt = torch.cat(gts_pt_list,dim=0)
|
318 |
+
torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
|
319 |
+
torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
|
320 |
+
|
321 |
+
try:
|
322 |
+
json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
|
323 |
+
json.dump(cached_dataset, json_file)
|
324 |
+
json_file.close()
|
325 |
+
except Exception:
|
326 |
+
raise FileNotFoundError("Cannot create JSON")
|
327 |
+
return cached_dataset
|
328 |
+
|
329 |
+
def load_cache(self, cache_folder):
|
330 |
+
json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
|
331 |
+
dataset = json.load(json_file)
|
332 |
+
json_file.close()
|
333 |
+
## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
|
334 |
+
## otherwise the pytorch tensor will be loaded
|
335 |
+
if(self.cache_boost):
|
336 |
+
# self.ims_npy = np.load(dataset["ims_npy_dir"])
|
337 |
+
# self.gts_npy = np.load(dataset["gts_npy_dir"])
|
338 |
+
self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
|
339 |
+
self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
|
340 |
+
return dataset
|
341 |
+
|
342 |
+
def __len__(self):
|
343 |
+
return len(self.dataset["im_path"])
|
344 |
+
|
345 |
+
def __getitem__(self, idx):
|
346 |
+
|
347 |
+
im = None
|
348 |
+
gt = None
|
349 |
+
if(self.cache_boost and self.ims_pt is not None):
|
350 |
+
|
351 |
+
# start = time.time()
|
352 |
+
im = self.ims_pt[idx]#.type(torch.float32)
|
353 |
+
gt = self.gts_pt[idx]#.type(torch.float32)
|
354 |
+
# print(idx, 'time for pt loading: ', time.time()-start)
|
355 |
+
|
356 |
+
else:
|
357 |
+
# import time
|
358 |
+
# start = time.time()
|
359 |
+
# print("tensor***")
|
360 |
+
im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
|
361 |
+
im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
|
362 |
+
gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
|
363 |
+
gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
|
364 |
+
# print(idx,'time for tensor loading: ', time.time()-start)
|
365 |
+
|
366 |
+
|
367 |
+
im_shp = self.dataset["im_shp"][idx]
|
368 |
+
# print("time for loading im and gt: ", time.time()-start)
|
369 |
+
|
370 |
+
# start_time = time.time()
|
371 |
+
im = torch.divide(im,255.0)
|
372 |
+
gt = torch.divide(gt,255.0)
|
373 |
+
# print(idx, 'time for normalize torch divide: ', time.time()-start_time)
|
374 |
+
|
375 |
+
sample = {
|
376 |
+
"imidx": torch.from_numpy(np.array(idx)),
|
377 |
+
"image": im,
|
378 |
+
"label": gt,
|
379 |
+
"shape": torch.from_numpy(np.array(im_shp)),
|
380 |
+
}
|
381 |
+
|
382 |
+
if self.transform:
|
383 |
+
sample = self.transform(sample)
|
384 |
+
|
385 |
+
return sample
|
hce_metric_main.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## hce_metric.py
|
2 |
+
import numpy as np
|
3 |
+
from skimage import io
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import cv2 as cv
|
6 |
+
from skimage.morphology import skeletonize
|
7 |
+
from skimage.morphology import erosion, dilation, disk
|
8 |
+
from skimage.measure import label
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
from tqdm import tqdm
|
13 |
+
from glob import glob
|
14 |
+
import pickle as pkl
|
15 |
+
|
16 |
+
def filter_bdy_cond(bdy_, mask, cond):
|
17 |
+
|
18 |
+
cond = cv.dilate(cond.astype(np.uint8),disk(1))
|
19 |
+
labels = label(mask) # find the connected regions
|
20 |
+
lbls = np.unique(labels) # the indices of the connected regions
|
21 |
+
indep = np.ones(lbls.shape[0]) # the label of each connected regions
|
22 |
+
indep[0] = 0 # 0 indicate the background region
|
23 |
+
|
24 |
+
boundaries = []
|
25 |
+
h,w = cond.shape[0:2]
|
26 |
+
ind_map = np.zeros((h,w))
|
27 |
+
indep_cnt = 0
|
28 |
+
|
29 |
+
for i in range(0,len(bdy_)):
|
30 |
+
tmp_bdies = []
|
31 |
+
tmp_bdy = []
|
32 |
+
for j in range(0,bdy_[i].shape[0]):
|
33 |
+
r, c = bdy_[i][j,0,1],bdy_[i][j,0,0]
|
34 |
+
|
35 |
+
if(np.sum(cond[r,c])==0 or ind_map[r,c]!=0):
|
36 |
+
if(len(tmp_bdy)>0):
|
37 |
+
tmp_bdies.append(tmp_bdy)
|
38 |
+
tmp_bdy = []
|
39 |
+
continue
|
40 |
+
tmp_bdy.append([c,r])
|
41 |
+
ind_map[r,c] = ind_map[r,c] + 1
|
42 |
+
indep[labels[r,c]] = 0 # indicates part of the boundary of this region needs human correction
|
43 |
+
if(len(tmp_bdy)>0):
|
44 |
+
tmp_bdies.append(tmp_bdy)
|
45 |
+
|
46 |
+
# check if the first and the last boundaries are connected
|
47 |
+
# if yes, invert the first boundary and attach it after the last boundary
|
48 |
+
if(len(tmp_bdies)>1):
|
49 |
+
first_x, first_y = tmp_bdies[0][0]
|
50 |
+
last_x, last_y = tmp_bdies[-1][-1]
|
51 |
+
if((abs(first_x-last_x)==1 and first_y==last_y) or
|
52 |
+
(first_x==last_x and abs(first_y-last_y)==1) or
|
53 |
+
(abs(first_x-last_x)==1 and abs(first_y-last_y)==1)
|
54 |
+
):
|
55 |
+
tmp_bdies[-1].extend(tmp_bdies[0][::-1])
|
56 |
+
del tmp_bdies[0]
|
57 |
+
|
58 |
+
for k in range(0,len(tmp_bdies)):
|
59 |
+
tmp_bdies[k] = np.array(tmp_bdies[k])[:,np.newaxis,:]
|
60 |
+
if(len(tmp_bdies)>0):
|
61 |
+
boundaries.extend(tmp_bdies)
|
62 |
+
|
63 |
+
return boundaries, np.sum(indep)
|
64 |
+
|
65 |
+
# this function approximate each boundary by DP algorithm
|
66 |
+
# https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm
|
67 |
+
def approximate_RDP(boundaries,epsilon=1.0):
|
68 |
+
|
69 |
+
boundaries_ = []
|
70 |
+
boundaries_len_ = []
|
71 |
+
pixel_cnt_ = 0
|
72 |
+
|
73 |
+
# polygon approximate of each boundary
|
74 |
+
for i in range(0,len(boundaries)):
|
75 |
+
boundaries_.append(cv.approxPolyDP(boundaries[i],epsilon,False))
|
76 |
+
|
77 |
+
# count the control points number of each boundary and the total control points number of all the boundaries
|
78 |
+
for i in range(0,len(boundaries_)):
|
79 |
+
boundaries_len_.append(len(boundaries_[i]))
|
80 |
+
pixel_cnt_ = pixel_cnt_ + len(boundaries_[i])
|
81 |
+
|
82 |
+
return boundaries_, boundaries_len_, pixel_cnt_
|
83 |
+
|
84 |
+
|
85 |
+
def relax_HCE(gt, rs, gt_ske, relax=5, epsilon=2.0):
|
86 |
+
# print("max(gt_ske): ", np.amax(gt_ske))
|
87 |
+
# gt_ske = gt_ske>128
|
88 |
+
# print("max(gt_ske): ", np.amax(gt_ske))
|
89 |
+
|
90 |
+
# Binarize gt
|
91 |
+
if(len(gt.shape)>2):
|
92 |
+
gt = gt[:,:,0]
|
93 |
+
|
94 |
+
epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0
|
95 |
+
gt = (gt>epsilon_gt).astype(np.uint8)
|
96 |
+
|
97 |
+
# Binarize rs
|
98 |
+
if(len(rs.shape)>2):
|
99 |
+
rs = rs[:,:,0]
|
100 |
+
epsilon_rs = 128#(np.amin(rs)+np.amax(rs))/2.0
|
101 |
+
rs = (rs>epsilon_rs).astype(np.uint8)
|
102 |
+
|
103 |
+
Union = np.logical_or(gt,rs)
|
104 |
+
TP = np.logical_and(gt,rs)
|
105 |
+
FP = rs - TP
|
106 |
+
FN = gt - TP
|
107 |
+
|
108 |
+
# relax the Union of gt and rs
|
109 |
+
Union_erode = Union.copy()
|
110 |
+
Union_erode = cv.erode(Union_erode.astype(np.uint8),disk(1),iterations=relax)
|
111 |
+
|
112 |
+
# --- get the relaxed False Positive regions for computing the human efforts in correcting them ---
|
113 |
+
FP_ = np.logical_and(FP,Union_erode) # get the relaxed FP
|
114 |
+
for i in range(0,relax):
|
115 |
+
FP_ = cv.dilate(FP_.astype(np.uint8),disk(1))
|
116 |
+
FP_ = np.logical_and(FP_, 1-np.logical_or(TP,FN))
|
117 |
+
FP_ = np.logical_and(FP, FP_)
|
118 |
+
|
119 |
+
# --- get the relaxed False Negative regions for computing the human efforts in correcting them ---
|
120 |
+
FN_ = np.logical_and(FN,Union_erode) # preserve the structural components of FN
|
121 |
+
## recover the FN, where pixels are not close to the TP borders
|
122 |
+
for i in range(0,relax):
|
123 |
+
FN_ = cv.dilate(FN_.astype(np.uint8),disk(1))
|
124 |
+
FN_ = np.logical_and(FN_,1-np.logical_or(TP,FP))
|
125 |
+
FN_ = np.logical_and(FN,FN_)
|
126 |
+
FN_ = np.logical_or(FN_, np.logical_xor(gt_ske,np.logical_and(TP,gt_ske))) # preserve the structural components of FN
|
127 |
+
|
128 |
+
## 2. =============Find exact polygon control points and independent regions==============
|
129 |
+
## find contours from FP_
|
130 |
+
ctrs_FP, hier_FP = cv.findContours(FP_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
131 |
+
## find control points and independent regions for human correction
|
132 |
+
bdies_FP, indep_cnt_FP = filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_))
|
133 |
+
## find contours from FN_
|
134 |
+
ctrs_FN, hier_FN = cv.findContours(FN_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
135 |
+
## find control points and independent regions for human correction
|
136 |
+
bdies_FN, indep_cnt_FN = filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP,FP_),FN_))
|
137 |
+
|
138 |
+
poly_FP, poly_FP_len, poly_FP_point_cnt = approximate_RDP(bdies_FP,epsilon=epsilon)
|
139 |
+
poly_FN, poly_FN_len, poly_FN_point_cnt = approximate_RDP(bdies_FN,epsilon=epsilon)
|
140 |
+
|
141 |
+
return poly_FP_point_cnt, indep_cnt_FP, poly_FN_point_cnt, indep_cnt_FN
|
142 |
+
|
143 |
+
def compute_hce(pred_root,gt_root,gt_ske_root):
|
144 |
+
|
145 |
+
gt_name_list = glob(pred_root+'/*.png')
|
146 |
+
gt_name_list = sorted([x.split('/')[-1] for x in gt_name_list])
|
147 |
+
|
148 |
+
hces = []
|
149 |
+
for gt_name in tqdm(gt_name_list, total=len(gt_name_list)):
|
150 |
+
gt_path = os.path.join(gt_root, gt_name)
|
151 |
+
pred_path = os.path.join(pred_root, gt_name)
|
152 |
+
|
153 |
+
gt = cv.imread(gt_path, cv.IMREAD_GRAYSCALE)
|
154 |
+
pred = cv.imread(pred_path, cv.IMREAD_GRAYSCALE)
|
155 |
+
|
156 |
+
ske_path = os.path.join(gt_ske_root,gt_name)
|
157 |
+
if os.path.exists(ske_path):
|
158 |
+
ske = cv.imread(ske_path,cv.IMREAD_GRAYSCALE)
|
159 |
+
ske = ske>128
|
160 |
+
else:
|
161 |
+
ske = skeletonize(gt>128)
|
162 |
+
|
163 |
+
FP_points, FP_indep, FN_points, FN_indep = relax_HCE(gt, pred,ske)
|
164 |
+
print(gt_path.split('/')[-1],FP_points, FP_indep, FN_points, FN_indep)
|
165 |
+
hces.append([FP_points, FP_indep, FN_points, FN_indep, FP_points+FP_indep+FN_points+FN_indep])
|
166 |
+
|
167 |
+
hce_metric ={'names': gt_name_list,
|
168 |
+
'hces': hces}
|
169 |
+
|
170 |
+
|
171 |
+
file_metric = open(pred_root+'/hce_metric.pkl','wb')
|
172 |
+
pkl.dump(hce_metric,file_metric)
|
173 |
+
# file_metrics.write(cmn_metrics)
|
174 |
+
file_metric.close()
|
175 |
+
|
176 |
+
return np.mean(np.array(hces)[:,-1])
|
177 |
+
|
178 |
+
def main():
|
179 |
+
|
180 |
+
gt_root = "../DIS5K/DIS-VD/gt"
|
181 |
+
gt_ske_root = ""
|
182 |
+
pred_root = "../Results/isnet(ours)/DIS-VD"
|
183 |
+
|
184 |
+
print("The average HCE metric: ", compute_hce(pred_root,gt_root,gt_ske_root))
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == '__main__':
|
188 |
+
main()
|
pytorch18.yml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pytorch18
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- anaconda
|
5 |
+
- pytorch
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- _libgcc_mutex=0.1=main
|
9 |
+
- _openmp_mutex=4.5=1_gnu
|
10 |
+
- blas=1.0=mkl
|
11 |
+
- brotli=1.0.9=he6710b0_2
|
12 |
+
- bzip2=1.0.8=h7b6447c_0
|
13 |
+
- ca-certificates=2022.2.1=h06a4308_0
|
14 |
+
- certifi=2021.10.8=py37h06a4308_2
|
15 |
+
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
16 |
+
- colorama=0.4.4=pyhd3eb1b0_0
|
17 |
+
- cudatoolkit=10.2.89=hfd86e86_1
|
18 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
19 |
+
- cytoolz=0.11.0=py37h7b6447c_0
|
20 |
+
- dask-core=2021.10.0=pyhd3eb1b0_0
|
21 |
+
- ffmpeg=4.3=hf484d3e_0
|
22 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
23 |
+
- freetype=2.11.0=h70c0345_0
|
24 |
+
- fsspec=2022.2.0=pyhd3eb1b0_0
|
25 |
+
- gmp=6.2.1=h2531618_2
|
26 |
+
- gnutls=3.6.15=he1e5248_0
|
27 |
+
- imageio=2.9.0=pyhd3eb1b0_0
|
28 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
29 |
+
- jpeg=9b=h024ee3a_2
|
30 |
+
- kiwisolver=1.3.2=py37h295c915_0
|
31 |
+
- lame=3.100=h7b6447c_0
|
32 |
+
- lcms2=2.12=h3be6417_0
|
33 |
+
- ld_impl_linux-64=2.35.1=h7274673_9
|
34 |
+
- libffi=3.3=he6710b0_2
|
35 |
+
- libgcc-ng=9.3.0=h5101ec6_17
|
36 |
+
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
37 |
+
- libgfortran4=7.5.0=ha8ba4b0_17
|
38 |
+
- libgomp=9.3.0=h5101ec6_17
|
39 |
+
- libiconv=1.15=h63c8f33_5
|
40 |
+
- libidn2=2.3.2=h7f8727e_0
|
41 |
+
- libpng=1.6.37=hbc83047_0
|
42 |
+
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
43 |
+
- libtasn1=4.16.0=h27cfd23_0
|
44 |
+
- libtiff=4.2.0=h85742a9_0
|
45 |
+
- libunistring=0.9.10=h27cfd23_0
|
46 |
+
- libuv=1.40.0=h7b6447c_0
|
47 |
+
- libwebp-base=1.2.2=h7f8727e_0
|
48 |
+
- locket=0.2.1=py37h06a4308_2
|
49 |
+
- lz4-c=1.9.3=h295c915_1
|
50 |
+
- matplotlib-base=3.5.1=py37ha18d171_1
|
51 |
+
- mkl=2021.4.0=h06a4308_640
|
52 |
+
- mkl-service=2.4.0=py37h7f8727e_0
|
53 |
+
- mkl_fft=1.3.1=py37hd3c417c_0
|
54 |
+
- mkl_random=1.2.2=py37h51133e4_0
|
55 |
+
- munkres=1.1.4=py_0
|
56 |
+
- ncurses=6.3=h7f8727e_2
|
57 |
+
- nettle=3.7.3=hbbd107a_1
|
58 |
+
- networkx=2.6.3=pyhd3eb1b0_0
|
59 |
+
- ninja=1.10.2=py37hd09550d_3
|
60 |
+
- numpy=1.21.2=py37h20f2e39_0
|
61 |
+
- numpy-base=1.21.2=py37h79a1101_0
|
62 |
+
- olefile=0.46=py37_0
|
63 |
+
- openh264=2.1.1=h4ff587b_0
|
64 |
+
- openssl=1.1.1n=h7f8727e_0
|
65 |
+
- packaging=21.3=pyhd3eb1b0_0
|
66 |
+
- partd=1.2.0=pyhd3eb1b0_1
|
67 |
+
- pillow=8.0.0=py37h9a89aac_0
|
68 |
+
- pip=21.2.2=py37h06a4308_0
|
69 |
+
- pyparsing=3.0.4=pyhd3eb1b0_0
|
70 |
+
- python=3.7.11=h12debd9_0
|
71 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
72 |
+
- pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
73 |
+
- pywavelets=1.1.1=py37h7b6447c_2
|
74 |
+
- pyyaml=6.0=py37h7f8727e_1
|
75 |
+
- readline=8.1.2=h7f8727e_1
|
76 |
+
- scikit-image=0.15.0=py37hb3f55d8_2
|
77 |
+
- scipy=1.7.3=py37hc147768_0
|
78 |
+
- setuptools=58.0.4=py37h06a4308_0
|
79 |
+
- six=1.16.0=pyhd3eb1b0_1
|
80 |
+
- sqlite=3.38.0=hc218d9a_0
|
81 |
+
- tk=8.6.11=h1ccaba5_0
|
82 |
+
- toolz=0.11.2=pyhd3eb1b0_0
|
83 |
+
- torchaudio=0.8.0=py37
|
84 |
+
- torchvision=0.9.0=py37_cu102
|
85 |
+
- tqdm=4.63.0=pyhd8ed1ab_0
|
86 |
+
- typing_extensions=3.10.0.2=pyh06a4308_0
|
87 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
88 |
+
- xz=5.2.5=h7b6447c_0
|
89 |
+
- yaml=0.2.5=h7b6447c_0
|
90 |
+
- zlib=1.2.11=h7f8727e_4
|
91 |
+
- zstd=1.4.9=haebb681_0
|
92 |
+
prefix: /home/solar/anaconda3/envs/pytorch18
|
requirements.txt
CHANGED
@@ -1,4 +1,87 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may be used to create an environment using:
|
2 |
+
# $ conda create --name <env> --file <this file>
|
3 |
+
# platform: linux-64
|
4 |
+
_libgcc_mutex=0.1=main
|
5 |
+
_openmp_mutex=4.5=1_gnu
|
6 |
+
blas=1.0=mkl
|
7 |
+
brotli=1.0.9=he6710b0_2
|
8 |
+
bzip2=1.0.8=h7b6447c_0
|
9 |
+
ca-certificates=2022.2.1=h06a4308_0
|
10 |
+
certifi=2021.10.8=py37h06a4308_2
|
11 |
+
cloudpickle=2.0.0=pyhd3eb1b0_0
|
12 |
+
colorama=0.4.4=pyhd3eb1b0_0
|
13 |
+
cudatoolkit=10.2.89=hfd86e86_1
|
14 |
+
cycler=0.11.0=pyhd3eb1b0_0
|
15 |
+
cytoolz=0.11.0=py37h7b6447c_0
|
16 |
+
dask-core=2021.10.0=pyhd3eb1b0_0
|
17 |
+
ffmpeg=4.3=hf484d3e_0
|
18 |
+
fonttools=4.25.0=pyhd3eb1b0_0
|
19 |
+
freetype=2.11.0=h70c0345_0
|
20 |
+
fsspec=2022.2.0=pyhd3eb1b0_0
|
21 |
+
gmp=6.2.1=h2531618_2
|
22 |
+
gnutls=3.6.15=he1e5248_0
|
23 |
+
imageio=2.9.0=pyhd3eb1b0_0
|
24 |
+
intel-openmp=2021.4.0=h06a4308_3561
|
25 |
+
jpeg=9b=h024ee3a_2
|
26 |
+
kiwisolver=1.3.2=py37h295c915_0
|
27 |
+
lame=3.100=h7b6447c_0
|
28 |
+
lcms2=2.12=h3be6417_0
|
29 |
+
ld_impl_linux-64=2.35.1=h7274673_9
|
30 |
+
libffi=3.3=he6710b0_2
|
31 |
+
libgcc-ng=9.3.0=h5101ec6_17
|
32 |
+
libgfortran-ng=7.5.0=ha8ba4b0_17
|
33 |
+
libgfortran4=7.5.0=ha8ba4b0_17
|
34 |
+
libgomp=9.3.0=h5101ec6_17
|
35 |
+
libiconv=1.15=h63c8f33_5
|
36 |
+
libidn2=2.3.2=h7f8727e_0
|
37 |
+
libpng=1.6.37=hbc83047_0
|
38 |
+
libstdcxx-ng=9.3.0=hd4cf53a_17
|
39 |
+
libtasn1=4.16.0=h27cfd23_0
|
40 |
+
libtiff=4.2.0=h85742a9_0
|
41 |
+
libunistring=0.9.10=h27cfd23_0
|
42 |
+
libuv=1.40.0=h7b6447c_0
|
43 |
+
libwebp-base=1.2.2=h7f8727e_0
|
44 |
+
locket=0.2.1=py37h06a4308_2
|
45 |
+
lz4-c=1.9.3=h295c915_1
|
46 |
+
matplotlib-base=3.5.1=py37ha18d171_1
|
47 |
+
mkl=2021.4.0=h06a4308_640
|
48 |
+
mkl-service=2.4.0=py37h7f8727e_0
|
49 |
+
mkl_fft=1.3.1=py37hd3c417c_0
|
50 |
+
mkl_random=1.2.2=py37h51133e4_0
|
51 |
+
munkres=1.1.4=py_0
|
52 |
+
ncurses=6.3=h7f8727e_2
|
53 |
+
nettle=3.7.3=hbbd107a_1
|
54 |
+
networkx=2.6.3=pyhd3eb1b0_0
|
55 |
+
ninja=1.10.2=py37hd09550d_3
|
56 |
+
numpy=1.21.2=py37h20f2e39_0
|
57 |
+
numpy-base=1.21.2=py37h79a1101_0
|
58 |
+
olefile=0.46=py37_0
|
59 |
+
openh264=2.1.1=h4ff587b_0
|
60 |
+
openssl=1.1.1n=h7f8727e_0
|
61 |
+
packaging=21.3=pyhd3eb1b0_0
|
62 |
+
partd=1.2.0=pyhd3eb1b0_1
|
63 |
+
pillow=8.0.0=py37h9a89aac_0
|
64 |
+
pip=21.2.2=py37h06a4308_0
|
65 |
+
pyparsing=3.0.4=pyhd3eb1b0_0
|
66 |
+
python=3.7.11=h12debd9_0
|
67 |
+
python-dateutil=2.8.2=pyhd3eb1b0_0
|
68 |
+
pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
69 |
+
pywavelets=1.1.1=py37h7b6447c_2
|
70 |
+
pyyaml=6.0=py37h7f8727e_1
|
71 |
+
readline=8.1.2=h7f8727e_1
|
72 |
+
scikit-image=0.15.0=py37hb3f55d8_2
|
73 |
+
scipy=1.7.3=py37hc147768_0
|
74 |
+
setuptools=58.0.4=py37h06a4308_0
|
75 |
+
six=1.16.0=pyhd3eb1b0_1
|
76 |
+
sqlite=3.38.0=hc218d9a_0
|
77 |
+
tk=8.6.11=h1ccaba5_0
|
78 |
+
toolz=0.11.2=pyhd3eb1b0_0
|
79 |
+
torchaudio=0.8.0=py37
|
80 |
+
torchvision=0.9.0=py37_cu102
|
81 |
+
tqdm=4.63.0=pyhd8ed1ab_0
|
82 |
+
typing_extensions=3.10.0.2=pyh06a4308_0
|
83 |
+
wheel=0.37.1=pyhd3eb1b0_0
|
84 |
+
xz=5.2.5=h7b6447c_0
|
85 |
+
yaml=0.2.5=h7b6447c_0
|
86 |
+
zlib=1.2.11=h7f8727e_4
|
87 |
+
zstd=1.4.9=haebb681_0
|
train_valid_inference_main.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from skimage import io
|
5 |
+
import time
|
6 |
+
|
7 |
+
import torch, gc
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.autograd import Variable
|
10 |
+
import torch.optim as optim
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
|
14 |
+
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
|
15 |
+
from models import *
|
16 |
+
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
|
19 |
+
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
20 |
+
|
21 |
+
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
22 |
+
# cache_size = hypar["cache_size"],
|
23 |
+
# cache_boost = hypar["cache_boost_train"],
|
24 |
+
# my_transforms = [
|
25 |
+
# GOSRandomHFlip(),
|
26 |
+
# # GOSResize(hypar["input_size"]),
|
27 |
+
# # GOSRandomCrop(hypar["crop_size"]),
|
28 |
+
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
29 |
+
# ],
|
30 |
+
# batch_size = hypar["batch_size_train"],
|
31 |
+
# shuffle = True)
|
32 |
+
|
33 |
+
torch.manual_seed(hypar["seed"])
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
torch.cuda.manual_seed(hypar["seed"])
|
36 |
+
|
37 |
+
print("define gt encoder ...")
|
38 |
+
net = ISNetGTEncoder() #UNETGTENCODERCombine()
|
39 |
+
## load the existing model gt encoder
|
40 |
+
if(hypar["gt_encoder_model"]!=""):
|
41 |
+
model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
net.load_state_dict(torch.load(model_path))
|
44 |
+
net.cuda()
|
45 |
+
else:
|
46 |
+
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
47 |
+
print("gt encoder restored from the saved weights ...")
|
48 |
+
return net ############
|
49 |
+
|
50 |
+
if torch.cuda.is_available():
|
51 |
+
net.cuda()
|
52 |
+
|
53 |
+
print("--- define optimizer for GT Encoder---")
|
54 |
+
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
55 |
+
|
56 |
+
model_path = hypar["model_path"]
|
57 |
+
model_save_fre = hypar["model_save_fre"]
|
58 |
+
max_ite = hypar["max_ite"]
|
59 |
+
batch_size_train = hypar["batch_size_train"]
|
60 |
+
batch_size_valid = hypar["batch_size_valid"]
|
61 |
+
|
62 |
+
if(not os.path.exists(model_path)):
|
63 |
+
os.mkdir(model_path)
|
64 |
+
|
65 |
+
ite_num = hypar["start_ite"] # count the total iteration number
|
66 |
+
ite_num4val = 0 #
|
67 |
+
running_loss = 0.0 # count the toal loss
|
68 |
+
running_tar_loss = 0.0 # count the target output loss
|
69 |
+
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
70 |
+
|
71 |
+
train_num = train_datasets[0].__len__()
|
72 |
+
|
73 |
+
net.train()
|
74 |
+
|
75 |
+
start_last = time.time()
|
76 |
+
gos_dataloader = train_dataloaders[0]
|
77 |
+
epoch_num = hypar["max_epoch_num"]
|
78 |
+
notgood_cnt = 0
|
79 |
+
for epoch in range(epoch_num): ## set the epoch num as 100000
|
80 |
+
|
81 |
+
for i, data in enumerate(gos_dataloader):
|
82 |
+
|
83 |
+
if(ite_num >= max_ite):
|
84 |
+
print("Training Reached the Maximal Iteration Number ", max_ite)
|
85 |
+
exit()
|
86 |
+
|
87 |
+
# start_read = time.time()
|
88 |
+
ite_num = ite_num + 1
|
89 |
+
ite_num4val = ite_num4val + 1
|
90 |
+
|
91 |
+
# get the inputs
|
92 |
+
labels = data['label']
|
93 |
+
|
94 |
+
if(hypar["model_digit"]=="full"):
|
95 |
+
labels = labels.type(torch.FloatTensor)
|
96 |
+
else:
|
97 |
+
labels = labels.type(torch.HalfTensor)
|
98 |
+
|
99 |
+
# wrap them in Variable
|
100 |
+
if torch.cuda.is_available():
|
101 |
+
labels_v = Variable(labels.cuda(), requires_grad=False)
|
102 |
+
else:
|
103 |
+
labels_v = Variable(labels, requires_grad=False)
|
104 |
+
|
105 |
+
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
106 |
+
|
107 |
+
# y zero the parameter gradients
|
108 |
+
start_inf_loss_back = time.time()
|
109 |
+
optimizer.zero_grad()
|
110 |
+
|
111 |
+
ds, fs = net(labels_v)#net(inputs_v)
|
112 |
+
loss2, loss = net.compute_loss(ds, labels_v)
|
113 |
+
|
114 |
+
loss.backward()
|
115 |
+
optimizer.step()
|
116 |
+
|
117 |
+
running_loss += loss.item()
|
118 |
+
running_tar_loss += loss2.item()
|
119 |
+
|
120 |
+
# del outputs, loss
|
121 |
+
del ds, loss2, loss
|
122 |
+
end_inf_loss_back = time.time()-start_inf_loss_back
|
123 |
+
|
124 |
+
print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
125 |
+
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
126 |
+
start_last = time.time()
|
127 |
+
|
128 |
+
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
129 |
+
notgood_cnt += 1
|
130 |
+
# net.eval()
|
131 |
+
# tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
132 |
+
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
|
133 |
+
|
134 |
+
net.train() # resume train
|
135 |
+
|
136 |
+
tmp_out = 0
|
137 |
+
print("last_f1:",last_f1)
|
138 |
+
print("tmp_f1:",tmp_f1)
|
139 |
+
for fi in range(len(last_f1)):
|
140 |
+
if(tmp_f1[fi]>last_f1[fi]):
|
141 |
+
tmp_out = 1
|
142 |
+
print("tmp_out:",tmp_out)
|
143 |
+
if(tmp_out):
|
144 |
+
notgood_cnt = 0
|
145 |
+
last_f1 = tmp_f1
|
146 |
+
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
147 |
+
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
148 |
+
maxf1 = '_'.join(tmp_f1_str)
|
149 |
+
meanM = '_'.join(tmp_mae_str)
|
150 |
+
# .cpu().detach().numpy()
|
151 |
+
model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
|
152 |
+
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
153 |
+
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
154 |
+
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
155 |
+
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
156 |
+
"_maxF1_" + maxf1 + \
|
157 |
+
"_mae_" + meanM + \
|
158 |
+
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
159 |
+
torch.save(net.state_dict(), model_path + model_name)
|
160 |
+
|
161 |
+
running_loss = 0.0
|
162 |
+
running_tar_loss = 0.0
|
163 |
+
ite_num4val = 0
|
164 |
+
|
165 |
+
if(tmp_f1[0]>0.99):
|
166 |
+
print("GT encoder is well-trained and obtained...")
|
167 |
+
return net
|
168 |
+
|
169 |
+
if(notgood_cnt >= hypar["early_stop"]):
|
170 |
+
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
171 |
+
exit()
|
172 |
+
|
173 |
+
print("Training Reaches The Maximum Epoch Number")
|
174 |
+
return net
|
175 |
+
|
176 |
+
def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
177 |
+
net.eval()
|
178 |
+
print("Validating...")
|
179 |
+
epoch_num = hypar["max_epoch_num"]
|
180 |
+
|
181 |
+
val_loss = 0.0
|
182 |
+
tar_loss = 0.0
|
183 |
+
|
184 |
+
|
185 |
+
tmp_f1 = []
|
186 |
+
tmp_mae = []
|
187 |
+
tmp_time = []
|
188 |
+
|
189 |
+
start_valid = time.time()
|
190 |
+
for k in range(len(valid_dataloaders)):
|
191 |
+
|
192 |
+
valid_dataloader = valid_dataloaders[k]
|
193 |
+
valid_dataset = valid_datasets[k]
|
194 |
+
|
195 |
+
val_num = valid_dataset.__len__()
|
196 |
+
mybins = np.arange(0,256)
|
197 |
+
PRE = np.zeros((val_num,len(mybins)-1))
|
198 |
+
REC = np.zeros((val_num,len(mybins)-1))
|
199 |
+
F1 = np.zeros((val_num,len(mybins)-1))
|
200 |
+
MAE = np.zeros((val_num))
|
201 |
+
|
202 |
+
val_cnt = 0.0
|
203 |
+
i_val = None
|
204 |
+
|
205 |
+
for i_val, data_val in enumerate(valid_dataloader):
|
206 |
+
|
207 |
+
# imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
208 |
+
imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
|
209 |
+
|
210 |
+
if(hypar["model_digit"]=="full"):
|
211 |
+
labels_val = labels_val.type(torch.FloatTensor)
|
212 |
+
else:
|
213 |
+
labels_val = labels_val.type(torch.HalfTensor)
|
214 |
+
|
215 |
+
# wrap them in Variable
|
216 |
+
if torch.cuda.is_available():
|
217 |
+
labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
|
218 |
+
else:
|
219 |
+
labels_val_v = Variable(labels_val,requires_grad=False)
|
220 |
+
|
221 |
+
t_start = time.time()
|
222 |
+
ds_val = net(labels_val_v)[0]
|
223 |
+
t_end = time.time()-t_start
|
224 |
+
tmp_time.append(t_end)
|
225 |
+
|
226 |
+
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
227 |
+
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
228 |
+
|
229 |
+
# compute F measure
|
230 |
+
for t in range(hypar["batch_size_valid"]):
|
231 |
+
val_cnt = val_cnt + 1.0
|
232 |
+
print("num of val: ", val_cnt)
|
233 |
+
i_test = imidx_val[t].data.numpy()
|
234 |
+
|
235 |
+
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
236 |
+
|
237 |
+
## recover the prediction spatial size to the orignal image size
|
238 |
+
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
239 |
+
|
240 |
+
ma = torch.max(pred_val)
|
241 |
+
mi = torch.min(pred_val)
|
242 |
+
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
243 |
+
# pred_val = normPRED(pred_val)
|
244 |
+
|
245 |
+
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
246 |
+
if gt.max()==1:
|
247 |
+
gt=gt*255
|
248 |
+
with torch.no_grad():
|
249 |
+
gt = torch.tensor(gt).to(device)
|
250 |
+
|
251 |
+
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
252 |
+
|
253 |
+
PRE[i_test,:]=pre
|
254 |
+
REC[i_test,:] = rec
|
255 |
+
F1[i_test,:] = f1
|
256 |
+
MAE[i_test] = mae
|
257 |
+
|
258 |
+
del ds_val, gt
|
259 |
+
gc.collect()
|
260 |
+
torch.cuda.empty_cache()
|
261 |
+
|
262 |
+
# if(loss_val.data[0]>1):
|
263 |
+
val_loss += loss_val.item()#data[0]
|
264 |
+
tar_loss += loss2_val.item()#data[0]
|
265 |
+
|
266 |
+
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
267 |
+
|
268 |
+
del loss2_val, loss_val
|
269 |
+
|
270 |
+
print('============================')
|
271 |
+
PRE_m = np.mean(PRE,0)
|
272 |
+
REC_m = np.mean(REC,0)
|
273 |
+
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
274 |
+
# print('--------------:', np.mean(f1_m))
|
275 |
+
tmp_f1.append(np.amax(f1_m))
|
276 |
+
tmp_mae.append(np.mean(MAE))
|
277 |
+
print("The max F1 Score: %f"%(np.max(f1_m)))
|
278 |
+
print("MAE: ", np.mean(MAE))
|
279 |
+
|
280 |
+
# print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
|
281 |
+
|
282 |
+
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
283 |
+
|
284 |
+
def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
285 |
+
|
286 |
+
if hypar["interm_sup"]:
|
287 |
+
print("Get the gt encoder ...")
|
288 |
+
featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
|
289 |
+
## freeze the weights of gt encoder
|
290 |
+
for param in featurenet.parameters():
|
291 |
+
param.requires_grad=False
|
292 |
+
|
293 |
+
|
294 |
+
model_path = hypar["model_path"]
|
295 |
+
model_save_fre = hypar["model_save_fre"]
|
296 |
+
max_ite = hypar["max_ite"]
|
297 |
+
batch_size_train = hypar["batch_size_train"]
|
298 |
+
batch_size_valid = hypar["batch_size_valid"]
|
299 |
+
|
300 |
+
if(not os.path.exists(model_path)):
|
301 |
+
os.mkdir(model_path)
|
302 |
+
|
303 |
+
ite_num = hypar["start_ite"] # count the toal iteration number
|
304 |
+
ite_num4val = 0 #
|
305 |
+
running_loss = 0.0 # count the toal loss
|
306 |
+
running_tar_loss = 0.0 # count the target output loss
|
307 |
+
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
308 |
+
|
309 |
+
train_num = train_datasets[0].__len__()
|
310 |
+
|
311 |
+
net.train()
|
312 |
+
|
313 |
+
start_last = time.time()
|
314 |
+
gos_dataloader = train_dataloaders[0]
|
315 |
+
epoch_num = hypar["max_epoch_num"]
|
316 |
+
notgood_cnt = 0
|
317 |
+
for epoch in range(epoch_num): ## set the epoch num as 100000
|
318 |
+
|
319 |
+
for i, data in enumerate(gos_dataloader):
|
320 |
+
|
321 |
+
if(ite_num >= max_ite):
|
322 |
+
print("Training Reached the Maximal Iteration Number ", max_ite)
|
323 |
+
exit()
|
324 |
+
|
325 |
+
# start_read = time.time()
|
326 |
+
ite_num = ite_num + 1
|
327 |
+
ite_num4val = ite_num4val + 1
|
328 |
+
|
329 |
+
# get the inputs
|
330 |
+
inputs, labels = data['image'], data['label']
|
331 |
+
|
332 |
+
if(hypar["model_digit"]=="full"):
|
333 |
+
inputs = inputs.type(torch.FloatTensor)
|
334 |
+
labels = labels.type(torch.FloatTensor)
|
335 |
+
else:
|
336 |
+
inputs = inputs.type(torch.HalfTensor)
|
337 |
+
labels = labels.type(torch.HalfTensor)
|
338 |
+
|
339 |
+
# wrap them in Variable
|
340 |
+
if torch.cuda.is_available():
|
341 |
+
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
|
342 |
+
else:
|
343 |
+
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
|
344 |
+
|
345 |
+
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
346 |
+
|
347 |
+
# y zero the parameter gradients
|
348 |
+
start_inf_loss_back = time.time()
|
349 |
+
optimizer.zero_grad()
|
350 |
+
|
351 |
+
if hypar["interm_sup"]:
|
352 |
+
# forward + backward + optimize
|
353 |
+
ds,dfs = net(inputs_v)
|
354 |
+
_,fs = featurenet(labels_v) ## extract the gt encodings
|
355 |
+
loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
|
356 |
+
else:
|
357 |
+
# forward + backward + optimize
|
358 |
+
ds,_ = net(inputs_v)
|
359 |
+
loss2, loss = net.compute_loss(ds, labels_v)
|
360 |
+
|
361 |
+
loss.backward()
|
362 |
+
optimizer.step()
|
363 |
+
|
364 |
+
# # print statistics
|
365 |
+
running_loss += loss.item()
|
366 |
+
running_tar_loss += loss2.item()
|
367 |
+
|
368 |
+
# del outputs, loss
|
369 |
+
del ds, loss2, loss
|
370 |
+
end_inf_loss_back = time.time()-start_inf_loss_back
|
371 |
+
|
372 |
+
print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
373 |
+
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
374 |
+
start_last = time.time()
|
375 |
+
|
376 |
+
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
377 |
+
notgood_cnt += 1
|
378 |
+
net.eval()
|
379 |
+
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
380 |
+
net.train() # resume train
|
381 |
+
|
382 |
+
tmp_out = 0
|
383 |
+
print("last_f1:",last_f1)
|
384 |
+
print("tmp_f1:",tmp_f1)
|
385 |
+
for fi in range(len(last_f1)):
|
386 |
+
if(tmp_f1[fi]>last_f1[fi]):
|
387 |
+
tmp_out = 1
|
388 |
+
print("tmp_out:",tmp_out)
|
389 |
+
if(tmp_out):
|
390 |
+
notgood_cnt = 0
|
391 |
+
last_f1 = tmp_f1
|
392 |
+
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
393 |
+
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
394 |
+
maxf1 = '_'.join(tmp_f1_str)
|
395 |
+
meanM = '_'.join(tmp_mae_str)
|
396 |
+
# .cpu().detach().numpy()
|
397 |
+
model_name = "/gpu_itr_"+str(ite_num)+\
|
398 |
+
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
399 |
+
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
400 |
+
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
401 |
+
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
402 |
+
"_maxF1_" + maxf1 + \
|
403 |
+
"_mae_" + meanM + \
|
404 |
+
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
405 |
+
torch.save(net.state_dict(), model_path + model_name)
|
406 |
+
|
407 |
+
running_loss = 0.0
|
408 |
+
running_tar_loss = 0.0
|
409 |
+
ite_num4val = 0
|
410 |
+
|
411 |
+
if(notgood_cnt >= hypar["early_stop"]):
|
412 |
+
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
413 |
+
exit()
|
414 |
+
|
415 |
+
print("Training Reaches The Maximum Epoch Number")
|
416 |
+
|
417 |
+
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
418 |
+
net.eval()
|
419 |
+
print("Validating...")
|
420 |
+
epoch_num = hypar["max_epoch_num"]
|
421 |
+
|
422 |
+
val_loss = 0.0
|
423 |
+
tar_loss = 0.0
|
424 |
+
val_cnt = 0.0
|
425 |
+
|
426 |
+
tmp_f1 = []
|
427 |
+
tmp_mae = []
|
428 |
+
tmp_time = []
|
429 |
+
|
430 |
+
start_valid = time.time()
|
431 |
+
|
432 |
+
for k in range(len(valid_dataloaders)):
|
433 |
+
|
434 |
+
valid_dataloader = valid_dataloaders[k]
|
435 |
+
valid_dataset = valid_datasets[k]
|
436 |
+
|
437 |
+
val_num = valid_dataset.__len__()
|
438 |
+
mybins = np.arange(0,256)
|
439 |
+
PRE = np.zeros((val_num,len(mybins)-1))
|
440 |
+
REC = np.zeros((val_num,len(mybins)-1))
|
441 |
+
F1 = np.zeros((val_num,len(mybins)-1))
|
442 |
+
MAE = np.zeros((val_num))
|
443 |
+
|
444 |
+
for i_val, data_val in enumerate(valid_dataloader):
|
445 |
+
val_cnt = val_cnt + 1.0
|
446 |
+
imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
447 |
+
|
448 |
+
if(hypar["model_digit"]=="full"):
|
449 |
+
inputs_val = inputs_val.type(torch.FloatTensor)
|
450 |
+
labels_val = labels_val.type(torch.FloatTensor)
|
451 |
+
else:
|
452 |
+
inputs_val = inputs_val.type(torch.HalfTensor)
|
453 |
+
labels_val = labels_val.type(torch.HalfTensor)
|
454 |
+
|
455 |
+
# wrap them in Variable
|
456 |
+
if torch.cuda.is_available():
|
457 |
+
inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
|
458 |
+
else:
|
459 |
+
inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
|
460 |
+
|
461 |
+
t_start = time.time()
|
462 |
+
ds_val = net(inputs_val_v)[0]
|
463 |
+
t_end = time.time()-t_start
|
464 |
+
tmp_time.append(t_end)
|
465 |
+
|
466 |
+
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
467 |
+
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
468 |
+
|
469 |
+
# compute F measure
|
470 |
+
for t in range(hypar["batch_size_valid"]):
|
471 |
+
i_test = imidx_val[t].data.numpy()
|
472 |
+
|
473 |
+
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
474 |
+
|
475 |
+
## recover the prediction spatial size to the orignal image size
|
476 |
+
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
477 |
+
|
478 |
+
# pred_val = normPRED(pred_val)
|
479 |
+
ma = torch.max(pred_val)
|
480 |
+
mi = torch.min(pred_val)
|
481 |
+
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
482 |
+
|
483 |
+
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
484 |
+
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
485 |
+
if gt.max()==1:
|
486 |
+
gt=gt*255
|
487 |
+
else:
|
488 |
+
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
489 |
+
with torch.no_grad():
|
490 |
+
gt = torch.tensor(gt).to(device)
|
491 |
+
|
492 |
+
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
493 |
+
|
494 |
+
|
495 |
+
PRE[i_test,:]=pre
|
496 |
+
REC[i_test,:] = rec
|
497 |
+
F1[i_test,:] = f1
|
498 |
+
MAE[i_test] = mae
|
499 |
+
|
500 |
+
del ds_val, gt
|
501 |
+
gc.collect()
|
502 |
+
torch.cuda.empty_cache()
|
503 |
+
|
504 |
+
# if(loss_val.data[0]>1):
|
505 |
+
val_loss += loss_val.item()#data[0]
|
506 |
+
tar_loss += loss2_val.item()#data[0]
|
507 |
+
|
508 |
+
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
509 |
+
|
510 |
+
del loss2_val, loss_val
|
511 |
+
|
512 |
+
print('============================')
|
513 |
+
PRE_m = np.mean(PRE,0)
|
514 |
+
REC_m = np.mean(REC,0)
|
515 |
+
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
516 |
+
|
517 |
+
tmp_f1.append(np.amax(f1_m))
|
518 |
+
tmp_mae.append(np.mean(MAE))
|
519 |
+
|
520 |
+
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
521 |
+
|
522 |
+
def main(train_datasets,
|
523 |
+
valid_datasets,
|
524 |
+
hypar): # model: "train", "test"
|
525 |
+
|
526 |
+
### --- Step 1: Build datasets and dataloaders ---
|
527 |
+
dataloaders_train = []
|
528 |
+
dataloaders_valid = []
|
529 |
+
|
530 |
+
if(hypar["mode"]=="train"):
|
531 |
+
print("--- create training dataloader ---")
|
532 |
+
## collect training dataset
|
533 |
+
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
|
534 |
+
## build dataloader for training datasets
|
535 |
+
train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
536 |
+
cache_size = hypar["cache_size"],
|
537 |
+
cache_boost = hypar["cache_boost_train"],
|
538 |
+
my_transforms = [
|
539 |
+
GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
|
540 |
+
# GOSResize(hypar["input_size"]),
|
541 |
+
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
|
542 |
+
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
543 |
+
],
|
544 |
+
batch_size = hypar["batch_size_train"],
|
545 |
+
shuffle = True)
|
546 |
+
train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
|
547 |
+
cache_size = hypar["cache_size"],
|
548 |
+
cache_boost = hypar["cache_boost_train"],
|
549 |
+
my_transforms = [
|
550 |
+
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
551 |
+
],
|
552 |
+
batch_size = hypar["batch_size_valid"],
|
553 |
+
shuffle = False)
|
554 |
+
print(len(train_dataloaders), " train dataloaders created")
|
555 |
+
|
556 |
+
print("--- create valid dataloader ---")
|
557 |
+
## build dataloader for validation or testing
|
558 |
+
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
|
559 |
+
## build dataloader for training datasets
|
560 |
+
valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
|
561 |
+
cache_size = hypar["cache_size"],
|
562 |
+
cache_boost = hypar["cache_boost_valid"],
|
563 |
+
my_transforms = [
|
564 |
+
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
565 |
+
# GOSResize(hypar["input_size"])
|
566 |
+
],
|
567 |
+
batch_size=hypar["batch_size_valid"],
|
568 |
+
shuffle=False)
|
569 |
+
print(len(valid_dataloaders), " valid dataloaders created")
|
570 |
+
# print(valid_datasets[0]["data_name"])
|
571 |
+
|
572 |
+
### --- Step 2: Build Model and Optimizer ---
|
573 |
+
print("--- build model ---")
|
574 |
+
net = hypar["model"]#GOSNETINC(3,1)
|
575 |
+
|
576 |
+
# convert to half precision
|
577 |
+
if(hypar["model_digit"]=="half"):
|
578 |
+
net.half()
|
579 |
+
for layer in net.modules():
|
580 |
+
if isinstance(layer, nn.BatchNorm2d):
|
581 |
+
layer.float()
|
582 |
+
|
583 |
+
if torch.cuda.is_available():
|
584 |
+
net.cuda()
|
585 |
+
|
586 |
+
if(hypar["restore_model"]!=""):
|
587 |
+
print("restore model from:")
|
588 |
+
print(hypar["model_path"]+"/"+hypar["restore_model"])
|
589 |
+
if torch.cuda.is_available():
|
590 |
+
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]))
|
591 |
+
else:
|
592 |
+
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"))
|
593 |
+
|
594 |
+
print("--- define optimizer ---")
|
595 |
+
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
596 |
+
|
597 |
+
### --- Step 3: Train or Valid Model ---
|
598 |
+
if(hypar["mode"]=="train"):
|
599 |
+
train(net,
|
600 |
+
optimizer,
|
601 |
+
train_dataloaders,
|
602 |
+
train_datasets,
|
603 |
+
valid_dataloaders,
|
604 |
+
valid_datasets,
|
605 |
+
hypar,
|
606 |
+
train_dataloaders_val, train_datasets_val)
|
607 |
+
else:
|
608 |
+
valid(net,
|
609 |
+
valid_dataloaders,
|
610 |
+
valid_datasets,
|
611 |
+
hypar)
|
612 |
+
|
613 |
+
|
614 |
+
if __name__ == "__main__":
|
615 |
+
|
616 |
+
### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
|
617 |
+
## configure the train, valid and inference datasets
|
618 |
+
train_datasets, valid_datasets = [], []
|
619 |
+
dataset_1, dataset_1 = {}, {}
|
620 |
+
|
621 |
+
dataset_tr = {"name": "DIS5K-TR",
|
622 |
+
"im_dir": "../DIS5K/DIS-TR/im",
|
623 |
+
"gt_dir": "../DIS5K/DIS-TR/gt",
|
624 |
+
"im_ext": ".jpg",
|
625 |
+
"gt_ext": ".png",
|
626 |
+
"cache_dir":"../DIS5K-Cache/DIS-TR"}
|
627 |
+
|
628 |
+
dataset_vd = {"name": "DIS5K-VD",
|
629 |
+
"im_dir": "../DIS5K/DIS-VD/im",
|
630 |
+
"gt_dir": "../DIS5K/DIS-VD/gt",
|
631 |
+
"im_ext": ".jpg",
|
632 |
+
"gt_ext": ".png",
|
633 |
+
"cache_dir":"../DIS5K-Cache/DIS-VD"}
|
634 |
+
|
635 |
+
dataset_te1 = {"name": "DIS5K-TE1",
|
636 |
+
"im_dir": "../DIS5K/DIS-TE1/im",
|
637 |
+
"gt_dir": "../DIS5K/DIS-TE1/gt",
|
638 |
+
"im_ext": ".jpg",
|
639 |
+
"gt_ext": ".png",
|
640 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE1"}
|
641 |
+
|
642 |
+
dataset_te2 = {"name": "DIS5K-TE2",
|
643 |
+
"im_dir": "../DIS5K/DIS-TE2/im",
|
644 |
+
"gt_dir": "../DIS5K/DIS-TE2/gt",
|
645 |
+
"im_ext": ".jpg",
|
646 |
+
"gt_ext": ".png",
|
647 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE2"}
|
648 |
+
|
649 |
+
dataset_te3 = {"name": "DIS5K-TE3",
|
650 |
+
"im_dir": "../DIS5K/DIS-TE3/im",
|
651 |
+
"gt_dir": "../DIS5K/DIS-TE3/gt",
|
652 |
+
"im_ext": ".jpg",
|
653 |
+
"gt_ext": ".png",
|
654 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE3"}
|
655 |
+
|
656 |
+
dataset_te4 = {"name": "DIS5K-TE4",
|
657 |
+
"im_dir": "../DIS5K/DIS-TE4/im",
|
658 |
+
"gt_dir": "../DIS5K/DIS-TE4/gt",
|
659 |
+
"im_ext": ".jpg",
|
660 |
+
"gt_ext": ".png",
|
661 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
|
662 |
+
### test your own dataset
|
663 |
+
dataset_demo = {"name": "your-dataset",
|
664 |
+
"im_dir": "../your-dataset/im",
|
665 |
+
"gt_dir": "",
|
666 |
+
"im_ext": ".jpg",
|
667 |
+
"gt_ext": "",
|
668 |
+
"cache_dir":"../your-dataset/cache"}
|
669 |
+
|
670 |
+
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
|
671 |
+
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
672 |
+
valid_datasets = [dataset_vd] # dataset_vd, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
673 |
+
|
674 |
+
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
675 |
+
hypar = {}
|
676 |
+
|
677 |
+
## -- 2.1. configure the model saving or restoring path --
|
678 |
+
hypar["mode"] = "train"
|
679 |
+
## "train": for training,
|
680 |
+
## "valid": for validation and inferening,
|
681 |
+
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
682 |
+
## otherwise only accuracy will be calculated and no predictions will be saved
|
683 |
+
hypar["interm_sup"] = False ## in-dicate if activate intermediate feature supervision
|
684 |
+
|
685 |
+
if hypar["mode"] == "train":
|
686 |
+
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
687 |
+
hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
|
688 |
+
hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
|
689 |
+
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
690 |
+
hypar["gt_encoder_model"] = ""
|
691 |
+
else: ## configure the segmentation output path and the to-be-used model weights path
|
692 |
+
hypar["valid_out_dir"] = "../your-results/"##"../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
693 |
+
hypar["model_path"] = "../saved_models/IS-Net" ## load trained weights from this path
|
694 |
+
hypar["restore_model"] = "isnet.pth"##"isnet.pth" ## name of the to-be-loaded weights
|
695 |
+
|
696 |
+
# if hypar["restore_model"]!="":
|
697 |
+
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
698 |
+
|
699 |
+
## -- 2.2. choose floating point accuracy --
|
700 |
+
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
|
701 |
+
hypar["seed"] = 0
|
702 |
+
|
703 |
+
## -- 2.3. cache data spatial size --
|
704 |
+
## To handle large size input images, which take a lot of time for loading in training,
|
705 |
+
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
|
706 |
+
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
|
707 |
+
hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
|
708 |
+
hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
|
709 |
+
|
710 |
+
## --- 2.4. data augmentation parameters ---
|
711 |
+
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
|
712 |
+
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
|
713 |
+
hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use
|
714 |
+
hypar["random_flip_v"] = 0 ## vertical flip , currently not in use
|
715 |
+
|
716 |
+
## --- 2.5. define model ---
|
717 |
+
print("building model...")
|
718 |
+
hypar["model"] = ISNetDIS() #U2NETFASTFEATURESUP()
|
719 |
+
hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
|
720 |
+
hypar["model_save_fre"] = 2000 ## valid and save model weights every 2000 iterations
|
721 |
+
|
722 |
+
hypar["batch_size_train"] = 8 ## batch size for training
|
723 |
+
hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
|
724 |
+
print("batch size: ", hypar["batch_size_train"])
|
725 |
+
|
726 |
+
hypar["max_ite"] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
|
727 |
+
hypar["max_epoch_num"] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
|
728 |
+
|
729 |
+
main(train_datasets,
|
730 |
+
valid_datasets,
|
731 |
+
hypar=hypar)
|