A memory profiler for a llm model. Args: model (torch.nn.Module): The model to profile. optimizer (torch.optim.Optimizer): The optimizer used for training the model. log_file (str): The file to write the memory state information to. total_steps: number of st
| 203 | |
| 204 | |
| 205 | class SimpleMemoryProfiler: |
| 206 | """ |
| 207 | A memory profiler for a llm model. |
| 208 | |
| 209 | Args: |
| 210 | model (torch.nn.Module): The model to profile. |
| 211 | optimizer (torch.optim.Optimizer): The optimizer used for training the model. |
| 212 | log_file (str): The file to write the memory state information to. |
| 213 | total_steps: number of steps to trace. |
| 214 | """ |
| 215 | |
| 216 | def __init__( |
| 217 | self, |
| 218 | model: torch.nn.Module, |
| 219 | optimizer: torch.optim.Optimizer, |
| 220 | log_folder: str, |
| 221 | total_steps: int = 5, |
| 222 | ): |
| 223 | self._model, self._num_model_chunks = _unpack_naive_wrapper(model) |
| 224 | self._optimizer = optimizer |
| 225 | self._log_folder = log_folder |
| 226 | self._remaining_steps = total_steps |
| 227 | |
| 228 | self._stoped = False |
| 229 | self._record_start_time = time.time() |
| 230 | |
| 231 | # For activation memory state. |
| 232 | |
| 233 | self._activation_mem: int = 0 |
| 234 | self._activation_mem_max: int = 0 |
| 235 | self._activation_base_mems = ActivationMemState(self._num_model_chunks) |
| 236 | |
| 237 | # Check or create log folder |
| 238 | os.makedirs(self._log_folder, exist_ok=True) |
| 239 | |
| 240 | # Register activation memory tracking hooks |
| 241 | if self._num_model_chunks > 1: |
| 242 | for chunk_id in range(self._num_model_chunks): |
| 243 | self._register_activation_trace_hooks(chunk_id, self._model[chunk_id]) |
| 244 | else: |
| 245 | self._register_activation_trace_hooks(0, self._model) |
| 246 | |
| 247 | # Calculate static parameter cuda memory |
| 248 | self._param_mem_state = SimpleMemState("param_mem") |
| 249 | self._calc_tensor_memory(self._param_mem_state, self._model.named_parameters()) |
| 250 | # Calculate static grad cuda memory |
| 251 | self._grad_mem_state = SimpleMemState("grad_mem") |
| 252 | self._calc_tensor_memory(self._grad_mem_state, self._model.named_parameters(), True) |
| 253 | # Calculate static optimizer state cuda memory |
| 254 | self._os_params_mem_state = SimpleMemState("os_params_mem") |
| 255 | self._os_state_mem_state = SimpleMemState("os_state_mem") |
| 256 | self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups))) |
| 257 | |
| 258 | # Generate the first memory record |
| 259 | self.point(with_options="params,grads,os_params", create=True) |
| 260 | |
| 261 | def point(self, with_options: str = "", create: bool = False) -> None: |
| 262 | """ |
no outgoing calls
no test coverage detected