MCPcopy
hub / github.com/PRIME-RL/PRIME / build_memory_reference_from_module

Function build_memory_reference_from_module

training/verl/utils/memory_buffer.py:97–110  ·  view source on GitHub ↗
(module: torch.nn.Module,
                                       memory_buffers: Dict[torch.dtype, MemoryBuffer],
                                       maintain_weight=True)

Source from the content-addressed store, hash-verified

95
96
97def build_memory_reference_from_module(module: torch.nn.Module,
98 memory_buffers: Dict[torch.dtype, MemoryBuffer],
99 maintain_weight=True):
100 start_index = {}
101 for dtype in memory_buffers.keys():
102 start_index[dtype] = 0
103 for name, param in sorted(module.named_parameters()):
104 memory_buffer = memory_buffers[param.dtype]
105 buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])
106 # need to increment start_index
107 start_index[param.dtype] += calc_padded_numel(param.shape, dtype)
108 if maintain_weight:
109 buffer.copy_(param.data)
110 param.data = buffer
111
112
113def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]):

Callers 4

__init__Method · 0.85

Calls 3

calc_padded_numelFunction · 0.85
named_parametersMethod · 0.80
getMethod · 0.45

Tested by

no test coverage detected