MCPcopy
hub / github.com/pytorch/vision / main

Function main

references/classification/train_quantization.py:15–154  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

13
14
15def main(args):
16 if args.output_dir:
17 utils.mkdir(args.output_dir)
18
19 utils.init_distributed_mode(args)
20 print(args)
21
22 if args.post_training_quantize and args.distributed:
23 raise RuntimeError("Post training quantization example should not be performed on distributed mode")
24
25 # Set backend engine to ensure that quantized model runs on the correct kernels
26 if args.qbackend not in torch.backends.quantized.supported_engines:
27 raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
28 torch.backends.quantized.engine = args.qbackend
29
30 device = torch.device(args.device)
31 torch.backends.cudnn.benchmark = True
32
33 # Data loading code
34 print("Loading data")
35 train_dir = os.path.join(args.data_path, "train")
36 val_dir = os.path.join(args.data_path, "val")
37
38 dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39 data_loader = torch.utils.data.DataLoader(
40 dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
41 )
42
43 data_loader_test = torch.utils.data.DataLoader(
44 dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
45 )
46
47 print("Creating model", args.model)
48 # when training quantized models, we always start from a pre-trained fp32 reference model
49 prefix = "quantized_"
50 model_name = args.model
51 if not model_name.startswith(prefix):
52 model_name = prefix + model_name
53 model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
54 model.to(device)
55
56 if not (args.test_only or args.post_training_quantize):
57 model.fuse_model(is_qat=True)
58 model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
59 torch.ao.quantization.prepare_qat(model, inplace=True)
60
61 if args.distributed and args.sync_bn:
62 model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
63
64 optimizer = torch.optim.SGD(
65 model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
66 )
67
68 lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
69
70 criterion = nn.CrossEntropyLoss()
71 model_without_ddp = model
72 if args.distributed:

Callers 1

Calls 11

load_dataFunction · 0.90
evaluateFunction · 0.90
train_one_epochFunction · 0.90
deviceMethod · 0.80
toMethod · 0.80
loadMethod · 0.80
prepareMethod · 0.80
trainMethod · 0.80
printFunction · 0.70
fuse_modelMethod · 0.45
set_epochMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…