Load model checkpoint from file. :param prefix: Prefix of model name. :param epoch: Epoch number of model we would like to load. :return: (arg_params, aux_params) arg_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's weights. aux_param
(prefix, epoch)
| 9 | |
| 10 | |
| 11 | def load_checkpoint(prefix, epoch): |
| 12 | """ |
| 13 | Load model checkpoint from file. |
| 14 | :param prefix: Prefix of model name. |
| 15 | :param epoch: Epoch number of model we would like to load. |
| 16 | :return: (arg_params, aux_params) |
| 17 | arg_params : dict of str to NDArray |
| 18 | Model parameter, dict of name to NDArray of net's weights. |
| 19 | aux_params : dict of str to NDArray |
| 20 | Model parameter, dict of name to NDArray of net's auxiliary states. |
| 21 | """ |
| 22 | save_dict = mx.nd.load('%s-%04d.params' % (prefix, epoch)) |
| 23 | arg_params = {} |
| 24 | aux_params = {} |
| 25 | for k, v in save_dict.items(): |
| 26 | tp, name = k.split(':', 1) |
| 27 | if tp == 'arg': |
| 28 | arg_params[name] = v |
| 29 | if tp == 'aux': |
| 30 | aux_params[name] = v |
| 31 | return arg_params, aux_params |
| 32 | |
| 33 | |
| 34 | def convert_context(params, ctx): |
no test coverage detected