build loaders for contrastive training
(opt, ngpus_per_node)
| 337 | |
| 338 | |
| 339 | def build_contrast_loader(opt, ngpus_per_node): |
| 340 | """build loaders for contrastive training""" |
| 341 | data_folder = opt.data_folder |
| 342 | aug = opt.aug |
| 343 | modal = opt.modal |
| 344 | use_jigsaw = opt.jigsaw |
| 345 | use_memory_bank = (opt.mem == 'bank') |
| 346 | batch_size = int(opt.batch_size / opt.world_size) |
| 347 | num_workers = int((opt.num_workers + ngpus_per_node - 1) / ngpus_per_node) |
| 348 | |
| 349 | train_transform, jigsaw_transform = \ |
| 350 | build_transforms(aug, modal, use_memory_bank) |
| 351 | |
| 352 | train_dir = os.path.join(data_folder, 'train') |
| 353 | if use_jigsaw: |
| 354 | train_dataset = ImageFolderInstance( |
| 355 | train_dir, transform=train_transform, |
| 356 | two_crop=(not use_memory_bank), |
| 357 | jigsaw_transform=jigsaw_transform |
| 358 | ) |
| 359 | else: |
| 360 | train_dataset = ImageFolderInstance( |
| 361 | train_dir, transform=train_transform, |
| 362 | two_crop=(not use_memory_bank) |
| 363 | ) |
| 364 | |
| 365 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
| 366 | |
| 367 | train_loader = torch.utils.data.DataLoader( |
| 368 | train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), |
| 369 | num_workers=num_workers, pin_memory=True, sampler=train_sampler) |
| 370 | |
| 371 | print('train images: {}'.format(len(train_dataset))) |
| 372 | |
| 373 | return train_dataset, train_loader, train_sampler |
| 374 | |
| 375 | |
| 376 | def build_linear_loader(opt, ngpus_per_node): |
no test coverage detected