Get default arguments.
()
| 9 | import random |
| 10 | |
| 11 | def get_parser(): |
| 12 | """Get default arguments.""" |
| 13 | parser = configargparse.ArgumentParser( |
| 14 | description="Transfer learning config parser", |
| 15 | config_file_parser_class=configargparse.YAMLConfigFileParser, |
| 16 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, |
| 17 | ) |
| 18 | # general configuration |
| 19 | parser.add("--config", is_config_file=True, help="config file path") |
| 20 | parser.add("--seed", type=int, default=0) |
| 21 | parser.add_argument('--num_workers', type=int, default=0) |
| 22 | |
| 23 | # network related |
| 24 | parser.add_argument('--backbone', type=str, default='resnet50') |
| 25 | parser.add_argument('--use_bottleneck', type=str2bool, default=True) |
| 26 | |
| 27 | # data loading related |
| 28 | parser.add_argument('--data_dir', type=str, required=True) |
| 29 | parser.add_argument('--src_domain', type=str, required=True) |
| 30 | parser.add_argument('--tgt_domain', type=str, required=True) |
| 31 | |
| 32 | # training related |
| 33 | parser.add_argument('--batch_size', type=int, default=32) |
| 34 | parser.add_argument('--n_epoch', type=int, default=100) |
| 35 | parser.add_argument('--early_stop', type=int, default=0, help="Early stopping") |
| 36 | parser.add_argument('--epoch_based_training', type=str2bool, default=False, help="Epoch-based training / Iteration-based training") |
| 37 | parser.add_argument("--n_iter_per_epoch", type=int, default=20, help="Used in Iteration-based training") |
| 38 | |
| 39 | # optimizer related |
| 40 | parser.add_argument('--lr', type=float, default=1e-3) |
| 41 | parser.add_argument('--momentum', type=float, default=0.9) |
| 42 | parser.add_argument('--weight_decay', type=float, default=5e-4) |
| 43 | |
| 44 | # learning rate scheduler related |
| 45 | parser.add_argument('--lr_gamma', type=float, default=0.0003) |
| 46 | parser.add_argument('--lr_decay', type=float, default=0.75) |
| 47 | parser.add_argument('--lr_scheduler', type=str2bool, default=True) |
| 48 | |
| 49 | # transfer related |
| 50 | parser.add_argument('--transfer_loss_weight', type=float, default=10) |
| 51 | parser.add_argument('--transfer_loss', type=str, default='mmd') |
| 52 | return parser |
| 53 | |
| 54 | def set_random_seed(seed=0): |
| 55 | # seed setting |