MCPcopy
hub / github.com/InternLM/InternLM / step

Method step

internlm/utils/simple_memory_profiler.py:322–363  ·  view source on GitHub ↗

Update the memory state of the optimizer state. Returns: None

(self)

Source from the content-addressed store, hash-verified

320 writer.write("\n")
321
322 def step(self) -> None:
323 """
324 Update the memory state of the optimizer state.
325
326 Returns:
327 None
328 """
329 if self._stoped:
330 return
331
332 self._remaining_steps -= 1
333 if self._remaining_steps == 0:
334 self._stoped = True
335
336 # Update os state memory usage
337 self._os_state_mem_state = SimpleMemState("os_state_mem")
338 self._calc_tensor_group_memory(self._os_state_mem_state, list(self._optimizer.state_dict()["state"].items()))
339
340 if not self._stoped:
341 # Do we need to print os_state_layout every time? Is it always constant?
342 self.point(with_options="os_state")
343 else:
344 # Dump memory layout
345 self.point(with_options="all")
346 # Generate sunburst charts
347 self._render_sunburst_chart(self._param_mem_state.to_json()["children"], "params_memory_sunburst")
348 self._render_sunburst_chart(self._grad_mem_state.to_json()["children"], "grads_memory_sunburst")
349 self._render_sunburst_chart(
350 [self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
351 "os_memory_sunburst",
352 )
353 self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst")
354 # Generate summary sunburst chart
355 summary_sunburst_data = [
356 {"name": "params", "value": self._param_mem_state.total_mem // mb},
357 {"name": "grads", "value": self._grad_mem_state.total_mem // mb},
358 {"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
359 {"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
360 {"name": "activation", "value": self._activation_mem_max // mb},
361 ]
362
363 self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
364
365 def _render_sunburst_chart(self, data: Any, name: str) -> None:
366 pyecharts.charts.Sunburst(init_opts=pyecharts.options.InitOpts(width="1000px", height="1000px")).add(

Callers 2

mainFunction · 0.95

Calls 6

pointMethod · 0.95
SimpleMemStateClass · 0.85
state_dictMethod · 0.45
to_jsonMethod · 0.45

Tested by

no test coverage detected