MCPcopy
hub / github.com/pytorch/vision / main

Function main

references/detection/train.py:182–329  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

180
181
182def main(args):
183 if args.backend.lower() == "tv_tensor" and not args.use_v2:
184 raise ValueError("Use --use-v2 if you want to use the tv_tensor backend.")
185 if args.dataset not in ("coco", "coco_kp"):
186 raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
187 if "keypoint" in args.model and args.dataset != "coco_kp":
188 raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
189 if args.dataset == "coco_kp" and args.use_v2:
190 raise ValueError("KeyPoint detection doesn't support V2 transforms yet")
191
192 if args.output_dir:
193 utils.mkdir(args.output_dir)
194
195 utils.init_distributed_mode(args)
196 print(args)
197
198 device = torch.device(args.device)
199
200 if args.use_deterministic_algorithms:
201 torch.use_deterministic_algorithms(True)
202
203 # Data loading code
204 print("Loading data")
205
206 dataset, num_classes = get_dataset(is_train=True, args=args)
207 dataset_test, _ = get_dataset(is_train=False, args=args)
208
209 print("Creating data loaders")
210 if args.distributed:
211 train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
212 test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
213 else:
214 train_sampler = torch.utils.data.RandomSampler(dataset)
215 test_sampler = torch.utils.data.SequentialSampler(dataset_test)
216
217 if args.aspect_ratio_group_factor >= 0:
218 group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
219 train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
220 else:
221 train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
222
223 train_collate_fn = utils.collate_fn
224 if args.use_copypaste:
225 if args.data_augmentation != "lsj":
226 raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
227
228 train_collate_fn = copypaste_collate_fn
229
230 data_loader = torch.utils.data.DataLoader(
231 dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
232 )
233
234 data_loader_test = torch.utils.data.DataLoader(
235 dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
236 )
237
238 print("Creating model")
239 kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}

Callers 1

train.pyFile · 0.70

Calls 10

GroupedBatchSamplerClass · 0.90
evaluateFunction · 0.90
train_one_epochFunction · 0.90
deviceMethod · 0.80
toMethod · 0.80
loadMethod · 0.80
printFunction · 0.70
get_datasetFunction · 0.70
set_epochMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…