(args)
| 180 | |
| 181 | |
| 182 | def main(args): |
| 183 | if args.backend.lower() == "tv_tensor" and not args.use_v2: |
| 184 | raise ValueError("Use --use-v2 if you want to use the tv_tensor backend.") |
| 185 | if args.dataset not in ("coco", "coco_kp"): |
| 186 | raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}") |
| 187 | if "keypoint" in args.model and args.dataset != "coco_kp": |
| 188 | raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp") |
| 189 | if args.dataset == "coco_kp" and args.use_v2: |
| 190 | raise ValueError("KeyPoint detection doesn't support V2 transforms yet") |
| 191 | |
| 192 | if args.output_dir: |
| 193 | utils.mkdir(args.output_dir) |
| 194 | |
| 195 | utils.init_distributed_mode(args) |
| 196 | print(args) |
| 197 | |
| 198 | device = torch.device(args.device) |
| 199 | |
| 200 | if args.use_deterministic_algorithms: |
| 201 | torch.use_deterministic_algorithms(True) |
| 202 | |
| 203 | # Data loading code |
| 204 | print("Loading data") |
| 205 | |
| 206 | dataset, num_classes = get_dataset(is_train=True, args=args) |
| 207 | dataset_test, _ = get_dataset(is_train=False, args=args) |
| 208 | |
| 209 | print("Creating data loaders") |
| 210 | if args.distributed: |
| 211 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
| 212 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) |
| 213 | else: |
| 214 | train_sampler = torch.utils.data.RandomSampler(dataset) |
| 215 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
| 216 | |
| 217 | if args.aspect_ratio_group_factor >= 0: |
| 218 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) |
| 219 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) |
| 220 | else: |
| 221 | train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) |
| 222 | |
| 223 | train_collate_fn = utils.collate_fn |
| 224 | if args.use_copypaste: |
| 225 | if args.data_augmentation != "lsj": |
| 226 | raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies") |
| 227 | |
| 228 | train_collate_fn = copypaste_collate_fn |
| 229 | |
| 230 | data_loader = torch.utils.data.DataLoader( |
| 231 | dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn |
| 232 | ) |
| 233 | |
| 234 | data_loader_test = torch.utils.data.DataLoader( |
| 235 | dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn |
| 236 | ) |
| 237 | |
| 238 | print("Creating model") |
| 239 | kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} |
no test coverage detected
searching dependent graphs…