()
| 13 | |
| 14 | |
| 15 | def get_args(): |
| 16 | parser = argparse.ArgumentParser(description='DG') |
| 17 | parser.add_argument('--algorithm', type=str, default="ERM") |
| 18 | parser.add_argument('--alpha', type=float, |
| 19 | default=1, help='DANN dis alpha') |
| 20 | parser.add_argument('--anneal_iters', type=int, |
| 21 | default=500, help='Penalty anneal iters used in VREx') |
| 22 | parser.add_argument('--batch_size', type=int, |
| 23 | default=32, help='batch_size') |
| 24 | parser.add_argument('--beta', type=float, |
| 25 | default=1, help='DIFEX beta') |
| 26 | parser.add_argument('--beta1', type=float, default=0.5, |
| 27 | help='Adam hyper-param') |
| 28 | parser.add_argument('--bottleneck', type=int, default=256) |
| 29 | parser.add_argument('--checkpoint_freq', type=int, |
| 30 | default=3, help='Checkpoint every N epoch') |
| 31 | parser.add_argument('--classifier', type=str, |
| 32 | default="linear", choices=["linear", "wn"]) |
| 33 | parser.add_argument('--data_file', type=str, default='', |
| 34 | help='root_dir') |
| 35 | parser.add_argument('--dataset', type=str, default='office') |
| 36 | parser.add_argument('--data_dir', type=str, default='', help='data dir') |
| 37 | parser.add_argument('--dis_hidden', type=int, |
| 38 | default=256, help='dis hidden dimension') |
| 39 | parser.add_argument('--disttype', type=str, default='2-norm', |
| 40 | choices=['1-norm', '2-norm', 'cos', 'norm-2-norm', 'norm-1-norm']) |
| 41 | parser.add_argument('--gpu_id', type=str, nargs='?', |
| 42 | default='0', help="device id to run") |
| 43 | parser.add_argument('--groupdro_eta', type=float, |
| 44 | default=1, help="groupdro eta") |
| 45 | parser.add_argument('--inner_lr', type=float, |
| 46 | default=1e-2, help="learning rate used in MLDG") |
| 47 | parser.add_argument('--lam', type=float, |
| 48 | default=1, help="tradeoff hyperparameter used in VREx") |
| 49 | parser.add_argument('--layer', type=str, default="bn", |
| 50 | choices=["ori", "bn"]) |
| 51 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") |
| 52 | parser.add_argument('--lr_decay', type=float, default=0.75, help='for sgd') |
| 53 | parser.add_argument('--lr_decay1', type=float, |
| 54 | default=1.0, help='for pretrained featurizer') |
| 55 | parser.add_argument('--lr_decay2', type=float, default=1.0, |
| 56 | help='inital learning rate decay of network') |
| 57 | parser.add_argument('--lr_gamma', type=float, |
| 58 | default=0.0003, help='for optimizer') |
| 59 | parser.add_argument('--max_epoch', type=int, |
| 60 | default=120, help="max iterations") |
| 61 | parser.add_argument('--mixupalpha', type=float, |
| 62 | default=0.2, help='mixup hyper-param') |
| 63 | parser.add_argument('--mldg_beta', type=float, |
| 64 | default=1, help="mldg hyper-param") |
| 65 | parser.add_argument('--mmd_gamma', type=float, |
| 66 | default=1, help='MMD, CORAL hyper-param') |
| 67 | parser.add_argument('--momentum', type=float, |
| 68 | default=0.9, help='for optimizer') |
| 69 | parser.add_argument('--net', type=str, default='resnet50', |
| 70 | help="featurizer: vgg16, resnet50, resnet101,DTNBase") |
| 71 | parser.add_argument('--N_WORKERS', type=int, default=4) |
| 72 | parser.add_argument('--rsc_f_drop_factor', type=float, |
no test coverage detected