Build the memory buffer given weight_buffer_meta Args: weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors Returns: a large memory buffer for each dtype that can hold all the tensors
(weight_buffer_meta: Dict[str, Dict])
| 66 | |
| 67 | |
| 68 | def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]: |
| 69 | """Build the memory buffer given weight_buffer_meta |
| 70 | |
| 71 | Args: |
| 72 | weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors |
| 73 | |
| 74 | Returns: a large memory buffer for each dtype that can hold all the tensors |
| 75 | |
| 76 | """ |
| 77 | memory_buffers = {} |
| 78 | total_numel_map = {} # map from dtype to the total numel |
| 79 | for name, meta_info in sorted(weight_buffer_meta.items()): |
| 80 | shape = meta_info['shape'] |
| 81 | dtype = meta_info['dtype'] |
| 82 | |
| 83 | assert isinstance(shape, torch.Size) |
| 84 | assert isinstance(dtype, torch.dtype) |
| 85 | |
| 86 | if dtype not in total_numel_map: |
| 87 | total_numel_map[dtype] = 0 |
| 88 | |
| 89 | total_numel_map[dtype] += calc_padded_numel(shape, dtype) |
| 90 | |
| 91 | for dtype, total_numel in total_numel_map.items(): |
| 92 | memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype) |
| 93 | |
| 94 | return memory_buffers |
| 95 | |
| 96 | |
| 97 | def build_memory_reference_from_module(module: torch.nn.Module, |
no test coverage detected