| 4482 | self.checkpoint_engine.save(state_dict=state, path=save_path) |
| 4483 | |
| 4484 | def _get_buffer_names(self): |
| 4485 | buffer_names = [] |
| 4486 | |
| 4487 | # we save buffer names so that we could extract later the real buffers from the saved |
| 4488 | # state_dict["module"] in the non-zero checkpoint - the buffers are already there but they |
| 4489 | # are intermixed with param placeholders |
| 4490 | |
| 4491 | # have to traverse the tree to be able to skip non-persistent buffers |
| 4492 | def get_layer_named_buffers(module, prefix=""): |
| 4493 | for name, buf in module.named_buffers(recurse=False): |
| 4494 | if buf is not None and name not in module._non_persistent_buffers_set: |
| 4495 | buffer_names.append(prefix + name) |
| 4496 | |
| 4497 | for name, child in module.named_children(): |
| 4498 | if child is not None: |
| 4499 | get_layer_named_buffers(child, prefix + name + ".") |
| 4500 | |
| 4501 | get_layer_named_buffers(self.module, prefix="") |
| 4502 | |
| 4503 | return buffer_names |
| 4504 | |
| 4505 | def _get_param_shape_func(self, param): |
| 4506 | return param.ds_shape if hasattr(param, 'ds_id') else param.shape |