MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / getCheckpointCallback

Function getCheckpointCallback

mGPT/callback.py:19–140  ·  view source on GitHub ↗
(cfg, logger=None, **kwargs)

Source from the content-addressed store, hash-verified

17 return callbacks
18
19def getCheckpointCallback(cfg, logger=None, **kwargs):
20 callbacks = []
21 # Logging
22 metric_monitor = {
23 "loss_total": "total/train",
24 "Train_jf": "recons/text2jfeats/train",
25 "Val_jf": "recons/text2jfeats/val",
26 "Train_rf": "recons/text2rfeats/train",
27 "Val_rf": "recons/text2rfeats/val",
28 "APE root": "Metrics/APE_root",
29 "APE mean pose": "Metrics/APE_mean_pose",
30 "AVE root": "Metrics/AVE_root",
31 "AVE mean pose": "Metrics/AVE_mean_pose",
32 "R_TOP_1": "Metrics/R_precision_top_1",
33 "R_TOP_2": "Metrics/R_precision_top_2",
34 "R_TOP_3": "Metrics/R_precision_top_3",
35 "gt_R_TOP_3": "Metrics/gt_R_precision_top_3",
36 "FID": "Metrics/FID",
37 "gt_FID": "Metrics/gt_FID",
38 "Diversity": "Metrics/Diversity",
39 "MM dist": "Metrics/Matching_score",
40 "Accuracy": "Metrics/accuracy",
41 }
42 callbacks.append(
43 progressLogger(logger,metric_monitor=metric_monitor,log_every_n_steps=1))
44
45 # Save 10 latest checkpoints
46 checkpointParams = {
47 'dirpath': os.path.join(cfg.FOLDER_EXP, "checkpoints"),
48 'filename': "{epoch}",
49 'monitor': "step",
50 'mode': "max",
51 'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS,
52 'save_top_k': 8,
53 'save_last': True,
54 'save_on_train_epoch_end': True
55 }
56 callbacks.append(ModelCheckpoint(**checkpointParams))
57
58 # Save checkpoint every n*10 epochs
59 checkpointParams.update({
60 'every_n_epochs':
61 cfg.LOGGER.VAL_EVERY_STEPS * 10,
62 'save_top_k':
63 -1,
64 'save_last':
65 False
66 })
67 callbacks.append(ModelCheckpoint(**checkpointParams))
68
69 metrics = cfg.METRIC.TYPE
70 metric_monitor_map = {
71 'TemosMetric': {
72 'Metrics/APE_root': {
73 'abbr': 'APEroot',
74 'mode': 'min'
75 },
76 },

Callers 1

build_callbacksFunction · 0.85

Calls 3

progressLoggerClass · 0.85
keysMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected