MCPcopy
hub / github.com/microsoft/AI-System / main

Function main

Labs/BasicLabs/Lab1/mnist_basic.py:101–155  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

99
100
101def main():
102 # Training settings
103 parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
104 parser.add_argument('--batch-size', type=int, default=64, metavar='N',
105 help='input batch size for training (default: 64)')
106 parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
107 help='input batch size for testing (default: 1000)')
108 parser.add_argument('--epochs', type=int, default=14, metavar='N',
109 help='number of epochs to train (default: 14)')
110 parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
111 help='learning rate (default: 1.0)')
112 parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
113 help='Learning rate step gamma (default: 0.7)')
114 parser.add_argument('--no-cuda', action='store_true', default=False,
115 help='disables CUDA training')
116 parser.add_argument('--seed', type=int, default=1, metavar='S',
117 help='random seed (default: 1)')
118 parser.add_argument('--log-interval', type=int, default=10, metavar='N',
119 help='how many batches to wait before logging training status')
120
121 parser.add_argument('--save-model', action='store_true', default=False,
122 help='For Saving the current Model')
123 args = parser.parse_args()
124 use_cuda = not args.no_cuda and torch.cuda.is_available()
125
126 torch.manual_seed(args.seed)
127
128 device = torch.device("cuda" if use_cuda else "cpu")
129
130 kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
131 train_loader = torch.utils.data.DataLoader(
132 datasets.MNIST('../data', train=True, download=True,
133 transform=transforms.Compose([
134 transforms.ToTensor(),
135 transforms.Normalize((0.1307,), (0.3081,))
136 ])),
137 batch_size=args.batch_size, shuffle=True, **kwargs)
138 test_loader = torch.utils.data.DataLoader(
139 datasets.MNIST('../data', train=False, transform=transforms.Compose([
140 transforms.ToTensor(),
141 transforms.Normalize((0.1307,), (0.3081,))
142 ])),
143 batch_size=args.test_batch_size, shuffle=True, **kwargs)
144
145 model = Net().to(device)
146 optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
147
148 scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
149 for epoch in range(1, args.epochs + 1):
150 train(args, model, device, train_loader, optimizer, epoch)
151 test(model, device, test_loader)
152 scheduler.step()
153
154 if args.save_model:
155 torch.save(model.state_dict(), "mnist_cnn.pt")
156
157
158if __name__ == '__main__':

Callers 1

mnist_basic.pyFile · 0.70

Calls 3

NetClass · 0.70
trainFunction · 0.70
testFunction · 0.70

Tested by

no test coverage detected