| 24 | |
| 25 | |
| 26 | def parse_option(): |
| 27 | parser = argparse.ArgumentParser('argument for training') |
| 28 | |
| 29 | parser.add_argument('--print_freq', type=int, default=10, |
| 30 | help='print frequency') |
| 31 | parser.add_argument('--save_freq', type=int, default=50, |
| 32 | help='save frequency') |
| 33 | parser.add_argument('--batch_size', type=int, default=256, |
| 34 | help='batch_size') |
| 35 | parser.add_argument('--num_workers', type=int, default=16, |
| 36 | help='num of workers to use') |
| 37 | parser.add_argument('--epochs', type=int, default=500, |
| 38 | help='number of training epochs') |
| 39 | |
| 40 | # optimization |
| 41 | parser.add_argument('--learning_rate', type=float, default=0.2, |
| 42 | help='learning rate') |
| 43 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450', |
| 44 | help='where to decay lr, can be a list') |
| 45 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, |
| 46 | help='decay rate for learning rate') |
| 47 | parser.add_argument('--weight_decay', type=float, default=1e-4, |
| 48 | help='weight decay') |
| 49 | parser.add_argument('--momentum', type=float, default=0.9, |
| 50 | help='momentum') |
| 51 | |
| 52 | # model dataset |
| 53 | parser.add_argument('--model', type=str, default='resnet50') |
| 54 | parser.add_argument('--dataset', type=str, default='cifar10', |
| 55 | choices=['cifar10', 'cifar100'], help='dataset') |
| 56 | |
| 57 | # other setting |
| 58 | parser.add_argument('--cosine', action='store_true', |
| 59 | help='using cosine annealing') |
| 60 | parser.add_argument('--syncBN', action='store_true', |
| 61 | help='using synchronized batch normalization') |
| 62 | parser.add_argument('--warm', action='store_true', |
| 63 | help='warm-up for large batch training') |
| 64 | parser.add_argument('--trial', type=str, default='0', |
| 65 | help='id for recording multiple runs') |
| 66 | |
| 67 | opt = parser.parse_args() |
| 68 | |
| 69 | # set the path according to the environment |
| 70 | opt.data_folder = './datasets/' |
| 71 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) |
| 72 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) |
| 73 | |
| 74 | iterations = opt.lr_decay_epochs.split(',') |
| 75 | opt.lr_decay_epochs = list([]) |
| 76 | for it in iterations: |
| 77 | opt.lr_decay_epochs.append(int(it)) |
| 78 | |
| 79 | opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.\ |
| 80 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, |
| 81 | opt.batch_size, opt.trial) |
| 82 | |
| 83 | if opt.cosine: |