(config, scaler=None)
| 197 | |
| 198 | |
| 199 | def train_distill(config, scaler=None): |
| 200 | EPOCH = config["epoch"] |
| 201 | topk = config["topk"] |
| 202 | |
| 203 | batch_size = config["TRAIN"]["batch_size"] |
| 204 | num_workers = config["TRAIN"]["num_workers"] |
| 205 | train_loader = build_dataloader( |
| 206 | "train", batch_size=batch_size, num_workers=num_workers |
| 207 | ) |
| 208 | |
| 209 | # build metric |
| 210 | metric_func = create_metric |
| 211 | |
| 212 | # model = distillmv3_large_x0_5(class_dim=100) |
| 213 | model = build_model(config) |
| 214 | |
| 215 | # pact quant train |
| 216 | if "quant_train" in config and config["quant_train"] is True: |
| 217 | quanter = QAT(config=quant_config, act_preprocess=PACT) |
| 218 | quanter.quantize(model) |
| 219 | elif "prune_train" in config and config["prune_train"] is True: |
| 220 | model = prune_model(model, [1, 3, 32, 32], 0.1) |
| 221 | else: |
| 222 | pass |
| 223 | |
| 224 | # build_optimizer |
| 225 | optimizer, lr_scheduler = create_optimizer( |
| 226 | config, parameter_list=model.parameters() |
| 227 | ) |
| 228 | |
| 229 | # load model |
| 230 | pre_best_model_dict = load_model(config, model, optimizer) |
| 231 | if len(pre_best_model_dict) > 0: |
| 232 | pre_str = "The metric of loaded metric as follows {}".format( |
| 233 | ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()]) |
| 234 | ) |
| 235 | logger.info(pre_str) |
| 236 | |
| 237 | model.train() |
| 238 | model = paddle.DataParallel(model) |
| 239 | |
| 240 | # build loss function |
| 241 | loss_func_distill = LossDistill(model_name_list=["student", "student1"]) |
| 242 | loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"]) |
| 243 | loss_func_js = KLJSLoss(mode="js") |
| 244 | |
| 245 | data_num = len(train_loader) |
| 246 | |
| 247 | best_acc = {} |
| 248 | for epoch in range(EPOCH): |
| 249 | st = time.time() |
| 250 | for idx, data in enumerate(train_loader): |
| 251 | img_batch, label = data |
| 252 | img_batch = paddle.transpose(img_batch, [0, 3, 1, 2]) |
| 253 | label = paddle.unsqueeze(label, -1) |
| 254 | if scaler is not None: |
| 255 | with paddle.amp.auto_cast(): |
| 256 | outs = model(img_batch) |
no test coverage detected
searching dependent graphs…