MCPcopy Index your code
hub / github.com/PRIME-RL/PRIME / build_memory_buffer

Function build_memory_buffer

training/verl/utils/memory_buffer.py:68–94  ·  view source on GitHub ↗

Build the memory buffer given weight_buffer_meta Args: weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors Returns: a large memory buffer for each dtype that can hold all the tensors

(weight_buffer_meta: Dict[str, Dict])

Source from the content-addressed store, hash-verified

66
67
68def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]:
69 """Build the memory buffer given weight_buffer_meta
70
71 Args:
72 weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors
73
74 Returns: a large memory buffer for each dtype that can hold all the tensors
75
76 """
77 memory_buffers = {}
78 total_numel_map = {} # map from dtype to the total numel
79 for name, meta_info in sorted(weight_buffer_meta.items()):
80 shape = meta_info['shape']
81 dtype = meta_info['dtype']
82
83 assert isinstance(shape, torch.Size)
84 assert isinstance(dtype, torch.dtype)
85
86 if dtype not in total_numel_map:
87 total_numel_map[dtype] = 0
88
89 total_numel_map[dtype] += calc_padded_numel(shape, dtype)
90
91 for dtype, total_numel in total_numel_map.items():
92 memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)
93
94 return memory_buffers
95
96
97def build_memory_reference_from_module(module: torch.nn.Module,

Callers 3

_build_param_bufferMethod · 0.90
__init__Method · 0.85

Calls 2

calc_padded_numelFunction · 0.85
MemoryBufferClass · 0.70

Tested by

no test coverage detected