| 67 | |
| 68 | |
| 69 | def get_args(): |
| 70 | parser = argparse.ArgumentParser(description='DG') |
| 71 | parser.add_argument('--algorithm', type=str, default="ERM") |
| 72 | parser.add_argument('--alpha', type=float, |
| 73 | default=0.1, help="DANN dis alpha") |
| 74 | parser.add_argument('--batch_size', type=int, |
| 75 | default=32, help="batch_size") |
| 76 | parser.add_argument('--beta1', type=float, default=0.5, help="Adam") |
| 77 | parser.add_argument('--bottleneck', type=int, default=256) |
| 78 | parser.add_argument('--checkpoint_freq', type=int, |
| 79 | default=100, help='Checkpoint every N steps') |
| 80 | parser.add_argument('--classifier', type=str, |
| 81 | default="linear", choices=["linear", "wn"]) |
| 82 | parser.add_argument('--class_balanced', type=int, default=0) |
| 83 | parser.add_argument('--data_file', type=str, default='') |
| 84 | parser.add_argument('--dataset', type=str, default='dsads') |
| 85 | parser.add_argument('--data_dir', type=str, default='') |
| 86 | parser.add_argument('--dis_hidden', type=int, default=256) |
| 87 | parser.add_argument('--gpu_id', type=str, nargs='?', |
| 88 | default='0', help="device id to run") |
| 89 | parser.add_argument('--layer', type=str, default="bn", |
| 90 | choices=["ori", "bn"]) |
| 91 | parser.add_argument('--ldmarginlosstype', type=str, default='avg_top_k', |
| 92 | choices=['all_top_k', 'worst_top_k', 'avg_top_k']) |
| 93 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") |
| 94 | parser.add_argument('--lr_decay1', type=float, |
| 95 | default=1.0, help='for pretrained featurizer') |
| 96 | parser.add_argument('--lr_decay2', type=float, default=1.0) |
| 97 | parser.add_argument('--max_epoch', type=int, |
| 98 | default=150, help="max iterations") |
| 99 | parser.add_argument('--mixupalpha', type=float, default=0.2) |
| 100 | parser.add_argument('--mixup_ld_margin', type=float, default=10000) |
| 101 | parser.add_argument('--mixupregtype', type=str, |
| 102 | default='l-margin', choices=['ld-margin']) |
| 103 | parser.add_argument('--net', type=str, |
| 104 | default='ActNetwork', help="ActNetwork") |
| 105 | parser.add_argument('--N_WORKERS', type=int, default=4) |
| 106 | parser.add_argument('--schuse', action='store_true') |
| 107 | parser.add_argument('--schusech', type=str, default='cos') |
| 108 | parser.add_argument('--seed', type=int, default=0) |
| 109 | parser.add_argument('--seed1', type=int, default=0) |
| 110 | parser.add_argument('--task', type=str, |
| 111 | default="cross_people", choices=['cross_people']) |
| 112 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) |
| 113 | parser.add_argument('--top_k', type=int, default=1) |
| 114 | parser.add_argument('--output', type=str, default="train_output") |
| 115 | parser.add_argument('--valid', action='store_true') |
| 116 | parser.add_argument('--valid_size', type=float, default=0.2) |
| 117 | parser.add_argument('--weight_decay', type=float, default=5e-4) |
| 118 | parser.add_argument('--wtype', type=str, default='ori', |
| 119 | choices=['ori', 'abs', 'fea']) |
| 120 | args = parser.parse_args() |
| 121 | return args |
| 122 | |
| 123 | |
| 124 | def init_args(args): |