(module: torch.nn.Module,
memory_buffers: Dict[torch.dtype, MemoryBuffer],
maintain_weight=True)
| 95 | |
| 96 | |
| 97 | def build_memory_reference_from_module(module: torch.nn.Module, |
| 98 | memory_buffers: Dict[torch.dtype, MemoryBuffer], |
| 99 | maintain_weight=True): |
| 100 | start_index = {} |
| 101 | for dtype in memory_buffers.keys(): |
| 102 | start_index[dtype] = 0 |
| 103 | for name, param in sorted(module.named_parameters()): |
| 104 | memory_buffer = memory_buffers[param.dtype] |
| 105 | buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype]) |
| 106 | # need to increment start_index |
| 107 | start_index[param.dtype] += calc_padded_numel(param.shape, dtype) |
| 108 | if maintain_weight: |
| 109 | buffer.copy_(param.data) |
| 110 | param.data = buffer |
| 111 | |
| 112 | |
| 113 | def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]): |
no test coverage detected