(cfg, logger=None, **kwargs)
| 17 | return callbacks |
| 18 | |
| 19 | def getCheckpointCallback(cfg, logger=None, **kwargs): |
| 20 | callbacks = [] |
| 21 | # Logging |
| 22 | metric_monitor = { |
| 23 | "loss_total": "total/train", |
| 24 | "Train_jf": "recons/text2jfeats/train", |
| 25 | "Val_jf": "recons/text2jfeats/val", |
| 26 | "Train_rf": "recons/text2rfeats/train", |
| 27 | "Val_rf": "recons/text2rfeats/val", |
| 28 | "APE root": "Metrics/APE_root", |
| 29 | "APE mean pose": "Metrics/APE_mean_pose", |
| 30 | "AVE root": "Metrics/AVE_root", |
| 31 | "AVE mean pose": "Metrics/AVE_mean_pose", |
| 32 | "R_TOP_1": "Metrics/R_precision_top_1", |
| 33 | "R_TOP_2": "Metrics/R_precision_top_2", |
| 34 | "R_TOP_3": "Metrics/R_precision_top_3", |
| 35 | "gt_R_TOP_3": "Metrics/gt_R_precision_top_3", |
| 36 | "FID": "Metrics/FID", |
| 37 | "gt_FID": "Metrics/gt_FID", |
| 38 | "Diversity": "Metrics/Diversity", |
| 39 | "MM dist": "Metrics/Matching_score", |
| 40 | "Accuracy": "Metrics/accuracy", |
| 41 | } |
| 42 | callbacks.append( |
| 43 | progressLogger(logger,metric_monitor=metric_monitor,log_every_n_steps=1)) |
| 44 | |
| 45 | # Save 10 latest checkpoints |
| 46 | checkpointParams = { |
| 47 | 'dirpath': os.path.join(cfg.FOLDER_EXP, "checkpoints"), |
| 48 | 'filename': "{epoch}", |
| 49 | 'monitor': "step", |
| 50 | 'mode': "max", |
| 51 | 'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS, |
| 52 | 'save_top_k': 8, |
| 53 | 'save_last': True, |
| 54 | 'save_on_train_epoch_end': True |
| 55 | } |
| 56 | callbacks.append(ModelCheckpoint(**checkpointParams)) |
| 57 | |
| 58 | # Save checkpoint every n*10 epochs |
| 59 | checkpointParams.update({ |
| 60 | 'every_n_epochs': |
| 61 | cfg.LOGGER.VAL_EVERY_STEPS * 10, |
| 62 | 'save_top_k': |
| 63 | -1, |
| 64 | 'save_last': |
| 65 | False |
| 66 | }) |
| 67 | callbacks.append(ModelCheckpoint(**checkpointParams)) |
| 68 | |
| 69 | metrics = cfg.METRIC.TYPE |
| 70 | metric_monitor_map = { |
| 71 | 'TemosMetric': { |
| 72 | 'Metrics/APE_root': { |
| 73 | 'abbr': 'APEroot', |
| 74 | 'mode': 'min' |
| 75 | }, |
| 76 | }, |
no test coverage detected