Spaces:
Runtime error
Runtime error
import os | |
import json | |
import argparse | |
class BaseOptions(): | |
def initialize(self, parser): | |
parser.add_argument('--nChannel', metavar='N', default=100, type=int, | |
help='number of channels') | |
parser.add_argument('--maxIter', metavar='T', default=1000, type=int, | |
help='number of maximum iterations') | |
parser.add_argument('--lr', metavar='LR', default=0.1, type=float, | |
help='learning rate') | |
parser.add_argument('--nConv', metavar='M', default=2, type=int, | |
help='number of convolutional layers') | |
parser.add_argument("--work_dir", type=str, default="/home/xli/WORKDIR", | |
help='project directory') | |
parser.add_argument("--out_dir", type=str, default=None, | |
help='logging output') | |
parser.add_argument("--use_wandb", type=int, default=0, | |
help='use wandb or not') | |
parser.add_argument("--data_path", type=str, default="/home/xli/DATA/BSR_processed/train_extend", | |
help="data path") | |
parser.add_argument("--img_path", type=str, default=None, | |
help="image path") | |
parser.add_argument('--crop_size', type=int, default= 224, | |
help='crop_size') | |
parser.add_argument("--batch_size", type=int, default=1, | |
help='batch size') | |
parser.add_argument('--workers', type=int, default=4, | |
help='number of data loading workers') | |
parser.add_argument("--use_slic", default = 1, type=int, | |
help="choose to use slic or gt label") | |
parser.add_argument("-f", "--config_file", type=str, default='models/week0417/json/single_scale_grouping_ft.json', | |
help='json files including all arguments') | |
parser.add_argument("--log_freq", type=int, default=10, | |
help='frequency to print log') | |
parser.add_argument("--display_freq", type=int, default=100, | |
help='frequency to save visualization') | |
parser.add_argument("--pretrained_ae", type=str, | |
default = "/home/xli/WORKDIR/07-16/transformer/cpk.pth") | |
parser.add_argument("--pretrained_path", type=str, default=None, | |
help='pretrained reconstruction model') | |
parser.add_argument('--momentum', type=float, default=0.5, | |
help='momentum for sgd, alpha parameter for adam') | |
parser.add_argument('--beta', type=float, default=0.999, | |
help='beta parameter for adam') | |
parser.add_argument("--l1_loss_wt", default=1.0, type=float) | |
parser.add_argument("--perceptual_loss_wt", default=1.0, type=float) | |
parser.add_argument('--project_name', type=str, default='test_time', | |
help='project name') | |
parser.add_argument("--save_freq", type=int, default=2000, | |
help='frequency to save model') | |
parser.add_argument("--local_rank", type=int) | |
parser.add_argument('--lr_decay_freq', type=int, default=3000, | |
help='frequency to decay learning rate') | |
parser.add_argument('--no_ganFeat_loss', action='store_true', | |
help='if specified, do *not* use discriminator feature matching loss') | |
parser.add_argument('--sp_num', type=int, default=None, | |
help='superpixel number') | |
parser.add_argument('--add_self_loops', type=int, default=1, | |
help='set to 1 to add self loops in GCNs') | |
parser.add_argument('--test_time', type=int, default=0, | |
help='set to 1 to add self loops in GCNs') | |
parser.add_argument('--add_texture_epoch', type=int, default=1000, | |
help='when to add texture synthesis') | |
parser.add_argument('--add_clustering_epoch', type=int, default=1000, | |
help='when to add grouping') | |
parser.add_argument('--temperature', type=int, default=1, | |
help='temperature in SoftMax') | |
parser.add_argument('--gumbel', type=int, default=0, | |
help='if use gumbel SoftMax') | |
parser.add_argument('--patch_size', type=int, default=40, | |
help='patch size in texture synthesis') | |
parser.add_argument("--netE_num_downsampling_sp", default=4, type=int) | |
parser.add_argument('--num_classes', type=int, default=0) | |
parser.add_argument( | |
"--netG_num_base_resnet_layers", | |
default=2, type=int, | |
help="The number of resnet layers before the upsampling layers." | |
) | |
parser.add_argument("--netG_scale_capacity", default=1.0, type=float) | |
parser.add_argument("--netG_resnet_ch", type=int, default=256) | |
parser.add_argument("--spatial_code_ch", default=8, type=int) | |
parser.add_argument("--texture_code_ch", default=256, type=int) | |
parser.add_argument("--netE_scale_capacity", default=1.0, type=float) | |
parser.add_argument("--netE_num_downsampling_gl", default=2, type=int) | |
parser.add_argument("--netE_nc_steepness", default=2.0, type=float) | |
parser.add_argument("--spatial_code_dim", type=int, default=256, help="codebook entry dimension") | |
return parser | |
def print_options(self, opt): | |
"""Print and save options | |
It will print both current options and default values(if different). | |
It will save options into a text file / [checkpoints_dir] / opt.txt | |
""" | |
message = '' | |
message += '----------------- Options ---------------\n' | |
for k, v in sorted(vars(opt).items()): | |
comment = '' | |
default = self.parser.get_default(k) | |
if v != default: | |
comment = '\t[default: %s]' % str(default) | |
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) | |
message += '----------------- End -------------------' | |
print(message) | |
def save_options(self, opt): | |
os.makedirs(opt.out_dir, exist_ok=True) | |
file_name = os.path.join(opt.out_dir, 'exp_args.txt') | |
with open(file_name, 'wt') as opt_file: | |
for k, v in sorted(vars(opt).items()): | |
comment = '' | |
default = self.parser.get_default(k) | |
if v != default: | |
comment = '\t[default: %s]' % str(default) | |
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) | |
opt_file.close() | |
def gather_options(self): | |
parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation') | |
self.parser = self.initialize(parser) | |
opt = self.parser.parse_args() | |
opt = self.update_with_json(opt) | |
opt.out_dir = os.path.join(opt.work_dir, opt.exp_name) | |
opt.use_slic = (opt.use_slic == 1) | |
opt.use_wandb = (opt.use_wandb == 1) | |
# logging | |
self.print_options(opt) | |
self.save_options(opt) | |
return opt | |
def update_with_json(self, args): | |
arg_dict = vars(args) | |
# arguments house keeping | |
with open(args.config_file, 'r') as f: | |
arg_str = f.read() | |
file_args = json.loads(arg_str) | |
arg_dict.update(file_args) | |
args = argparse.Namespace(**arg_dict) | |
return args | |