MCPcopy Index your code
hub / github.com/InternLM/InternLM / SimpleMemoryProfiler

Class SimpleMemoryProfiler

internlm/utils/simple_memory_profiler.py:205–591  ·  view source on GitHub ↗

A memory profiler for a llm model. Args: model (torch.nn.Module): The model to profile. optimizer (torch.optim.Optimizer): The optimizer used for training the model. log_file (str): The file to write the memory state information to. total_steps: number of st

Source from the content-addressed store, hash-verified

203
204
205class SimpleMemoryProfiler:
206 """
207 A memory profiler for a llm model.
208
209 Args:
210 model (torch.nn.Module): The model to profile.
211 optimizer (torch.optim.Optimizer): The optimizer used for training the model.
212 log_file (str): The file to write the memory state information to.
213 total_steps: number of steps to trace.
214 """
215
216 def __init__(
217 self,
218 model: torch.nn.Module,
219 optimizer: torch.optim.Optimizer,
220 log_folder: str,
221 total_steps: int = 5,
222 ):
223 self._model, self._num_model_chunks = _unpack_naive_wrapper(model)
224 self._optimizer = optimizer
225 self._log_folder = log_folder
226 self._remaining_steps = total_steps
227
228 self._stoped = False
229 self._record_start_time = time.time()
230
231 # For activation memory state.
232
233 self._activation_mem: int = 0
234 self._activation_mem_max: int = 0
235 self._activation_base_mems = ActivationMemState(self._num_model_chunks)
236
237 # Check or create log folder
238 os.makedirs(self._log_folder, exist_ok=True)
239
240 # Register activation memory tracking hooks
241 if self._num_model_chunks > 1:
242 for chunk_id in range(self._num_model_chunks):
243 self._register_activation_trace_hooks(chunk_id, self._model[chunk_id])
244 else:
245 self._register_activation_trace_hooks(0, self._model)
246
247 # Calculate static parameter cuda memory
248 self._param_mem_state = SimpleMemState("param_mem")
249 self._calc_tensor_memory(self._param_mem_state, self._model.named_parameters())
250 # Calculate static grad cuda memory
251 self._grad_mem_state = SimpleMemState("grad_mem")
252 self._calc_tensor_memory(self._grad_mem_state, self._model.named_parameters(), True)
253 # Calculate static optimizer state cuda memory
254 self._os_params_mem_state = SimpleMemState("os_params_mem")
255 self._os_state_mem_state = SimpleMemState("os_state_mem")
256 self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
257
258 # Generate the first memory record
259 self.point(with_options="params,grads,os_params", create=True)
260
261 def point(self, with_options: str = "", create: bool = False) -> None:
262 """

Callers 2

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected