Update the memory state of the optimizer state. Returns: None
(self)
| 320 | writer.write("\n") |
| 321 | |
| 322 | def step(self) -> None: |
| 323 | """ |
| 324 | Update the memory state of the optimizer state. |
| 325 | |
| 326 | Returns: |
| 327 | None |
| 328 | """ |
| 329 | if self._stoped: |
| 330 | return |
| 331 | |
| 332 | self._remaining_steps -= 1 |
| 333 | if self._remaining_steps == 0: |
| 334 | self._stoped = True |
| 335 | |
| 336 | # Update os state memory usage |
| 337 | self._os_state_mem_state = SimpleMemState("os_state_mem") |
| 338 | self._calc_tensor_group_memory(self._os_state_mem_state, list(self._optimizer.state_dict()["state"].items())) |
| 339 | |
| 340 | if not self._stoped: |
| 341 | # Do we need to print os_state_layout every time? Is it always constant? |
| 342 | self.point(with_options="os_state") |
| 343 | else: |
| 344 | # Dump memory layout |
| 345 | self.point(with_options="all") |
| 346 | # Generate sunburst charts |
| 347 | self._render_sunburst_chart(self._param_mem_state.to_json()["children"], "params_memory_sunburst") |
| 348 | self._render_sunburst_chart(self._grad_mem_state.to_json()["children"], "grads_memory_sunburst") |
| 349 | self._render_sunburst_chart( |
| 350 | [self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()], |
| 351 | "os_memory_sunburst", |
| 352 | ) |
| 353 | self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst") |
| 354 | # Generate summary sunburst chart |
| 355 | summary_sunburst_data = [ |
| 356 | {"name": "params", "value": self._param_mem_state.total_mem // mb}, |
| 357 | {"name": "grads", "value": self._grad_mem_state.total_mem // mb}, |
| 358 | {"name": "os_params", "value": self._os_params_mem_state.total_mem // mb}, |
| 359 | {"name": "os_state", "value": self._os_state_mem_state.total_mem // mb}, |
| 360 | {"name": "activation", "value": self._activation_mem_max // mb}, |
| 361 | ] |
| 362 | |
| 363 | self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst") |
| 364 | |
| 365 | def _render_sunburst_chart(self, data: Any, name: str) -> None: |
| 366 | pyecharts.charts.Sunburst(init_opts=pyecharts.options.InitOpts(width="1000px", height="1000px")).add( |
no test coverage detected