(self, state: State, logger: Logger)
| 276 | self._save_at_batch_end = True |
| 277 | |
| 278 | def batch_end(self, state: State, logger: Logger) -> None: |
| 279 | assert state.profiler is not None |
| 280 | timestamp = state.timestamp |
| 281 | trace_folder = format_name_with_dist(self.folder, run_name=state.run_name) |
| 282 | if self._save_at_batch_end: |
| 283 | # no longer active, but was previously active. |
| 284 | # Epty the queue and save the trace file |
| 285 | trace_filename = os.path.join( |
| 286 | trace_folder, |
| 287 | format_name_with_dist_and_time(self.filename, state.run_name, timestamp), |
| 288 | ) |
| 289 | trace_dirname = os.path.dirname(trace_filename) |
| 290 | if trace_dirname: |
| 291 | os.makedirs(trace_dirname, exist_ok=True) |
| 292 | with open(trace_filename, 'w+') as f: |
| 293 | is_first_line = True |
| 294 | f.write('[\n') |
| 295 | while True: |
| 296 | try: |
| 297 | s = self._queue.get_nowait() |
| 298 | except queue.Empty: |
| 299 | break |
| 300 | if not is_first_line: |
| 301 | s = ',\n' + s |
| 302 | is_first_line = False |
| 303 | f.write(s) |
| 304 | f.write('\n]\n') |
| 305 | |
| 306 | if self.remote_file_name is not None: |
| 307 | remote_file_name = format_name_with_dist_and_time(self.remote_file_name, state.run_name, timestamp) |
| 308 | logger.upload_file( |
| 309 | remote_file_name=remote_file_name, |
| 310 | file_path=trace_filename, |
| 311 | overwrite=self.overwrite, |
| 312 | ) |
| 313 | # Gather the filenames |
| 314 | trace_files = [pathlib.Path(x) for x in dist.all_gather_object(trace_filename)] |
| 315 | self.saved_traces.append((timestamp, trace_files)) |
| 316 | |
| 317 | # Ensure that all traces have been saved. |
| 318 | dist.barrier() |
| 319 | |
| 320 | if self.merged_trace_filename is not None and dist.get_local_rank() == 0: |
| 321 | # Merge together all traces from the node into one file |
| 322 | start_rank = dist.get_global_rank() |
| 323 | end_rank = dist.get_global_rank() + dist.get_local_world_size() |
| 324 | trace_files_to_merge = trace_files[start_rank:end_rank] |
| 325 | merged_trace_filename = os.path.join( |
| 326 | trace_folder, |
| 327 | format_name_with_dist( |
| 328 | self.merged_trace_filename, |
| 329 | state.run_name, |
| 330 | ), |
| 331 | ) |
| 332 | merged_trace_dirname = os.path.dirname(merged_trace_filename) |
| 333 | if merged_trace_dirname: |
| 334 | os.makedirs(merged_trace_dirname, exist_ok=True) |
| 335 |
nothing calls this directly
no test coverage detected