Spaces:
Runtime error
Runtime error
File size: 7,657 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import json
import argparse
class BaseOptions():
def initialize(self, parser):
parser.add_argument('--nChannel', metavar='N', default=100, type=int,
help='number of channels')
parser.add_argument('--maxIter', metavar='T', default=1000, type=int,
help='number of maximum iterations')
parser.add_argument('--lr', metavar='LR', default=0.1, type=float,
help='learning rate')
parser.add_argument('--nConv', metavar='M', default=2, type=int,
help='number of convolutional layers')
parser.add_argument("--work_dir", type=str, default="/home/xli/WORKDIR",
help='project directory')
parser.add_argument("--out_dir", type=str, default=None,
help='logging output')
parser.add_argument("--use_wandb", type=int, default=0,
help='use wandb or not')
parser.add_argument("--data_path", type=str, default="/home/xli/DATA/BSR_processed/train_extend",
help="data path")
parser.add_argument("--img_path", type=str, default=None,
help="image path")
parser.add_argument('--crop_size', type=int, default= 224,
help='crop_size')
parser.add_argument("--batch_size", type=int, default=1,
help='batch size')
parser.add_argument('--workers', type=int, default=4,
help='number of data loading workers')
parser.add_argument("--use_slic", default = 1, type=int,
help="choose to use slic or gt label")
parser.add_argument("-f", "--config_file", type=str, default='models/week0417/json/single_scale_grouping_ft.json',
help='json files including all arguments')
parser.add_argument("--log_freq", type=int, default=10,
help='frequency to print log')
parser.add_argument("--display_freq", type=int, default=100,
help='frequency to save visualization')
parser.add_argument("--pretrained_ae", type=str,
default = "/home/xli/WORKDIR/07-16/transformer/cpk.pth")
parser.add_argument("--pretrained_path", type=str, default=None,
help='pretrained reconstruction model')
parser.add_argument('--momentum', type=float, default=0.5,
help='momentum for sgd, alpha parameter for adam')
parser.add_argument('--beta', type=float, default=0.999,
help='beta parameter for adam')
parser.add_argument("--l1_loss_wt", default=1.0, type=float)
parser.add_argument("--perceptual_loss_wt", default=1.0, type=float)
parser.add_argument('--project_name', type=str, default='test_time',
help='project name')
parser.add_argument("--save_freq", type=int, default=2000,
help='frequency to save model')
parser.add_argument("--local_rank", type=int)
parser.add_argument('--lr_decay_freq', type=int, default=3000,
help='frequency to decay learning rate')
parser.add_argument('--no_ganFeat_loss', action='store_true',
help='if specified, do *not* use discriminator feature matching loss')
parser.add_argument('--sp_num', type=int, default=None,
help='superpixel number')
parser.add_argument('--add_self_loops', type=int, default=1,
help='set to 1 to add self loops in GCNs')
parser.add_argument('--test_time', type=int, default=0,
help='set to 1 to add self loops in GCNs')
parser.add_argument('--add_texture_epoch', type=int, default=1000,
help='when to add texture synthesis')
parser.add_argument('--add_clustering_epoch', type=int, default=1000,
help='when to add grouping')
parser.add_argument('--temperature', type=int, default=1,
help='temperature in SoftMax')
parser.add_argument('--gumbel', type=int, default=0,
help='if use gumbel SoftMax')
parser.add_argument('--patch_size', type=int, default=40,
help='patch size in texture synthesis')
parser.add_argument("--netE_num_downsampling_sp", default=4, type=int)
parser.add_argument('--num_classes', type=int, default=0)
parser.add_argument(
"--netG_num_base_resnet_layers",
default=2, type=int,
help="The number of resnet layers before the upsampling layers."
)
parser.add_argument("--netG_scale_capacity", default=1.0, type=float)
parser.add_argument("--netG_resnet_ch", type=int, default=256)
parser.add_argument("--spatial_code_ch", default=8, type=int)
parser.add_argument("--texture_code_ch", default=256, type=int)
parser.add_argument("--netE_scale_capacity", default=1.0, type=float)
parser.add_argument("--netE_num_downsampling_gl", default=2, type=int)
parser.add_argument("--netE_nc_steepness", default=2.0, type=float)
parser.add_argument("--spatial_code_dim", type=int, default=256, help="codebook entry dimension")
return parser
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
def save_options(self, opt):
os.makedirs(opt.out_dir, exist_ok=True)
file_name = os.path.join(opt.out_dir, 'exp_args.txt')
with open(file_name, 'wt') as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
opt_file.close()
def gather_options(self):
parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation')
self.parser = self.initialize(parser)
opt = self.parser.parse_args()
opt = self.update_with_json(opt)
opt.out_dir = os.path.join(opt.work_dir, opt.exp_name)
opt.use_slic = (opt.use_slic == 1)
opt.use_wandb = (opt.use_wandb == 1)
# logging
self.print_options(opt)
self.save_options(opt)
return opt
def update_with_json(self, args):
arg_dict = vars(args)
# arguments house keeping
with open(args.config_file, 'r') as f:
arg_str = f.read()
file_args = json.loads(arg_str)
arg_dict.update(file_args)
args = argparse.Namespace(**arg_dict)
return args
|