(args)
| 13 | |
| 14 | |
| 15 | def main(args): |
| 16 | if args.output_dir: |
| 17 | utils.mkdir(args.output_dir) |
| 18 | |
| 19 | utils.init_distributed_mode(args) |
| 20 | print(args) |
| 21 | |
| 22 | if args.post_training_quantize and args.distributed: |
| 23 | raise RuntimeError("Post training quantization example should not be performed on distributed mode") |
| 24 | |
| 25 | # Set backend engine to ensure that quantized model runs on the correct kernels |
| 26 | if args.qbackend not in torch.backends.quantized.supported_engines: |
| 27 | raise RuntimeError("Quantized backend not supported: " + str(args.qbackend)) |
| 28 | torch.backends.quantized.engine = args.qbackend |
| 29 | |
| 30 | device = torch.device(args.device) |
| 31 | torch.backends.cudnn.benchmark = True |
| 32 | |
| 33 | # Data loading code |
| 34 | print("Loading data") |
| 35 | train_dir = os.path.join(args.data_path, "train") |
| 36 | val_dir = os.path.join(args.data_path, "val") |
| 37 | |
| 38 | dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) |
| 39 | data_loader = torch.utils.data.DataLoader( |
| 40 | dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True |
| 41 | ) |
| 42 | |
| 43 | data_loader_test = torch.utils.data.DataLoader( |
| 44 | dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True |
| 45 | ) |
| 46 | |
| 47 | print("Creating model", args.model) |
| 48 | # when training quantized models, we always start from a pre-trained fp32 reference model |
| 49 | prefix = "quantized_" |
| 50 | model_name = args.model |
| 51 | if not model_name.startswith(prefix): |
| 52 | model_name = prefix + model_name |
| 53 | model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only) |
| 54 | model.to(device) |
| 55 | |
| 56 | if not (args.test_only or args.post_training_quantize): |
| 57 | model.fuse_model(is_qat=True) |
| 58 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend) |
| 59 | torch.ao.quantization.prepare_qat(model, inplace=True) |
| 60 | |
| 61 | if args.distributed and args.sync_bn: |
| 62 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| 63 | |
| 64 | optimizer = torch.optim.SGD( |
| 65 | model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay |
| 66 | ) |
| 67 | |
| 68 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) |
| 69 | |
| 70 | criterion = nn.CrossEntropyLoss() |
| 71 | model_without_ddp = model |
| 72 | if args.distributed: |
no test coverage detected
searching dependent graphs…