MCPcopy
hub / github.com/mosaicml/composer / batch_end

Method batch_end

composer/profiler/json_trace_handler.py:278–367  ·  view source on GitHub ↗
(self, state: State, logger: Logger)

Source from the content-addressed store, hash-verified

276 self._save_at_batch_end = True
277
278 def batch_end(self, state: State, logger: Logger) -> None:
279 assert state.profiler is not None
280 timestamp = state.timestamp
281 trace_folder = format_name_with_dist(self.folder, run_name=state.run_name)
282 if self._save_at_batch_end:
283 # no longer active, but was previously active.
284 # Epty the queue and save the trace file
285 trace_filename = os.path.join(
286 trace_folder,
287 format_name_with_dist_and_time(self.filename, state.run_name, timestamp),
288 )
289 trace_dirname = os.path.dirname(trace_filename)
290 if trace_dirname:
291 os.makedirs(trace_dirname, exist_ok=True)
292 with open(trace_filename, 'w+') as f:
293 is_first_line = True
294 f.write('[\n')
295 while True:
296 try:
297 s = self._queue.get_nowait()
298 except queue.Empty:
299 break
300 if not is_first_line:
301 s = ',\n' + s
302 is_first_line = False
303 f.write(s)
304 f.write('\n]\n')
305
306 if self.remote_file_name is not None:
307 remote_file_name = format_name_with_dist_and_time(self.remote_file_name, state.run_name, timestamp)
308 logger.upload_file(
309 remote_file_name=remote_file_name,
310 file_path=trace_filename,
311 overwrite=self.overwrite,
312 )
313 # Gather the filenames
314 trace_files = [pathlib.Path(x) for x in dist.all_gather_object(trace_filename)]
315 self.saved_traces.append((timestamp, trace_files))
316
317 # Ensure that all traces have been saved.
318 dist.barrier()
319
320 if self.merged_trace_filename is not None and dist.get_local_rank() == 0:
321 # Merge together all traces from the node into one file
322 start_rank = dist.get_global_rank()
323 end_rank = dist.get_global_rank() + dist.get_local_world_size()
324 trace_files_to_merge = trace_files[start_rank:end_rank]
325 merged_trace_filename = os.path.join(
326 trace_folder,
327 format_name_with_dist(
328 self.merged_trace_filename,
329 state.run_name,
330 ),
331 )
332 merged_trace_dirname = os.path.dirname(merged_trace_filename)
333 if merged_trace_dirname:
334 os.makedirs(merged_trace_dirname, exist_ok=True)
335

Callers

nothing calls this directly

Calls 6

format_name_with_distFunction · 0.90
merge_tracesFunction · 0.90
existsMethod · 0.80
writeMethod · 0.45
upload_fileMethod · 0.45

Tested by

no test coverage detected