MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / load_and_assign_npz_dict

Function load_and_assign_npz_dict

tensorlayer/files/utils.py:2077–2112  ·  view source on GitHub ↗

Restore the parameters saved by ``tl.files.save_npz_dict()``. Parameters ------------- name : str The name of the `.npz` file. network : :class:`Model` The network to be assigned. skip : boolean If 'skip' == True, loaded weights whose name is not found in

(name='model.npz', network=None, skip=False)

Source from the content-addressed store, hash-verified

2075
2076
2077def load_and_assign_npz_dict(name='model.npz', network=None, skip=False):
2078 """Restore the parameters saved by ``tl.files.save_npz_dict()``.
2079
2080 Parameters
2081 -------------
2082 name : str
2083 The name of the `.npz` file.
2084 network : :class:`Model`
2085 The network to be assigned.
2086 skip : boolean
2087 If 'skip' == True, loaded weights whose name is not found in network's weights will be skipped.
2088 If 'skip' is False, error will be raised when mismatch is found. Default False.
2089
2090 """
2091 if not os.path.exists(name):
2092 logging.error("file {} doesn't exist.".format(name))
2093 return False
2094
2095 weights = np.load(name, allow_pickle=True)
2096 if len(weights.keys()) != len(set(weights.keys())):
2097 raise Exception("Duplication in model npz_dict %s" % name)
2098
2099 net_weights_name = [w.name for w in network.all_weights]
2100
2101 for key in weights.keys():
2102 if key not in net_weights_name:
2103 if skip:
2104 logging.warning("Weights named '%s' not found in network. Skip it." % key)
2105 else:
2106 raise RuntimeError(
2107 "Weights named '%s' not found in network. Hint: set argument skip=Ture "
2108 "if you want to skip redundant or mismatch weights." % key
2109 )
2110 else:
2111 assign_tf_variable(network.all_weights[net_weights_name.index(key)], weights[key])
2112 logging.info("[*] Model restored from npz_dict %s" % name)
2113
2114
2115def save_ckpt(mode_name='model.ckpt', save_dir='checkpoint', var_list=None, global_step=None, printable=False):

Callers

nothing calls this directly

Calls 2

assign_tf_variableFunction · 0.85
loadMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…