()
| 235 | |
| 236 | |
| 237 | def main(): |
| 238 | # Training settings |
| 239 | parser = argparse.ArgumentParser(description='PyTorch Siamese network Example') |
| 240 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
| 241 | help='input batch size for training (default: 64)') |
| 242 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', |
| 243 | help='input batch size for testing (default: 1000)') |
| 244 | parser.add_argument('--epochs', type=int, default=14, metavar='N', |
| 245 | help='number of epochs to train (default: 14)') |
| 246 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', |
| 247 | help='learning rate (default: 1.0)') |
| 248 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', |
| 249 | help='Learning rate step gamma (default: 0.7)') |
| 250 | parser.add_argument('--no-accel', action='store_true', |
| 251 | help='disables accelerator') |
| 252 | parser.add_argument('--dry-run', action='store_true', default=False, |
| 253 | help='quickly check a single pass') |
| 254 | parser.add_argument('--seed', type=int, default=1, metavar='S', |
| 255 | help='random seed (default: 1)') |
| 256 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', |
| 257 | help='how many batches to wait before logging training status') |
| 258 | parser.add_argument('--save-model', action='store_true', default=False, |
| 259 | help='For Saving the current Model') |
| 260 | args = parser.parse_args() |
| 261 | |
| 262 | use_accel = not args.no_accel and torch.accelerator.is_available() |
| 263 | |
| 264 | torch.manual_seed(args.seed) |
| 265 | |
| 266 | |
| 267 | if use_accel: |
| 268 | device = torch.accelerator.current_accelerator() |
| 269 | else: |
| 270 | device = torch.device("cpu") |
| 271 | |
| 272 | print(f"Using device: {device}") |
| 273 | |
| 274 | train_kwargs = {'batch_size': args.batch_size} |
| 275 | test_kwargs = {'batch_size': args.test_batch_size} |
| 276 | if use_accel: |
| 277 | accel_kwargs = {'num_workers': 1, |
| 278 | 'pin_memory': True, |
| 279 | 'shuffle': True} |
| 280 | train_kwargs.update(accel_kwargs) |
| 281 | test_kwargs.update(accel_kwargs) |
| 282 | |
| 283 | train_dataset = APP_MATCHER('../data', train=True, download=True) |
| 284 | test_dataset = APP_MATCHER('../data', train=False) |
| 285 | train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) |
| 286 | test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) |
| 287 | |
| 288 | model = SiameseNetwork().to(device) |
| 289 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) |
| 290 | |
| 291 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) |
| 292 | for epoch in range(1, args.epochs + 1): |
| 293 | train(args, model, device, train_loader, optimizer, epoch) |
| 294 | test(model, device, test_loader) |
no test coverage detected