| 5 | import deepspeed |
| 6 | |
| 7 | def add_argument(): |
| 8 | |
| 9 | parser=argparse.ArgumentParser(description='CIFAR') |
| 10 | |
| 11 | #data |
| 12 | # cuda |
| 13 | parser.add_argument('--with_cuda', default=False, action='store_true', |
| 14 | help='use CPU in case there\'s no GPU support') |
| 15 | parser.add_argument('--use_ema', default=False, action='store_true', |
| 16 | help='whether use exponential moving average') |
| 17 | |
| 18 | # train |
| 19 | parser.add_argument('-b', '--batch_size', default=32, type=int, |
| 20 | help='mini-batch size (default: 32)') |
| 21 | parser.add_argument('-e', '--epochs', default=30, type=int, |
| 22 | help='number of total epochs (default: 30)') |
| 23 | parser.add_argument('--local_rank', type=int, default=-1, |
| 24 | help='local rank passed from distributed launcher') |
| 25 | |
| 26 | # Include DeepSpeed configuration arguments |
| 27 | parser = deepspeed.add_config_arguments(parser) |
| 28 | |
| 29 | args=parser.parse_args() |
| 30 | |
| 31 | return args |
| 32 | |
| 33 | ######################################################################## |
| 34 | # The output of torchvision datasets are PILImage images of range [0, 1]. |