Restore the parameters saved by ``tl.files.save_npz_dict()``. Parameters ------------- name : str The name of the `.npz` file. network : :class:`Model` The network to be assigned. skip : boolean If 'skip' == True, loaded weights whose name is not found in
(name='model.npz', network=None, skip=False)
| 2075 | |
| 2076 | |
| 2077 | def load_and_assign_npz_dict(name='model.npz', network=None, skip=False): |
| 2078 | """Restore the parameters saved by ``tl.files.save_npz_dict()``. |
| 2079 | |
| 2080 | Parameters |
| 2081 | ------------- |
| 2082 | name : str |
| 2083 | The name of the `.npz` file. |
| 2084 | network : :class:`Model` |
| 2085 | The network to be assigned. |
| 2086 | skip : boolean |
| 2087 | If 'skip' == True, loaded weights whose name is not found in network's weights will be skipped. |
| 2088 | If 'skip' is False, error will be raised when mismatch is found. Default False. |
| 2089 | |
| 2090 | """ |
| 2091 | if not os.path.exists(name): |
| 2092 | logging.error("file {} doesn't exist.".format(name)) |
| 2093 | return False |
| 2094 | |
| 2095 | weights = np.load(name, allow_pickle=True) |
| 2096 | if len(weights.keys()) != len(set(weights.keys())): |
| 2097 | raise Exception("Duplication in model npz_dict %s" % name) |
| 2098 | |
| 2099 | net_weights_name = [w.name for w in network.all_weights] |
| 2100 | |
| 2101 | for key in weights.keys(): |
| 2102 | if key not in net_weights_name: |
| 2103 | if skip: |
| 2104 | logging.warning("Weights named '%s' not found in network. Skip it." % key) |
| 2105 | else: |
| 2106 | raise RuntimeError( |
| 2107 | "Weights named '%s' not found in network. Hint: set argument skip=Ture " |
| 2108 | "if you want to skip redundant or mismatch weights." % key |
| 2109 | ) |
| 2110 | else: |
| 2111 | assign_tf_variable(network.all_weights[net_weights_name.index(key)], weights[key]) |
| 2112 | logging.info("[*] Model restored from npz_dict %s" % name) |
| 2113 | |
| 2114 | |
| 2115 | def save_ckpt(mode_name='model.ckpt', save_dir='checkpoint', var_list=None, global_step=None, printable=False): |
nothing calls this directly
no test coverage detected
searching dependent graphs…