MCPcopy Index your code
hub / github.com/pytorch/examples / main

Function main

siamese_network/main.py:237–298  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

235
236
237def main():
238 # Training settings
239 parser = argparse.ArgumentParser(description='PyTorch Siamese network Example')
240 parser.add_argument('--batch-size', type=int, default=64, metavar='N',
241 help='input batch size for training (default: 64)')
242 parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
243 help='input batch size for testing (default: 1000)')
244 parser.add_argument('--epochs', type=int, default=14, metavar='N',
245 help='number of epochs to train (default: 14)')
246 parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
247 help='learning rate (default: 1.0)')
248 parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
249 help='Learning rate step gamma (default: 0.7)')
250 parser.add_argument('--no-accel', action='store_true',
251 help='disables accelerator')
252 parser.add_argument('--dry-run', action='store_true', default=False,
253 help='quickly check a single pass')
254 parser.add_argument('--seed', type=int, default=1, metavar='S',
255 help='random seed (default: 1)')
256 parser.add_argument('--log-interval', type=int, default=10, metavar='N',
257 help='how many batches to wait before logging training status')
258 parser.add_argument('--save-model', action='store_true', default=False,
259 help='For Saving the current Model')
260 args = parser.parse_args()
261
262 use_accel = not args.no_accel and torch.accelerator.is_available()
263
264 torch.manual_seed(args.seed)
265
266
267 if use_accel:
268 device = torch.accelerator.current_accelerator()
269 else:
270 device = torch.device("cpu")
271
272 print(f"Using device: {device}")
273
274 train_kwargs = {'batch_size': args.batch_size}
275 test_kwargs = {'batch_size': args.test_batch_size}
276 if use_accel:
277 accel_kwargs = {'num_workers': 1,
278 'pin_memory': True,
279 'shuffle': True}
280 train_kwargs.update(accel_kwargs)
281 test_kwargs.update(accel_kwargs)
282
283 train_dataset = APP_MATCHER('../data', train=True, download=True)
284 test_dataset = APP_MATCHER('../data', train=False)
285 train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
286 test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
287
288 model = SiameseNetwork().to(device)
289 optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
290
291 scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
292 for epoch in range(1, args.epochs + 1):
293 train(args, model, device, train_loader, optimizer, epoch)
294 test(model, device, test_loader)

Callers 1

main.pyFile · 0.70

Calls 6

APP_MATCHERClass · 0.85
SiameseNetworkClass · 0.85
updateMethod · 0.80
saveMethod · 0.80
trainFunction · 0.70
testFunction · 0.70

Tested by

no test coverage detected