Vijish commited on
Commit
221def3
·
1 Parent(s): 9ed5837

Upload 7 files

Browse files
Inference.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from skimage import io
5
+ import time
6
+ from glob import glob
7
+ from tqdm import tqdm
8
+
9
+ import torch, gc
10
+ import torch.nn as nn
11
+ from torch.autograd import Variable
12
+ import torch.optim as optim
13
+ import torch.nn.functional as F
14
+ from torchvision.transforms.functional import normalize
15
+
16
+ from models import *
17
+
18
+
19
+ if __name__ == "__main__":
20
+ dataset_path="../demo_datasets/your_dataset" #Your dataset path
21
+ model_path="../saved_models/IS-Net/isnet-general-use.pth" # the model path
22
+ result_path="../demo_datasets/your_dataset_result" #The folder path that you want to save the results
23
+ input_size=[1024,1024]
24
+ net=ISNetDIS()
25
+
26
+ if torch.cuda.is_available():
27
+ net.load_state_dict(torch.load(model_path))
28
+ net=net.cuda()
29
+ else:
30
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
31
+ net.eval()
32
+ im_list = glob(dataset_path+"/*.jpg")+glob(dataset_path+"/*.JPG")+glob(dataset_path+"/*.jpeg")+glob(dataset_path+"/*.JPEG")+glob(dataset_path+"/*.png")+glob(dataset_path+"/*.PNG")+glob(dataset_path+"/*.bmp")+glob(dataset_path+"/*.BMP")+glob(dataset_path+"/*.tiff")+glob(dataset_path+"/*.TIFF")
33
+ with torch.no_grad():
34
+ for i, im_path in tqdm(enumerate(im_list), total=len(im_list)):
35
+ print("im_path: ", im_path)
36
+ im = io.imread(im_path)
37
+ if len(im.shape) < 3:
38
+ im = im[:, :, np.newaxis]
39
+ im_shp=im.shape[0:2]
40
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
41
+ im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
42
+ image = torch.divide(im_tensor,255.0)
43
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
44
+
45
+ if torch.cuda.is_available():
46
+ image=image.cuda()
47
+ result=net(image)
48
+ result=torch.squeeze(F.upsample(result[0][0],im_shp,mode='bilinear'),0)
49
+ ma = torch.max(result)
50
+ mi = torch.min(result)
51
+ result = (result-mi)/(ma-mi)
52
+ im_name=im_path.split('/')[-1].split('.')[0]
53
+ io.imsave(os.path.join(result_path,im_name+".png"),(result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8))
basics.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '2'
3
+ from skimage import io, transform
4
+ import torch
5
+ import torchvision
6
+ from torch.autograd import Variable
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms, utils
11
+ import torch.optim as optim
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from PIL import Image
16
+ import glob
17
+
18
+ def mae_torch(pred,gt):
19
+
20
+ h,w = gt.shape[0:2]
21
+ sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
22
+ maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
23
+
24
+ return maeError
25
+
26
+ def f1score_torch(pd,gt):
27
+
28
+ # print(gt.shape)
29
+ gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
30
+
31
+ pp = pd[gt>128]
32
+ nn = pd[gt<=128]
33
+
34
+ pp_hist =torch.histc(pp,bins=255,min=0,max=255)
35
+ nn_hist = torch.histc(nn,bins=255,min=0,max=255)
36
+
37
+
38
+ pp_hist_flip = torch.flipud(pp_hist)
39
+ nn_hist_flip = torch.flipud(nn_hist)
40
+
41
+ pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
42
+ nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
43
+
44
+ precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
45
+ recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
46
+ f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
47
+
48
+ return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
49
+
50
+
51
+ def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
52
+
53
+ import time
54
+ tic = time.time()
55
+
56
+ if(len(gt.shape)>2):
57
+ gt = gt[:,:,0]
58
+
59
+ pre, rec, f1 = f1score_torch(pred,gt)
60
+ mae = mae_torch(pred,gt)
61
+
62
+
63
+ # hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
64
+ if(hypar["valid_out_dir"]!=""):
65
+ if(not os.path.exists(hypar["valid_out_dir"])):
66
+ os.mkdir(hypar["valid_out_dir"])
67
+ dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
68
+ if(not os.path.exists(dataset_folder)):
69
+ os.mkdir(dataset_folder)
70
+ io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
71
+ print(valid_dataset.dataset["im_name"][idx]+".png")
72
+ print("time for evaluation : ", time.time()-tic)
73
+
74
+ return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()
data_loader_cache.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## data loader
2
+ ## Ackownledgement:
3
+ ## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
4
+ ## for his helps in implementing cache machanism of our DIS dataloader.
5
+ from __future__ import print_function, division
6
+
7
+ import numpy as np
8
+ import random
9
+ from copy import deepcopy
10
+ import json
11
+ from tqdm import tqdm
12
+ from skimage import io
13
+ import os
14
+ from glob import glob
15
+
16
+ import torch
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torchvision import transforms, utils
19
+ from torchvision.transforms.functional import normalize
20
+ import torch.nn.functional as F
21
+
22
+ #### --------------------- DIS dataloader cache ---------------------####
23
+
24
+ def get_im_gt_name_dict(datasets, flag='valid'):
25
+ print("------------------------------", flag, "--------------------------------")
26
+ name_im_gt_list = []
27
+ for i in range(len(datasets)):
28
+ print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
29
+ tmp_im_list, tmp_gt_list = [], []
30
+ tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
31
+
32
+ # img_name_dict[im_dirs[i][0]] = tmp_im_list
33
+ print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
34
+
35
+ if(datasets[i]["gt_dir"]==""):
36
+ print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
37
+ tmp_gt_list = []
38
+ else:
39
+ tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
40
+
41
+ # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
42
+ print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
43
+
44
+
45
+ if flag=="train": ## combine multiple training sets into one dataset
46
+ if len(name_im_gt_list)==0:
47
+ name_im_gt_list.append({"dataset_name":datasets[i]["name"],
48
+ "im_path":tmp_im_list,
49
+ "gt_path":tmp_gt_list,
50
+ "im_ext":datasets[i]["im_ext"],
51
+ "gt_ext":datasets[i]["gt_ext"],
52
+ "cache_dir":datasets[i]["cache_dir"]})
53
+ else:
54
+ name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
55
+ name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list
56
+ name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list
57
+ if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
58
+ print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
59
+ exit()
60
+ name_im_gt_list[0]["im_ext"] = ".jpg"
61
+ name_im_gt_list[0]["gt_ext"] = ".png"
62
+ name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"]
63
+ else: ## keep different validation or inference datasets as separate ones
64
+ name_im_gt_list.append({"dataset_name":datasets[i]["name"],
65
+ "im_path":tmp_im_list,
66
+ "gt_path":tmp_gt_list,
67
+ "im_ext":datasets[i]["im_ext"],
68
+ "gt_ext":datasets[i]["gt_ext"],
69
+ "cache_dir":datasets[i]["cache_dir"]})
70
+
71
+ return name_im_gt_list
72
+
73
+ def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False):
74
+ ## model="train": return one dataloader for training
75
+ ## model="valid": return a list of dataloaders for validation or testing
76
+
77
+ gos_dataloaders = []
78
+ gos_datasets = []
79
+
80
+ if(len(name_im_gt_list)==0):
81
+ return gos_dataloaders, gos_datasets
82
+
83
+ num_workers_ = 1
84
+ if(batch_size>1):
85
+ num_workers_ = 2
86
+ if(batch_size>4):
87
+ num_workers_ = 4
88
+ if(batch_size>8):
89
+ num_workers_ = 8
90
+
91
+ for i in range(0,len(name_im_gt_list)):
92
+ gos_dataset = GOSDatasetCache([name_im_gt_list[i]],
93
+ cache_size = cache_size,
94
+ cache_path = name_im_gt_list[i]["cache_dir"],
95
+ cache_boost = cache_boost,
96
+ transform = transforms.Compose(my_transforms))
97
+ gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
98
+ gos_datasets.append(gos_dataset)
99
+
100
+ return gos_dataloaders, gos_datasets
101
+
102
+ def im_reader(im_path):
103
+ return io.imread(im_path)
104
+
105
+ def im_preprocess(im,size):
106
+ if len(im.shape) < 3:
107
+ im = im[:, :, np.newaxis]
108
+ if im.shape[2] == 1:
109
+ im = np.repeat(im, 3, axis=2)
110
+ im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
111
+ im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
112
+ if(len(size)<2):
113
+ return im_tensor, im.shape[0:2]
114
+ else:
115
+ im_tensor = torch.unsqueeze(im_tensor,0)
116
+ im_tensor = F.upsample(im_tensor, size, mode="bilinear")
117
+ im_tensor = torch.squeeze(im_tensor,0)
118
+
119
+ return im_tensor.type(torch.uint8), im.shape[0:2]
120
+
121
+ def gt_preprocess(gt,size):
122
+ if len(gt.shape) > 2:
123
+ gt = gt[:, :, 0]
124
+
125
+ gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
126
+
127
+ if(len(size)<2):
128
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
129
+ else:
130
+ gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
131
+ gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
132
+ gt_tensor = torch.squeeze(gt_tensor,0)
133
+
134
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
135
+ # return gt_tensor, gt.shape[0:2]
136
+
137
+ class GOSRandomHFlip(object):
138
+ def __init__(self,prob=0.5):
139
+ self.prob = prob
140
+ def __call__(self,sample):
141
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
142
+
143
+ # random horizontal flip
144
+ if random.random() >= self.prob:
145
+ image = torch.flip(image,dims=[2])
146
+ label = torch.flip(label,dims=[2])
147
+
148
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
149
+
150
+ class GOSResize(object):
151
+ def __init__(self,size=[320,320]):
152
+ self.size = size
153
+ def __call__(self,sample):
154
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
155
+
156
+ # import time
157
+ # start = time.time()
158
+
159
+ image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
160
+ label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
161
+
162
+ # print("time for resize: ", time.time()-start)
163
+
164
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
165
+
166
+ class GOSRandomCrop(object):
167
+ def __init__(self,size=[288,288]):
168
+ self.size = size
169
+ def __call__(self,sample):
170
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
171
+
172
+ h, w = image.shape[1:]
173
+ new_h, new_w = self.size
174
+
175
+ top = np.random.randint(0, h - new_h)
176
+ left = np.random.randint(0, w - new_w)
177
+
178
+ image = image[:,top:top+new_h,left:left+new_w]
179
+ label = label[:,top:top+new_h,left:left+new_w]
180
+
181
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
182
+
183
+
184
+ class GOSNormalize(object):
185
+ def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
186
+ self.mean = mean
187
+ self.std = std
188
+
189
+ def __call__(self,sample):
190
+
191
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
192
+ image = normalize(image,self.mean,self.std)
193
+
194
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
195
+
196
+
197
+ class GOSDatasetCache(Dataset):
198
+
199
+ def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None):
200
+
201
+
202
+ self.cache_size = cache_size
203
+ self.cache_path = cache_path
204
+ self.cache_file_name = cache_file_name
205
+ self.cache_boost_name = ""
206
+
207
+ self.cache_boost = cache_boost
208
+ # self.ims_npy = None
209
+ # self.gts_npy = None
210
+
211
+ ## cache all the images and ground truth into a single pytorch tensor
212
+ self.ims_pt = None
213
+ self.gts_pt = None
214
+
215
+ ## we will cache the npy as well regardless of the cache_boost
216
+ # if(self.cache_boost):
217
+ self.cache_boost_name = cache_file_name.split('.json')[0]
218
+
219
+ self.transform = transform
220
+
221
+ self.dataset = {}
222
+
223
+ ## combine different datasets into one
224
+ dataset_names = []
225
+ dt_name_list = [] # dataset name per image
226
+ im_name_list = [] # image name
227
+ im_path_list = [] # im path
228
+ gt_path_list = [] # gt path
229
+ im_ext_list = [] # im ext
230
+ gt_ext_list = [] # gt ext
231
+ for i in range(0,len(name_im_gt_list)):
232
+ dataset_names.append(name_im_gt_list[i]["dataset_name"])
233
+ # dataset name repeated based on the number of images in this dataset
234
+ dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]])
235
+ im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]])
236
+ im_path_list.extend(name_im_gt_list[i]["im_path"])
237
+ gt_path_list.extend(name_im_gt_list[i]["gt_path"])
238
+ im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]])
239
+ gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]])
240
+
241
+
242
+ self.dataset["data_name"] = dt_name_list
243
+ self.dataset["im_name"] = im_name_list
244
+ self.dataset["im_path"] = im_path_list
245
+ self.dataset["ori_im_path"] = deepcopy(im_path_list)
246
+ self.dataset["gt_path"] = gt_path_list
247
+ self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
248
+ self.dataset["im_shp"] = []
249
+ self.dataset["gt_shp"] = []
250
+ self.dataset["im_ext"] = im_ext_list
251
+ self.dataset["gt_ext"] = gt_ext_list
252
+
253
+
254
+ self.dataset["ims_pt_dir"] = ""
255
+ self.dataset["gts_pt_dir"] = ""
256
+
257
+ self.dataset = self.manage_cache(dataset_names)
258
+
259
+ def manage_cache(self,dataset_names):
260
+ if not os.path.exists(self.cache_path): # create the folder for cache
261
+ os.makedirs(self.cache_path)
262
+ cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
263
+ if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
264
+ return self.cache(cache_folder)
265
+ return self.load_cache(cache_folder)
266
+
267
+ def cache(self,cache_folder):
268
+ os.mkdir(cache_folder)
269
+ cached_dataset = deepcopy(self.dataset)
270
+
271
+ # ims_list = []
272
+ # gts_list = []
273
+ ims_pt_list = []
274
+ gts_pt_list = []
275
+ for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
276
+
277
+ im_id = cached_dataset["im_name"][i]
278
+ print("im_path: ", im_path)
279
+ im = im_reader(im_path)
280
+ im, im_shp = im_preprocess(im,self.cache_size)
281
+ im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
282
+ torch.save(im,im_cache_file)
283
+
284
+ cached_dataset["im_path"][i] = im_cache_file
285
+ if(self.cache_boost):
286
+ ims_pt_list.append(torch.unsqueeze(im,0))
287
+ # ims_list.append(im.cpu().data.numpy().astype(np.uint8))
288
+
289
+ gt = np.zeros(im.shape[0:2])
290
+ if len(self.dataset["gt_path"])!=0:
291
+ gt = im_reader(self.dataset["gt_path"][i])
292
+ gt, gt_shp = gt_preprocess(gt,self.cache_size)
293
+ gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
294
+ torch.save(gt,gt_cache_file)
295
+ if len(self.dataset["gt_path"])>0:
296
+ cached_dataset["gt_path"][i] = gt_cache_file
297
+ else:
298
+ cached_dataset["gt_path"].append(gt_cache_file)
299
+ if(self.cache_boost):
300
+ gts_pt_list.append(torch.unsqueeze(gt,0))
301
+ # gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
302
+
303
+ # im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
304
+ # torch.save(gt_shp, shp_cache_file)
305
+ cached_dataset["im_shp"].append(im_shp)
306
+ # self.dataset["im_shp"].append(im_shp)
307
+
308
+ # shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
309
+ # torch.save(gt_shp, shp_cache_file)
310
+ cached_dataset["gt_shp"].append(gt_shp)
311
+ # self.dataset["gt_shp"].append(gt_shp)
312
+
313
+ if(self.cache_boost):
314
+ cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
315
+ cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
316
+ self.ims_pt = torch.cat(ims_pt_list,dim=0)
317
+ self.gts_pt = torch.cat(gts_pt_list,dim=0)
318
+ torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
319
+ torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
320
+
321
+ try:
322
+ json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
323
+ json.dump(cached_dataset, json_file)
324
+ json_file.close()
325
+ except Exception:
326
+ raise FileNotFoundError("Cannot create JSON")
327
+ return cached_dataset
328
+
329
+ def load_cache(self, cache_folder):
330
+ json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
331
+ dataset = json.load(json_file)
332
+ json_file.close()
333
+ ## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
334
+ ## otherwise the pytorch tensor will be loaded
335
+ if(self.cache_boost):
336
+ # self.ims_npy = np.load(dataset["ims_npy_dir"])
337
+ # self.gts_npy = np.load(dataset["gts_npy_dir"])
338
+ self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
339
+ self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
340
+ return dataset
341
+
342
+ def __len__(self):
343
+ return len(self.dataset["im_path"])
344
+
345
+ def __getitem__(self, idx):
346
+
347
+ im = None
348
+ gt = None
349
+ if(self.cache_boost and self.ims_pt is not None):
350
+
351
+ # start = time.time()
352
+ im = self.ims_pt[idx]#.type(torch.float32)
353
+ gt = self.gts_pt[idx]#.type(torch.float32)
354
+ # print(idx, 'time for pt loading: ', time.time()-start)
355
+
356
+ else:
357
+ # import time
358
+ # start = time.time()
359
+ # print("tensor***")
360
+ im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
361
+ im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
362
+ gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
363
+ gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
364
+ # print(idx,'time for tensor loading: ', time.time()-start)
365
+
366
+
367
+ im_shp = self.dataset["im_shp"][idx]
368
+ # print("time for loading im and gt: ", time.time()-start)
369
+
370
+ # start_time = time.time()
371
+ im = torch.divide(im,255.0)
372
+ gt = torch.divide(gt,255.0)
373
+ # print(idx, 'time for normalize torch divide: ', time.time()-start_time)
374
+
375
+ sample = {
376
+ "imidx": torch.from_numpy(np.array(idx)),
377
+ "image": im,
378
+ "label": gt,
379
+ "shape": torch.from_numpy(np.array(im_shp)),
380
+ }
381
+
382
+ if self.transform:
383
+ sample = self.transform(sample)
384
+
385
+ return sample
hce_metric_main.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## hce_metric.py
2
+ import numpy as np
3
+ from skimage import io
4
+ import matplotlib.pyplot as plt
5
+ import cv2 as cv
6
+ from skimage.morphology import skeletonize
7
+ from skimage.morphology import erosion, dilation, disk
8
+ from skimage.measure import label
9
+
10
+ import os
11
+ import sys
12
+ from tqdm import tqdm
13
+ from glob import glob
14
+ import pickle as pkl
15
+
16
+ def filter_bdy_cond(bdy_, mask, cond):
17
+
18
+ cond = cv.dilate(cond.astype(np.uint8),disk(1))
19
+ labels = label(mask) # find the connected regions
20
+ lbls = np.unique(labels) # the indices of the connected regions
21
+ indep = np.ones(lbls.shape[0]) # the label of each connected regions
22
+ indep[0] = 0 # 0 indicate the background region
23
+
24
+ boundaries = []
25
+ h,w = cond.shape[0:2]
26
+ ind_map = np.zeros((h,w))
27
+ indep_cnt = 0
28
+
29
+ for i in range(0,len(bdy_)):
30
+ tmp_bdies = []
31
+ tmp_bdy = []
32
+ for j in range(0,bdy_[i].shape[0]):
33
+ r, c = bdy_[i][j,0,1],bdy_[i][j,0,0]
34
+
35
+ if(np.sum(cond[r,c])==0 or ind_map[r,c]!=0):
36
+ if(len(tmp_bdy)>0):
37
+ tmp_bdies.append(tmp_bdy)
38
+ tmp_bdy = []
39
+ continue
40
+ tmp_bdy.append([c,r])
41
+ ind_map[r,c] = ind_map[r,c] + 1
42
+ indep[labels[r,c]] = 0 # indicates part of the boundary of this region needs human correction
43
+ if(len(tmp_bdy)>0):
44
+ tmp_bdies.append(tmp_bdy)
45
+
46
+ # check if the first and the last boundaries are connected
47
+ # if yes, invert the first boundary and attach it after the last boundary
48
+ if(len(tmp_bdies)>1):
49
+ first_x, first_y = tmp_bdies[0][0]
50
+ last_x, last_y = tmp_bdies[-1][-1]
51
+ if((abs(first_x-last_x)==1 and first_y==last_y) or
52
+ (first_x==last_x and abs(first_y-last_y)==1) or
53
+ (abs(first_x-last_x)==1 and abs(first_y-last_y)==1)
54
+ ):
55
+ tmp_bdies[-1].extend(tmp_bdies[0][::-1])
56
+ del tmp_bdies[0]
57
+
58
+ for k in range(0,len(tmp_bdies)):
59
+ tmp_bdies[k] = np.array(tmp_bdies[k])[:,np.newaxis,:]
60
+ if(len(tmp_bdies)>0):
61
+ boundaries.extend(tmp_bdies)
62
+
63
+ return boundaries, np.sum(indep)
64
+
65
+ # this function approximate each boundary by DP algorithm
66
+ # https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm
67
+ def approximate_RDP(boundaries,epsilon=1.0):
68
+
69
+ boundaries_ = []
70
+ boundaries_len_ = []
71
+ pixel_cnt_ = 0
72
+
73
+ # polygon approximate of each boundary
74
+ for i in range(0,len(boundaries)):
75
+ boundaries_.append(cv.approxPolyDP(boundaries[i],epsilon,False))
76
+
77
+ # count the control points number of each boundary and the total control points number of all the boundaries
78
+ for i in range(0,len(boundaries_)):
79
+ boundaries_len_.append(len(boundaries_[i]))
80
+ pixel_cnt_ = pixel_cnt_ + len(boundaries_[i])
81
+
82
+ return boundaries_, boundaries_len_, pixel_cnt_
83
+
84
+
85
+ def relax_HCE(gt, rs, gt_ske, relax=5, epsilon=2.0):
86
+ # print("max(gt_ske): ", np.amax(gt_ske))
87
+ # gt_ske = gt_ske>128
88
+ # print("max(gt_ske): ", np.amax(gt_ske))
89
+
90
+ # Binarize gt
91
+ if(len(gt.shape)>2):
92
+ gt = gt[:,:,0]
93
+
94
+ epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0
95
+ gt = (gt>epsilon_gt).astype(np.uint8)
96
+
97
+ # Binarize rs
98
+ if(len(rs.shape)>2):
99
+ rs = rs[:,:,0]
100
+ epsilon_rs = 128#(np.amin(rs)+np.amax(rs))/2.0
101
+ rs = (rs>epsilon_rs).astype(np.uint8)
102
+
103
+ Union = np.logical_or(gt,rs)
104
+ TP = np.logical_and(gt,rs)
105
+ FP = rs - TP
106
+ FN = gt - TP
107
+
108
+ # relax the Union of gt and rs
109
+ Union_erode = Union.copy()
110
+ Union_erode = cv.erode(Union_erode.astype(np.uint8),disk(1),iterations=relax)
111
+
112
+ # --- get the relaxed False Positive regions for computing the human efforts in correcting them ---
113
+ FP_ = np.logical_and(FP,Union_erode) # get the relaxed FP
114
+ for i in range(0,relax):
115
+ FP_ = cv.dilate(FP_.astype(np.uint8),disk(1))
116
+ FP_ = np.logical_and(FP_, 1-np.logical_or(TP,FN))
117
+ FP_ = np.logical_and(FP, FP_)
118
+
119
+ # --- get the relaxed False Negative regions for computing the human efforts in correcting them ---
120
+ FN_ = np.logical_and(FN,Union_erode) # preserve the structural components of FN
121
+ ## recover the FN, where pixels are not close to the TP borders
122
+ for i in range(0,relax):
123
+ FN_ = cv.dilate(FN_.astype(np.uint8),disk(1))
124
+ FN_ = np.logical_and(FN_,1-np.logical_or(TP,FP))
125
+ FN_ = np.logical_and(FN,FN_)
126
+ FN_ = np.logical_or(FN_, np.logical_xor(gt_ske,np.logical_and(TP,gt_ske))) # preserve the structural components of FN
127
+
128
+ ## 2. =============Find exact polygon control points and independent regions==============
129
+ ## find contours from FP_
130
+ ctrs_FP, hier_FP = cv.findContours(FP_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
131
+ ## find control points and independent regions for human correction
132
+ bdies_FP, indep_cnt_FP = filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_))
133
+ ## find contours from FN_
134
+ ctrs_FN, hier_FN = cv.findContours(FN_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
135
+ ## find control points and independent regions for human correction
136
+ bdies_FN, indep_cnt_FN = filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP,FP_),FN_))
137
+
138
+ poly_FP, poly_FP_len, poly_FP_point_cnt = approximate_RDP(bdies_FP,epsilon=epsilon)
139
+ poly_FN, poly_FN_len, poly_FN_point_cnt = approximate_RDP(bdies_FN,epsilon=epsilon)
140
+
141
+ return poly_FP_point_cnt, indep_cnt_FP, poly_FN_point_cnt, indep_cnt_FN
142
+
143
+ def compute_hce(pred_root,gt_root,gt_ske_root):
144
+
145
+ gt_name_list = glob(pred_root+'/*.png')
146
+ gt_name_list = sorted([x.split('/')[-1] for x in gt_name_list])
147
+
148
+ hces = []
149
+ for gt_name in tqdm(gt_name_list, total=len(gt_name_list)):
150
+ gt_path = os.path.join(gt_root, gt_name)
151
+ pred_path = os.path.join(pred_root, gt_name)
152
+
153
+ gt = cv.imread(gt_path, cv.IMREAD_GRAYSCALE)
154
+ pred = cv.imread(pred_path, cv.IMREAD_GRAYSCALE)
155
+
156
+ ske_path = os.path.join(gt_ske_root,gt_name)
157
+ if os.path.exists(ske_path):
158
+ ske = cv.imread(ske_path,cv.IMREAD_GRAYSCALE)
159
+ ske = ske>128
160
+ else:
161
+ ske = skeletonize(gt>128)
162
+
163
+ FP_points, FP_indep, FN_points, FN_indep = relax_HCE(gt, pred,ske)
164
+ print(gt_path.split('/')[-1],FP_points, FP_indep, FN_points, FN_indep)
165
+ hces.append([FP_points, FP_indep, FN_points, FN_indep, FP_points+FP_indep+FN_points+FN_indep])
166
+
167
+ hce_metric ={'names': gt_name_list,
168
+ 'hces': hces}
169
+
170
+
171
+ file_metric = open(pred_root+'/hce_metric.pkl','wb')
172
+ pkl.dump(hce_metric,file_metric)
173
+ # file_metrics.write(cmn_metrics)
174
+ file_metric.close()
175
+
176
+ return np.mean(np.array(hces)[:,-1])
177
+
178
+ def main():
179
+
180
+ gt_root = "../DIS5K/DIS-VD/gt"
181
+ gt_ske_root = ""
182
+ pred_root = "../Results/isnet(ours)/DIS-VD"
183
+
184
+ print("The average HCE metric: ", compute_hce(pred_root,gt_root,gt_ske_root))
185
+
186
+
187
+ if __name__ == '__main__':
188
+ main()
pytorch18.yml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pytorch18
2
+ channels:
3
+ - conda-forge
4
+ - anaconda
5
+ - pytorch
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=4.5=1_gnu
10
+ - blas=1.0=mkl
11
+ - brotli=1.0.9=he6710b0_2
12
+ - bzip2=1.0.8=h7b6447c_0
13
+ - ca-certificates=2022.2.1=h06a4308_0
14
+ - certifi=2021.10.8=py37h06a4308_2
15
+ - cloudpickle=2.0.0=pyhd3eb1b0_0
16
+ - colorama=0.4.4=pyhd3eb1b0_0
17
+ - cudatoolkit=10.2.89=hfd86e86_1
18
+ - cycler=0.11.0=pyhd3eb1b0_0
19
+ - cytoolz=0.11.0=py37h7b6447c_0
20
+ - dask-core=2021.10.0=pyhd3eb1b0_0
21
+ - ffmpeg=4.3=hf484d3e_0
22
+ - fonttools=4.25.0=pyhd3eb1b0_0
23
+ - freetype=2.11.0=h70c0345_0
24
+ - fsspec=2022.2.0=pyhd3eb1b0_0
25
+ - gmp=6.2.1=h2531618_2
26
+ - gnutls=3.6.15=he1e5248_0
27
+ - imageio=2.9.0=pyhd3eb1b0_0
28
+ - intel-openmp=2021.4.0=h06a4308_3561
29
+ - jpeg=9b=h024ee3a_2
30
+ - kiwisolver=1.3.2=py37h295c915_0
31
+ - lame=3.100=h7b6447c_0
32
+ - lcms2=2.12=h3be6417_0
33
+ - ld_impl_linux-64=2.35.1=h7274673_9
34
+ - libffi=3.3=he6710b0_2
35
+ - libgcc-ng=9.3.0=h5101ec6_17
36
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
37
+ - libgfortran4=7.5.0=ha8ba4b0_17
38
+ - libgomp=9.3.0=h5101ec6_17
39
+ - libiconv=1.15=h63c8f33_5
40
+ - libidn2=2.3.2=h7f8727e_0
41
+ - libpng=1.6.37=hbc83047_0
42
+ - libstdcxx-ng=9.3.0=hd4cf53a_17
43
+ - libtasn1=4.16.0=h27cfd23_0
44
+ - libtiff=4.2.0=h85742a9_0
45
+ - libunistring=0.9.10=h27cfd23_0
46
+ - libuv=1.40.0=h7b6447c_0
47
+ - libwebp-base=1.2.2=h7f8727e_0
48
+ - locket=0.2.1=py37h06a4308_2
49
+ - lz4-c=1.9.3=h295c915_1
50
+ - matplotlib-base=3.5.1=py37ha18d171_1
51
+ - mkl=2021.4.0=h06a4308_640
52
+ - mkl-service=2.4.0=py37h7f8727e_0
53
+ - mkl_fft=1.3.1=py37hd3c417c_0
54
+ - mkl_random=1.2.2=py37h51133e4_0
55
+ - munkres=1.1.4=py_0
56
+ - ncurses=6.3=h7f8727e_2
57
+ - nettle=3.7.3=hbbd107a_1
58
+ - networkx=2.6.3=pyhd3eb1b0_0
59
+ - ninja=1.10.2=py37hd09550d_3
60
+ - numpy=1.21.2=py37h20f2e39_0
61
+ - numpy-base=1.21.2=py37h79a1101_0
62
+ - olefile=0.46=py37_0
63
+ - openh264=2.1.1=h4ff587b_0
64
+ - openssl=1.1.1n=h7f8727e_0
65
+ - packaging=21.3=pyhd3eb1b0_0
66
+ - partd=1.2.0=pyhd3eb1b0_1
67
+ - pillow=8.0.0=py37h9a89aac_0
68
+ - pip=21.2.2=py37h06a4308_0
69
+ - pyparsing=3.0.4=pyhd3eb1b0_0
70
+ - python=3.7.11=h12debd9_0
71
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
72
+ - pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
73
+ - pywavelets=1.1.1=py37h7b6447c_2
74
+ - pyyaml=6.0=py37h7f8727e_1
75
+ - readline=8.1.2=h7f8727e_1
76
+ - scikit-image=0.15.0=py37hb3f55d8_2
77
+ - scipy=1.7.3=py37hc147768_0
78
+ - setuptools=58.0.4=py37h06a4308_0
79
+ - six=1.16.0=pyhd3eb1b0_1
80
+ - sqlite=3.38.0=hc218d9a_0
81
+ - tk=8.6.11=h1ccaba5_0
82
+ - toolz=0.11.2=pyhd3eb1b0_0
83
+ - torchaudio=0.8.0=py37
84
+ - torchvision=0.9.0=py37_cu102
85
+ - tqdm=4.63.0=pyhd8ed1ab_0
86
+ - typing_extensions=3.10.0.2=pyh06a4308_0
87
+ - wheel=0.37.1=pyhd3eb1b0_0
88
+ - xz=5.2.5=h7b6447c_0
89
+ - yaml=0.2.5=h7b6447c_0
90
+ - zlib=1.2.11=h7f8727e_4
91
+ - zstd=1.4.9=haebb681_0
92
+ prefix: /home/solar/anaconda3/envs/pytorch18
requirements.txt CHANGED
@@ -1,4 +1,87 @@
1
- diffusers==0.14.0
2
- safetensors
3
- opencv-python
4
- controlnet_hinter==0.0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ _libgcc_mutex=0.1=main
5
+ _openmp_mutex=4.5=1_gnu
6
+ blas=1.0=mkl
7
+ brotli=1.0.9=he6710b0_2
8
+ bzip2=1.0.8=h7b6447c_0
9
+ ca-certificates=2022.2.1=h06a4308_0
10
+ certifi=2021.10.8=py37h06a4308_2
11
+ cloudpickle=2.0.0=pyhd3eb1b0_0
12
+ colorama=0.4.4=pyhd3eb1b0_0
13
+ cudatoolkit=10.2.89=hfd86e86_1
14
+ cycler=0.11.0=pyhd3eb1b0_0
15
+ cytoolz=0.11.0=py37h7b6447c_0
16
+ dask-core=2021.10.0=pyhd3eb1b0_0
17
+ ffmpeg=4.3=hf484d3e_0
18
+ fonttools=4.25.0=pyhd3eb1b0_0
19
+ freetype=2.11.0=h70c0345_0
20
+ fsspec=2022.2.0=pyhd3eb1b0_0
21
+ gmp=6.2.1=h2531618_2
22
+ gnutls=3.6.15=he1e5248_0
23
+ imageio=2.9.0=pyhd3eb1b0_0
24
+ intel-openmp=2021.4.0=h06a4308_3561
25
+ jpeg=9b=h024ee3a_2
26
+ kiwisolver=1.3.2=py37h295c915_0
27
+ lame=3.100=h7b6447c_0
28
+ lcms2=2.12=h3be6417_0
29
+ ld_impl_linux-64=2.35.1=h7274673_9
30
+ libffi=3.3=he6710b0_2
31
+ libgcc-ng=9.3.0=h5101ec6_17
32
+ libgfortran-ng=7.5.0=ha8ba4b0_17
33
+ libgfortran4=7.5.0=ha8ba4b0_17
34
+ libgomp=9.3.0=h5101ec6_17
35
+ libiconv=1.15=h63c8f33_5
36
+ libidn2=2.3.2=h7f8727e_0
37
+ libpng=1.6.37=hbc83047_0
38
+ libstdcxx-ng=9.3.0=hd4cf53a_17
39
+ libtasn1=4.16.0=h27cfd23_0
40
+ libtiff=4.2.0=h85742a9_0
41
+ libunistring=0.9.10=h27cfd23_0
42
+ libuv=1.40.0=h7b6447c_0
43
+ libwebp-base=1.2.2=h7f8727e_0
44
+ locket=0.2.1=py37h06a4308_2
45
+ lz4-c=1.9.3=h295c915_1
46
+ matplotlib-base=3.5.1=py37ha18d171_1
47
+ mkl=2021.4.0=h06a4308_640
48
+ mkl-service=2.4.0=py37h7f8727e_0
49
+ mkl_fft=1.3.1=py37hd3c417c_0
50
+ mkl_random=1.2.2=py37h51133e4_0
51
+ munkres=1.1.4=py_0
52
+ ncurses=6.3=h7f8727e_2
53
+ nettle=3.7.3=hbbd107a_1
54
+ networkx=2.6.3=pyhd3eb1b0_0
55
+ ninja=1.10.2=py37hd09550d_3
56
+ numpy=1.21.2=py37h20f2e39_0
57
+ numpy-base=1.21.2=py37h79a1101_0
58
+ olefile=0.46=py37_0
59
+ openh264=2.1.1=h4ff587b_0
60
+ openssl=1.1.1n=h7f8727e_0
61
+ packaging=21.3=pyhd3eb1b0_0
62
+ partd=1.2.0=pyhd3eb1b0_1
63
+ pillow=8.0.0=py37h9a89aac_0
64
+ pip=21.2.2=py37h06a4308_0
65
+ pyparsing=3.0.4=pyhd3eb1b0_0
66
+ python=3.7.11=h12debd9_0
67
+ python-dateutil=2.8.2=pyhd3eb1b0_0
68
+ pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
69
+ pywavelets=1.1.1=py37h7b6447c_2
70
+ pyyaml=6.0=py37h7f8727e_1
71
+ readline=8.1.2=h7f8727e_1
72
+ scikit-image=0.15.0=py37hb3f55d8_2
73
+ scipy=1.7.3=py37hc147768_0
74
+ setuptools=58.0.4=py37h06a4308_0
75
+ six=1.16.0=pyhd3eb1b0_1
76
+ sqlite=3.38.0=hc218d9a_0
77
+ tk=8.6.11=h1ccaba5_0
78
+ toolz=0.11.2=pyhd3eb1b0_0
79
+ torchaudio=0.8.0=py37
80
+ torchvision=0.9.0=py37_cu102
81
+ tqdm=4.63.0=pyhd8ed1ab_0
82
+ typing_extensions=3.10.0.2=pyh06a4308_0
83
+ wheel=0.37.1=pyhd3eb1b0_0
84
+ xz=5.2.5=h7b6447c_0
85
+ yaml=0.2.5=h7b6447c_0
86
+ zlib=1.2.11=h7f8727e_4
87
+ zstd=1.4.9=haebb681_0
train_valid_inference_main.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from skimage import io
5
+ import time
6
+
7
+ import torch, gc
8
+ import torch.nn as nn
9
+ from torch.autograd import Variable
10
+ import torch.optim as optim
11
+ import torch.nn.functional as F
12
+
13
+ from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
14
+ from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
15
+ from models import *
16
+
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
+ def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
20
+
21
+ # train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
22
+ # cache_size = hypar["cache_size"],
23
+ # cache_boost = hypar["cache_boost_train"],
24
+ # my_transforms = [
25
+ # GOSRandomHFlip(),
26
+ # # GOSResize(hypar["input_size"]),
27
+ # # GOSRandomCrop(hypar["crop_size"]),
28
+ # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
29
+ # ],
30
+ # batch_size = hypar["batch_size_train"],
31
+ # shuffle = True)
32
+
33
+ torch.manual_seed(hypar["seed"])
34
+ if torch.cuda.is_available():
35
+ torch.cuda.manual_seed(hypar["seed"])
36
+
37
+ print("define gt encoder ...")
38
+ net = ISNetGTEncoder() #UNETGTENCODERCombine()
39
+ ## load the existing model gt encoder
40
+ if(hypar["gt_encoder_model"]!=""):
41
+ model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
42
+ if torch.cuda.is_available():
43
+ net.load_state_dict(torch.load(model_path))
44
+ net.cuda()
45
+ else:
46
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
47
+ print("gt encoder restored from the saved weights ...")
48
+ return net ############
49
+
50
+ if torch.cuda.is_available():
51
+ net.cuda()
52
+
53
+ print("--- define optimizer for GT Encoder---")
54
+ optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
55
+
56
+ model_path = hypar["model_path"]
57
+ model_save_fre = hypar["model_save_fre"]
58
+ max_ite = hypar["max_ite"]
59
+ batch_size_train = hypar["batch_size_train"]
60
+ batch_size_valid = hypar["batch_size_valid"]
61
+
62
+ if(not os.path.exists(model_path)):
63
+ os.mkdir(model_path)
64
+
65
+ ite_num = hypar["start_ite"] # count the total iteration number
66
+ ite_num4val = 0 #
67
+ running_loss = 0.0 # count the toal loss
68
+ running_tar_loss = 0.0 # count the target output loss
69
+ last_f1 = [0 for x in range(len(valid_dataloaders))]
70
+
71
+ train_num = train_datasets[0].__len__()
72
+
73
+ net.train()
74
+
75
+ start_last = time.time()
76
+ gos_dataloader = train_dataloaders[0]
77
+ epoch_num = hypar["max_epoch_num"]
78
+ notgood_cnt = 0
79
+ for epoch in range(epoch_num): ## set the epoch num as 100000
80
+
81
+ for i, data in enumerate(gos_dataloader):
82
+
83
+ if(ite_num >= max_ite):
84
+ print("Training Reached the Maximal Iteration Number ", max_ite)
85
+ exit()
86
+
87
+ # start_read = time.time()
88
+ ite_num = ite_num + 1
89
+ ite_num4val = ite_num4val + 1
90
+
91
+ # get the inputs
92
+ labels = data['label']
93
+
94
+ if(hypar["model_digit"]=="full"):
95
+ labels = labels.type(torch.FloatTensor)
96
+ else:
97
+ labels = labels.type(torch.HalfTensor)
98
+
99
+ # wrap them in Variable
100
+ if torch.cuda.is_available():
101
+ labels_v = Variable(labels.cuda(), requires_grad=False)
102
+ else:
103
+ labels_v = Variable(labels, requires_grad=False)
104
+
105
+ # print("time lapse for data preparation: ", time.time()-start_read, ' s')
106
+
107
+ # y zero the parameter gradients
108
+ start_inf_loss_back = time.time()
109
+ optimizer.zero_grad()
110
+
111
+ ds, fs = net(labels_v)#net(inputs_v)
112
+ loss2, loss = net.compute_loss(ds, labels_v)
113
+
114
+ loss.backward()
115
+ optimizer.step()
116
+
117
+ running_loss += loss.item()
118
+ running_tar_loss += loss2.item()
119
+
120
+ # del outputs, loss
121
+ del ds, loss2, loss
122
+ end_inf_loss_back = time.time()-start_inf_loss_back
123
+
124
+ print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
125
+ epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
126
+ start_last = time.time()
127
+
128
+ if ite_num % model_save_fre == 0: # validate every 2000 iterations
129
+ notgood_cnt += 1
130
+ # net.eval()
131
+ # tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
132
+ tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
133
+
134
+ net.train() # resume train
135
+
136
+ tmp_out = 0
137
+ print("last_f1:",last_f1)
138
+ print("tmp_f1:",tmp_f1)
139
+ for fi in range(len(last_f1)):
140
+ if(tmp_f1[fi]>last_f1[fi]):
141
+ tmp_out = 1
142
+ print("tmp_out:",tmp_out)
143
+ if(tmp_out):
144
+ notgood_cnt = 0
145
+ last_f1 = tmp_f1
146
+ tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
147
+ tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
148
+ maxf1 = '_'.join(tmp_f1_str)
149
+ meanM = '_'.join(tmp_mae_str)
150
+ # .cpu().detach().numpy()
151
+ model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
152
+ "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
153
+ "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
154
+ "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
155
+ "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
156
+ "_maxF1_" + maxf1 + \
157
+ "_mae_" + meanM + \
158
+ "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
159
+ torch.save(net.state_dict(), model_path + model_name)
160
+
161
+ running_loss = 0.0
162
+ running_tar_loss = 0.0
163
+ ite_num4val = 0
164
+
165
+ if(tmp_f1[0]>0.99):
166
+ print("GT encoder is well-trained and obtained...")
167
+ return net
168
+
169
+ if(notgood_cnt >= hypar["early_stop"]):
170
+ print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
171
+ exit()
172
+
173
+ print("Training Reaches The Maximum Epoch Number")
174
+ return net
175
+
176
+ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
177
+ net.eval()
178
+ print("Validating...")
179
+ epoch_num = hypar["max_epoch_num"]
180
+
181
+ val_loss = 0.0
182
+ tar_loss = 0.0
183
+
184
+
185
+ tmp_f1 = []
186
+ tmp_mae = []
187
+ tmp_time = []
188
+
189
+ start_valid = time.time()
190
+ for k in range(len(valid_dataloaders)):
191
+
192
+ valid_dataloader = valid_dataloaders[k]
193
+ valid_dataset = valid_datasets[k]
194
+
195
+ val_num = valid_dataset.__len__()
196
+ mybins = np.arange(0,256)
197
+ PRE = np.zeros((val_num,len(mybins)-1))
198
+ REC = np.zeros((val_num,len(mybins)-1))
199
+ F1 = np.zeros((val_num,len(mybins)-1))
200
+ MAE = np.zeros((val_num))
201
+
202
+ val_cnt = 0.0
203
+ i_val = None
204
+
205
+ for i_val, data_val in enumerate(valid_dataloader):
206
+
207
+ # imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
208
+ imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
209
+
210
+ if(hypar["model_digit"]=="full"):
211
+ labels_val = labels_val.type(torch.FloatTensor)
212
+ else:
213
+ labels_val = labels_val.type(torch.HalfTensor)
214
+
215
+ # wrap them in Variable
216
+ if torch.cuda.is_available():
217
+ labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
218
+ else:
219
+ labels_val_v = Variable(labels_val,requires_grad=False)
220
+
221
+ t_start = time.time()
222
+ ds_val = net(labels_val_v)[0]
223
+ t_end = time.time()-t_start
224
+ tmp_time.append(t_end)
225
+
226
+ # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
227
+ loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
228
+
229
+ # compute F measure
230
+ for t in range(hypar["batch_size_valid"]):
231
+ val_cnt = val_cnt + 1.0
232
+ print("num of val: ", val_cnt)
233
+ i_test = imidx_val[t].data.numpy()
234
+
235
+ pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
236
+
237
+ ## recover the prediction spatial size to the orignal image size
238
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
239
+
240
+ ma = torch.max(pred_val)
241
+ mi = torch.min(pred_val)
242
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
243
+ # pred_val = normPRED(pred_val)
244
+
245
+ gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
246
+ if gt.max()==1:
247
+ gt=gt*255
248
+ with torch.no_grad():
249
+ gt = torch.tensor(gt).to(device)
250
+
251
+ pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
252
+
253
+ PRE[i_test,:]=pre
254
+ REC[i_test,:] = rec
255
+ F1[i_test,:] = f1
256
+ MAE[i_test] = mae
257
+
258
+ del ds_val, gt
259
+ gc.collect()
260
+ torch.cuda.empty_cache()
261
+
262
+ # if(loss_val.data[0]>1):
263
+ val_loss += loss_val.item()#data[0]
264
+ tar_loss += loss2_val.item()#data[0]
265
+
266
+ print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
267
+
268
+ del loss2_val, loss_val
269
+
270
+ print('============================')
271
+ PRE_m = np.mean(PRE,0)
272
+ REC_m = np.mean(REC,0)
273
+ f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
274
+ # print('--------------:', np.mean(f1_m))
275
+ tmp_f1.append(np.amax(f1_m))
276
+ tmp_mae.append(np.mean(MAE))
277
+ print("The max F1 Score: %f"%(np.max(f1_m)))
278
+ print("MAE: ", np.mean(MAE))
279
+
280
+ # print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
281
+
282
+ return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
283
+
284
+ def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
285
+
286
+ if hypar["interm_sup"]:
287
+ print("Get the gt encoder ...")
288
+ featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
289
+ ## freeze the weights of gt encoder
290
+ for param in featurenet.parameters():
291
+ param.requires_grad=False
292
+
293
+
294
+ model_path = hypar["model_path"]
295
+ model_save_fre = hypar["model_save_fre"]
296
+ max_ite = hypar["max_ite"]
297
+ batch_size_train = hypar["batch_size_train"]
298
+ batch_size_valid = hypar["batch_size_valid"]
299
+
300
+ if(not os.path.exists(model_path)):
301
+ os.mkdir(model_path)
302
+
303
+ ite_num = hypar["start_ite"] # count the toal iteration number
304
+ ite_num4val = 0 #
305
+ running_loss = 0.0 # count the toal loss
306
+ running_tar_loss = 0.0 # count the target output loss
307
+ last_f1 = [0 for x in range(len(valid_dataloaders))]
308
+
309
+ train_num = train_datasets[0].__len__()
310
+
311
+ net.train()
312
+
313
+ start_last = time.time()
314
+ gos_dataloader = train_dataloaders[0]
315
+ epoch_num = hypar["max_epoch_num"]
316
+ notgood_cnt = 0
317
+ for epoch in range(epoch_num): ## set the epoch num as 100000
318
+
319
+ for i, data in enumerate(gos_dataloader):
320
+
321
+ if(ite_num >= max_ite):
322
+ print("Training Reached the Maximal Iteration Number ", max_ite)
323
+ exit()
324
+
325
+ # start_read = time.time()
326
+ ite_num = ite_num + 1
327
+ ite_num4val = ite_num4val + 1
328
+
329
+ # get the inputs
330
+ inputs, labels = data['image'], data['label']
331
+
332
+ if(hypar["model_digit"]=="full"):
333
+ inputs = inputs.type(torch.FloatTensor)
334
+ labels = labels.type(torch.FloatTensor)
335
+ else:
336
+ inputs = inputs.type(torch.HalfTensor)
337
+ labels = labels.type(torch.HalfTensor)
338
+
339
+ # wrap them in Variable
340
+ if torch.cuda.is_available():
341
+ inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
342
+ else:
343
+ inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
344
+
345
+ # print("time lapse for data preparation: ", time.time()-start_read, ' s')
346
+
347
+ # y zero the parameter gradients
348
+ start_inf_loss_back = time.time()
349
+ optimizer.zero_grad()
350
+
351
+ if hypar["interm_sup"]:
352
+ # forward + backward + optimize
353
+ ds,dfs = net(inputs_v)
354
+ _,fs = featurenet(labels_v) ## extract the gt encodings
355
+ loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
356
+ else:
357
+ # forward + backward + optimize
358
+ ds,_ = net(inputs_v)
359
+ loss2, loss = net.compute_loss(ds, labels_v)
360
+
361
+ loss.backward()
362
+ optimizer.step()
363
+
364
+ # # print statistics
365
+ running_loss += loss.item()
366
+ running_tar_loss += loss2.item()
367
+
368
+ # del outputs, loss
369
+ del ds, loss2, loss
370
+ end_inf_loss_back = time.time()-start_inf_loss_back
371
+
372
+ print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
373
+ epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
374
+ start_last = time.time()
375
+
376
+ if ite_num % model_save_fre == 0: # validate every 2000 iterations
377
+ notgood_cnt += 1
378
+ net.eval()
379
+ tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
380
+ net.train() # resume train
381
+
382
+ tmp_out = 0
383
+ print("last_f1:",last_f1)
384
+ print("tmp_f1:",tmp_f1)
385
+ for fi in range(len(last_f1)):
386
+ if(tmp_f1[fi]>last_f1[fi]):
387
+ tmp_out = 1
388
+ print("tmp_out:",tmp_out)
389
+ if(tmp_out):
390
+ notgood_cnt = 0
391
+ last_f1 = tmp_f1
392
+ tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
393
+ tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
394
+ maxf1 = '_'.join(tmp_f1_str)
395
+ meanM = '_'.join(tmp_mae_str)
396
+ # .cpu().detach().numpy()
397
+ model_name = "/gpu_itr_"+str(ite_num)+\
398
+ "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
399
+ "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
400
+ "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
401
+ "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
402
+ "_maxF1_" + maxf1 + \
403
+ "_mae_" + meanM + \
404
+ "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
405
+ torch.save(net.state_dict(), model_path + model_name)
406
+
407
+ running_loss = 0.0
408
+ running_tar_loss = 0.0
409
+ ite_num4val = 0
410
+
411
+ if(notgood_cnt >= hypar["early_stop"]):
412
+ print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
413
+ exit()
414
+
415
+ print("Training Reaches The Maximum Epoch Number")
416
+
417
+ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
418
+ net.eval()
419
+ print("Validating...")
420
+ epoch_num = hypar["max_epoch_num"]
421
+
422
+ val_loss = 0.0
423
+ tar_loss = 0.0
424
+ val_cnt = 0.0
425
+
426
+ tmp_f1 = []
427
+ tmp_mae = []
428
+ tmp_time = []
429
+
430
+ start_valid = time.time()
431
+
432
+ for k in range(len(valid_dataloaders)):
433
+
434
+ valid_dataloader = valid_dataloaders[k]
435
+ valid_dataset = valid_datasets[k]
436
+
437
+ val_num = valid_dataset.__len__()
438
+ mybins = np.arange(0,256)
439
+ PRE = np.zeros((val_num,len(mybins)-1))
440
+ REC = np.zeros((val_num,len(mybins)-1))
441
+ F1 = np.zeros((val_num,len(mybins)-1))
442
+ MAE = np.zeros((val_num))
443
+
444
+ for i_val, data_val in enumerate(valid_dataloader):
445
+ val_cnt = val_cnt + 1.0
446
+ imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
447
+
448
+ if(hypar["model_digit"]=="full"):
449
+ inputs_val = inputs_val.type(torch.FloatTensor)
450
+ labels_val = labels_val.type(torch.FloatTensor)
451
+ else:
452
+ inputs_val = inputs_val.type(torch.HalfTensor)
453
+ labels_val = labels_val.type(torch.HalfTensor)
454
+
455
+ # wrap them in Variable
456
+ if torch.cuda.is_available():
457
+ inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
458
+ else:
459
+ inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
460
+
461
+ t_start = time.time()
462
+ ds_val = net(inputs_val_v)[0]
463
+ t_end = time.time()-t_start
464
+ tmp_time.append(t_end)
465
+
466
+ # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
467
+ loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
468
+
469
+ # compute F measure
470
+ for t in range(hypar["batch_size_valid"]):
471
+ i_test = imidx_val[t].data.numpy()
472
+
473
+ pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
474
+
475
+ ## recover the prediction spatial size to the orignal image size
476
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
477
+
478
+ # pred_val = normPRED(pred_val)
479
+ ma = torch.max(pred_val)
480
+ mi = torch.min(pred_val)
481
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
482
+
483
+ if len(valid_dataset.dataset["ori_gt_path"]) != 0:
484
+ gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
485
+ if gt.max()==1:
486
+ gt=gt*255
487
+ else:
488
+ gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
489
+ with torch.no_grad():
490
+ gt = torch.tensor(gt).to(device)
491
+
492
+ pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
493
+
494
+
495
+ PRE[i_test,:]=pre
496
+ REC[i_test,:] = rec
497
+ F1[i_test,:] = f1
498
+ MAE[i_test] = mae
499
+
500
+ del ds_val, gt
501
+ gc.collect()
502
+ torch.cuda.empty_cache()
503
+
504
+ # if(loss_val.data[0]>1):
505
+ val_loss += loss_val.item()#data[0]
506
+ tar_loss += loss2_val.item()#data[0]
507
+
508
+ print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
509
+
510
+ del loss2_val, loss_val
511
+
512
+ print('============================')
513
+ PRE_m = np.mean(PRE,0)
514
+ REC_m = np.mean(REC,0)
515
+ f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
516
+
517
+ tmp_f1.append(np.amax(f1_m))
518
+ tmp_mae.append(np.mean(MAE))
519
+
520
+ return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
521
+
522
+ def main(train_datasets,
523
+ valid_datasets,
524
+ hypar): # model: "train", "test"
525
+
526
+ ### --- Step 1: Build datasets and dataloaders ---
527
+ dataloaders_train = []
528
+ dataloaders_valid = []
529
+
530
+ if(hypar["mode"]=="train"):
531
+ print("--- create training dataloader ---")
532
+ ## collect training dataset
533
+ train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
534
+ ## build dataloader for training datasets
535
+ train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
536
+ cache_size = hypar["cache_size"],
537
+ cache_boost = hypar["cache_boost_train"],
538
+ my_transforms = [
539
+ GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
540
+ # GOSResize(hypar["input_size"]),
541
+ # GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
542
+ GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
543
+ ],
544
+ batch_size = hypar["batch_size_train"],
545
+ shuffle = True)
546
+ train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
547
+ cache_size = hypar["cache_size"],
548
+ cache_boost = hypar["cache_boost_train"],
549
+ my_transforms = [
550
+ GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
551
+ ],
552
+ batch_size = hypar["batch_size_valid"],
553
+ shuffle = False)
554
+ print(len(train_dataloaders), " train dataloaders created")
555
+
556
+ print("--- create valid dataloader ---")
557
+ ## build dataloader for validation or testing
558
+ valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
559
+ ## build dataloader for training datasets
560
+ valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
561
+ cache_size = hypar["cache_size"],
562
+ cache_boost = hypar["cache_boost_valid"],
563
+ my_transforms = [
564
+ GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
565
+ # GOSResize(hypar["input_size"])
566
+ ],
567
+ batch_size=hypar["batch_size_valid"],
568
+ shuffle=False)
569
+ print(len(valid_dataloaders), " valid dataloaders created")
570
+ # print(valid_datasets[0]["data_name"])
571
+
572
+ ### --- Step 2: Build Model and Optimizer ---
573
+ print("--- build model ---")
574
+ net = hypar["model"]#GOSNETINC(3,1)
575
+
576
+ # convert to half precision
577
+ if(hypar["model_digit"]=="half"):
578
+ net.half()
579
+ for layer in net.modules():
580
+ if isinstance(layer, nn.BatchNorm2d):
581
+ layer.float()
582
+
583
+ if torch.cuda.is_available():
584
+ net.cuda()
585
+
586
+ if(hypar["restore_model"]!=""):
587
+ print("restore model from:")
588
+ print(hypar["model_path"]+"/"+hypar["restore_model"])
589
+ if torch.cuda.is_available():
590
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]))
591
+ else:
592
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"))
593
+
594
+ print("--- define optimizer ---")
595
+ optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
596
+
597
+ ### --- Step 3: Train or Valid Model ---
598
+ if(hypar["mode"]=="train"):
599
+ train(net,
600
+ optimizer,
601
+ train_dataloaders,
602
+ train_datasets,
603
+ valid_dataloaders,
604
+ valid_datasets,
605
+ hypar,
606
+ train_dataloaders_val, train_datasets_val)
607
+ else:
608
+ valid(net,
609
+ valid_dataloaders,
610
+ valid_datasets,
611
+ hypar)
612
+
613
+
614
+ if __name__ == "__main__":
615
+
616
+ ### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
617
+ ## configure the train, valid and inference datasets
618
+ train_datasets, valid_datasets = [], []
619
+ dataset_1, dataset_1 = {}, {}
620
+
621
+ dataset_tr = {"name": "DIS5K-TR",
622
+ "im_dir": "../DIS5K/DIS-TR/im",
623
+ "gt_dir": "../DIS5K/DIS-TR/gt",
624
+ "im_ext": ".jpg",
625
+ "gt_ext": ".png",
626
+ "cache_dir":"../DIS5K-Cache/DIS-TR"}
627
+
628
+ dataset_vd = {"name": "DIS5K-VD",
629
+ "im_dir": "../DIS5K/DIS-VD/im",
630
+ "gt_dir": "../DIS5K/DIS-VD/gt",
631
+ "im_ext": ".jpg",
632
+ "gt_ext": ".png",
633
+ "cache_dir":"../DIS5K-Cache/DIS-VD"}
634
+
635
+ dataset_te1 = {"name": "DIS5K-TE1",
636
+ "im_dir": "../DIS5K/DIS-TE1/im",
637
+ "gt_dir": "../DIS5K/DIS-TE1/gt",
638
+ "im_ext": ".jpg",
639
+ "gt_ext": ".png",
640
+ "cache_dir":"../DIS5K-Cache/DIS-TE1"}
641
+
642
+ dataset_te2 = {"name": "DIS5K-TE2",
643
+ "im_dir": "../DIS5K/DIS-TE2/im",
644
+ "gt_dir": "../DIS5K/DIS-TE2/gt",
645
+ "im_ext": ".jpg",
646
+ "gt_ext": ".png",
647
+ "cache_dir":"../DIS5K-Cache/DIS-TE2"}
648
+
649
+ dataset_te3 = {"name": "DIS5K-TE3",
650
+ "im_dir": "../DIS5K/DIS-TE3/im",
651
+ "gt_dir": "../DIS5K/DIS-TE3/gt",
652
+ "im_ext": ".jpg",
653
+ "gt_ext": ".png",
654
+ "cache_dir":"../DIS5K-Cache/DIS-TE3"}
655
+
656
+ dataset_te4 = {"name": "DIS5K-TE4",
657
+ "im_dir": "../DIS5K/DIS-TE4/im",
658
+ "gt_dir": "../DIS5K/DIS-TE4/gt",
659
+ "im_ext": ".jpg",
660
+ "gt_ext": ".png",
661
+ "cache_dir":"../DIS5K-Cache/DIS-TE4"}
662
+ ### test your own dataset
663
+ dataset_demo = {"name": "your-dataset",
664
+ "im_dir": "../your-dataset/im",
665
+ "gt_dir": "",
666
+ "im_ext": ".jpg",
667
+ "gt_ext": "",
668
+ "cache_dir":"../your-dataset/cache"}
669
+
670
+ train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
671
+ # valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
672
+ valid_datasets = [dataset_vd] # dataset_vd, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
673
+
674
+ ### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
675
+ hypar = {}
676
+
677
+ ## -- 2.1. configure the model saving or restoring path --
678
+ hypar["mode"] = "train"
679
+ ## "train": for training,
680
+ ## "valid": for validation and inferening,
681
+ ## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
682
+ ## otherwise only accuracy will be calculated and no predictions will be saved
683
+ hypar["interm_sup"] = False ## in-dicate if activate intermediate feature supervision
684
+
685
+ if hypar["mode"] == "train":
686
+ hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
687
+ hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
688
+ hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
689
+ hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
690
+ hypar["gt_encoder_model"] = ""
691
+ else: ## configure the segmentation output path and the to-be-used model weights path
692
+ hypar["valid_out_dir"] = "../your-results/"##"../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
693
+ hypar["model_path"] = "../saved_models/IS-Net" ## load trained weights from this path
694
+ hypar["restore_model"] = "isnet.pth"##"isnet.pth" ## name of the to-be-loaded weights
695
+
696
+ # if hypar["restore_model"]!="":
697
+ # hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
698
+
699
+ ## -- 2.2. choose floating point accuracy --
700
+ hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
701
+ hypar["seed"] = 0
702
+
703
+ ## -- 2.3. cache data spatial size --
704
+ ## To handle large size input images, which take a lot of time for loading in training,
705
+ # we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
706
+ hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
707
+ hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
708
+ hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
709
+
710
+ ## --- 2.4. data augmentation parameters ---
711
+ hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
712
+ hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
713
+ hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use
714
+ hypar["random_flip_v"] = 0 ## vertical flip , currently not in use
715
+
716
+ ## --- 2.5. define model ---
717
+ print("building model...")
718
+ hypar["model"] = ISNetDIS() #U2NETFASTFEATURESUP()
719
+ hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
720
+ hypar["model_save_fre"] = 2000 ## valid and save model weights every 2000 iterations
721
+
722
+ hypar["batch_size_train"] = 8 ## batch size for training
723
+ hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
724
+ print("batch size: ", hypar["batch_size_train"])
725
+
726
+ hypar["max_ite"] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
727
+ hypar["max_epoch_num"] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
728
+
729
+ main(train_datasets,
730
+ valid_datasets,
731
+ hypar=hypar)