Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 2,999 Bytes
			
			| a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import argparse
arg_lists = []
parser = argparse.ArgumentParser(description="LANet")
def str2bool(v):
    return v.lower() in ("true", "1")
def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg
# train data params
traindata_arg = add_argument_group("Traindata Params")
traindata_arg.add_argument("--train_txt", type=str, default="", help="Train set.")
traindata_arg.add_argument(
    "--train_root", type=str, default="", help="Where the train images are."
)
traindata_arg.add_argument(
    "--batch_size", type=int, default=8, help="# of images in each batch of data"
)
traindata_arg.add_argument(
    "--num_workers",
    type=int,
    default=4,
    help="# of subprocesses to use for data loading",
)
traindata_arg.add_argument(
    "--pin_memory",
    type=str2bool,
    default=True,
    help="# of subprocesses to use for data loading",
)
traindata_arg.add_argument(
    "--shuffle",
    type=str2bool,
    default=True,
    help="Whether to shuffle the train and valid indices",
)
traindata_arg.add_argument("--image_shape", type=tuple, default=(240, 320), help="")
traindata_arg.add_argument(
    "--jittering", type=tuple, default=(0.5, 0.5, 0.2, 0.05), help=""
)
# data storage
storage_arg = add_argument_group("Storage")
storage_arg.add_argument("--ckpt_name", type=str, default="PointModel", help="")
# training params
train_arg = add_argument_group("Training Params")
train_arg.add_argument("--start_epoch", type=int, default=0, help="")
train_arg.add_argument("--max_epoch", type=int, default=12, help="")
train_arg.add_argument(
    "--init_lr", type=float, default=3e-4, help="Initial learning rate value."
)
train_arg.add_argument(
    "--lr_factor", type=float, default=0.5, help="Reduce learning rate value."
)
train_arg.add_argument(
    "--momentum", type=float, default=0.9, help="Nesterov momentum value."
)
train_arg.add_argument("--display", type=int, default=50, help="")
# loss function params
loss_arg = add_argument_group("Loss function Params")
loss_arg.add_argument("--score_weight", type=float, default=1.0, help="")
loss_arg.add_argument("--loc_weight", type=float, default=1.0, help="")
loss_arg.add_argument("--desc_weight", type=float, default=4.0, help="")
loss_arg.add_argument("--corres_weight", type=float, default=0.5, help="")
loss_arg.add_argument("--corres_threshold", type=int, default=4.0, help="")
# other params
misc_arg = add_argument_group("Misc.")
misc_arg.add_argument(
    "--use_gpu", type=str2bool, default=True, help="Whether to run on the GPU."
)
misc_arg.add_argument("--gpu", type=int, default=0, help="Which GPU to run on.")
misc_arg.add_argument(
    "--seed", type=int, default=1001, help="Seed to ensure reproducibility."
)
misc_arg.add_argument(
    "--ckpt_dir",
    type=str,
    default="./checkpoints",
    help="Directory in which to save model checkpoints.",
)
def get_config():
    config, unparsed = parser.parse_known_args()
    return config, unparsed
 |