()
| 99 | |
| 100 | |
| 101 | def main(): |
| 102 | # Training settings |
| 103 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') |
| 104 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
| 105 | help='input batch size for training (default: 64)') |
| 106 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', |
| 107 | help='input batch size for testing (default: 1000)') |
| 108 | parser.add_argument('--epochs', type=int, default=14, metavar='N', |
| 109 | help='number of epochs to train (default: 14)') |
| 110 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', |
| 111 | help='learning rate (default: 1.0)') |
| 112 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', |
| 113 | help='Learning rate step gamma (default: 0.7)') |
| 114 | parser.add_argument('--no-cuda', action='store_true', default=False, |
| 115 | help='disables CUDA training') |
| 116 | parser.add_argument('--seed', type=int, default=1, metavar='S', |
| 117 | help='random seed (default: 1)') |
| 118 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', |
| 119 | help='how many batches to wait before logging training status') |
| 120 | |
| 121 | parser.add_argument('--save-model', action='store_true', default=False, |
| 122 | help='For Saving the current Model') |
| 123 | args = parser.parse_args() |
| 124 | use_cuda = not args.no_cuda and torch.cuda.is_available() |
| 125 | |
| 126 | torch.manual_seed(args.seed) |
| 127 | |
| 128 | device = torch.device("cuda" if use_cuda else "cpu") |
| 129 | |
| 130 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} |
| 131 | train_loader = torch.utils.data.DataLoader( |
| 132 | datasets.MNIST('../data', train=True, download=True, |
| 133 | transform=transforms.Compose([ |
| 134 | transforms.ToTensor(), |
| 135 | transforms.Normalize((0.1307,), (0.3081,)) |
| 136 | ])), |
| 137 | batch_size=args.batch_size, shuffle=True, **kwargs) |
| 138 | test_loader = torch.utils.data.DataLoader( |
| 139 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ |
| 140 | transforms.ToTensor(), |
| 141 | transforms.Normalize((0.1307,), (0.3081,)) |
| 142 | ])), |
| 143 | batch_size=args.test_batch_size, shuffle=True, **kwargs) |
| 144 | |
| 145 | model = Net().to(device) |
| 146 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) |
| 147 | |
| 148 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) |
| 149 | for epoch in range(1, args.epochs + 1): |
| 150 | train(args, model, device, train_loader, optimizer, epoch) |
| 151 | test(model, device, test_loader) |
| 152 | scheduler.step() |
| 153 | |
| 154 | if args.save_model: |
| 155 | torch.save(model.state_dict(), "mnist_cnn.pt") |
| 156 | |
| 157 | |
| 158 | if __name__ == '__main__': |
no test coverage detected