Build the memory references. The memory buffers are built using the build_memory_buffer API. This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. Args: weight_buffer_meta: memory_buffers: Returns:
(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer])
| 111 | |
| 112 | |
| 113 | def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]): |
| 114 | """Build the memory references. The memory buffers are built using the build_memory_buffer API. |
| 115 | This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. |
| 116 | |
| 117 | Args: |
| 118 | weight_buffer_meta: |
| 119 | memory_buffers: |
| 120 | |
| 121 | Returns: |
| 122 | |
| 123 | """ |
| 124 | start_idx = {} |
| 125 | weight_buffers = {} |
| 126 | for dtype in memory_buffers.keys(): |
| 127 | start_idx[dtype] = 0 |
| 128 | |
| 129 | for name, meta_info in sorted(weight_buffer_meta.items()): |
| 130 | shape = meta_info['shape'] |
| 131 | dtype = meta_info['dtype'] |
| 132 | |
| 133 | buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype]) |
| 134 | start_idx[dtype] += calc_padded_numel(shape, dtype) |
| 135 | weight_buffers[name] = buffer |
| 136 | |
| 137 | return weight_buffers |
| 138 | |
| 139 | |
| 140 | class MemoryBufferModuleWrapper: |
no test coverage detected