This function should be called after loading parameters from a checkpoint / initial parameters file.
(model, blobs=None, cpu_mode=False)
| 928 | |
| 929 | |
| 930 | def FinalizeAfterCheckpoint(model, blobs=None, cpu_mode=False): |
| 931 | ''' |
| 932 | This function should be called after loading parameters from a |
| 933 | checkpoint / initial parameters file. |
| 934 | ''' |
| 935 | |
| 936 | if not hasattr(model, "_checkpoint_net"): |
| 937 | if blobs is None: |
| 938 | (_, uniq_blob_names) = _ComputeBlobsToSync(model) |
| 939 | else: |
| 940 | uniq_blob_names = [stripBlobName(p) for p in blobs] |
| 941 | |
| 942 | # Synchronize to the blob lookup map, as the provided |
| 943 | # blobs might have non-parameters, such as momentum blobs. |
| 944 | log.info("Creating checkpoint synchronization net") |
| 945 | devices = model.GetDevices() |
| 946 | for name in uniq_blob_names: |
| 947 | if name not in model._device_grouped_blobs: |
| 948 | grouped = { |
| 949 | d: |
| 950 | core.BlobReference("{}_{}{}{}".format( |
| 951 | model._device_prefix, |
| 952 | d, |
| 953 | scope._NAMESCOPE_SEPARATOR, |
| 954 | name) |
| 955 | ) for d in devices} |
| 956 | model._device_grouped_blobs[name] = grouped |
| 957 | |
| 958 | model._checkpoint_net = core.Net("checkpoint_sync_net") |
| 959 | if not cpu_mode: |
| 960 | model._checkpoint_net.RunAllOnGPU() |
| 961 | |
| 962 | checkpoint_init_net = None |
| 963 | if (model._rendezvous is not None and model._rendezvous['num_shards'] > 1): |
| 964 | checkpoint_init_net = core.Net("checkpoint_init_net") |
| 965 | if not cpu_mode: |
| 966 | checkpoint_init_net.RunAllOnGPU() |
| 967 | |
| 968 | _SyncAllParams( |
| 969 | devices, |
| 970 | model, |
| 971 | checkpoint_init_net, |
| 972 | model._checkpoint_net, |
| 973 | model._rendezvous, |
| 974 | uniq_blob_names, |
| 975 | max_concurrent_distributed_ops=1 |
| 976 | ) |
| 977 | if (checkpoint_init_net): |
| 978 | workspace.RunNetOnce(checkpoint_init_net) |
| 979 | |
| 980 | workspace.CreateNet(model._checkpoint_net) |
| 981 | |
| 982 | # Run the sync |
| 983 | log.info("Run checkpoint net") |
| 984 | workspace.RunNet(model._checkpoint_net.Proto().name) |
| 985 | |
| 986 | |
| 987 | def GetLearningRateBlobNames(model): |
nothing calls this directly
no test coverage detected
searching dependent graphs…