MCPcopy Index your code
hub / github.com/pytorch/pytorch / FinalizeAfterCheckpoint

Function FinalizeAfterCheckpoint

caffe2/python/data_parallel_model.py:930–984  ·  view source on GitHub ↗

This function should be called after loading parameters from a checkpoint / initial parameters file.

(model, blobs=None, cpu_mode=False)

Source from the content-addressed store, hash-verified

928
929
930def FinalizeAfterCheckpoint(model, blobs=None, cpu_mode=False):
931 '''
932 This function should be called after loading parameters from a
933 checkpoint / initial parameters file.
934 '''
935
936 if not hasattr(model, "_checkpoint_net"):
937 if blobs is None:
938 (_, uniq_blob_names) = _ComputeBlobsToSync(model)
939 else:
940 uniq_blob_names = [stripBlobName(p) for p in blobs]
941
942 # Synchronize to the blob lookup map, as the provided
943 # blobs might have non-parameters, such as momentum blobs.
944 log.info("Creating checkpoint synchronization net")
945 devices = model.GetDevices()
946 for name in uniq_blob_names:
947 if name not in model._device_grouped_blobs:
948 grouped = {
949 d:
950 core.BlobReference("{}_{}{}{}".format(
951 model._device_prefix,
952 d,
953 scope._NAMESCOPE_SEPARATOR,
954 name)
955 ) for d in devices}
956 model._device_grouped_blobs[name] = grouped
957
958 model._checkpoint_net = core.Net("checkpoint_sync_net")
959 if not cpu_mode:
960 model._checkpoint_net.RunAllOnGPU()
961
962 checkpoint_init_net = None
963 if (model._rendezvous is not None and model._rendezvous['num_shards'] > 1):
964 checkpoint_init_net = core.Net("checkpoint_init_net")
965 if not cpu_mode:
966 checkpoint_init_net.RunAllOnGPU()
967
968 _SyncAllParams(
969 devices,
970 model,
971 checkpoint_init_net,
972 model._checkpoint_net,
973 model._rendezvous,
974 uniq_blob_names,
975 max_concurrent_distributed_ops=1
976 )
977 if (checkpoint_init_net):
978 workspace.RunNetOnce(checkpoint_init_net)
979
980 workspace.CreateNet(model._checkpoint_net)
981
982 # Run the sync
983 log.info("Run checkpoint net")
984 workspace.RunNet(model._checkpoint_net.Proto().name)
985
986
987def GetLearningRateBlobNames(model):

Callers

nothing calls this directly

Calls 9

RunAllOnGPUMethod · 0.95
_ComputeBlobsToSyncFunction · 0.85
stripBlobNameFunction · 0.85
_SyncAllParamsFunction · 0.85
infoMethod · 0.80
GetDevicesMethod · 0.80
NetMethod · 0.80
formatMethod · 0.45
ProtoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…