MCPcopy
hub / github.com/PRIME-RL/PRIME / build_memory_reference

Function build_memory_reference

training/verl/utils/memory_buffer.py:113–137  ·  view source on GitHub ↗

Build the memory references. The memory buffers are built using the build_memory_buffer API. This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. Args: weight_buffer_meta: memory_buffers: Returns:

(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer])

Source from the content-addressed store, hash-verified

111
112
113def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]):
114 """Build the memory references. The memory buffers are built using the build_memory_buffer API.
115 This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.
116
117 Args:
118 weight_buffer_meta:
119 memory_buffers:
120
121 Returns:
122
123 """
124 start_idx = {}
125 weight_buffers = {}
126 for dtype in memory_buffers.keys():
127 start_idx[dtype] = 0
128
129 for name, meta_info in sorted(weight_buffer_meta.items()):
130 shape = meta_info['shape']
131 dtype = meta_info['dtype']
132
133 buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])
134 start_idx[dtype] += calc_padded_numel(shape, dtype)
135 weight_buffers[name] = buffer
136
137 return weight_buffers
138
139
140class MemoryBufferModuleWrapper:

Callers 1

Calls 2

calc_padded_numelFunction · 0.85
getMethod · 0.45

Tested by

no test coverage detected