MCPcopy
hub / github.com/hpcaitech/ColossalAI / ImageLogger

Class ImageLogger

examples/images/diffusion/main.py:383–524  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

381
382# PyTorch Lightning callback for logging images during training and validation of a deep learning model
383class ImageLogger(Callback):
384 def __init__(
385 self,
386 batch_frequency, # Frequency of batches on which to log images
387 max_images, # Maximum number of images to log
388 clamp=True, # Whether to clamp pixel values to [-1,1]
389 increase_log_steps=True, # Whether to increase frequency of log steps exponentially
390 rescale=True, # Whether to rescale pixel values to [0,1]
391 disabled=False, # Whether to disable logging
392 log_on_batch_idx=False, # Whether to log on batch index instead of global step
393 log_first_step=False, # Whether to log on the first step
394 log_images_kwargs=None,
395 ): # Additional keyword arguments to pass to log_images method
396 super().__init__()
397 self.rescale = rescale
398 self.batch_freq = batch_frequency
399 self.max_images = max_images
400 self.logger_log_images = {
401 # Dictionary of logger classes and their corresponding logging methods
402 pl.loggers.CSVLogger: self._testtube,
403 }
404 # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency
405 self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
406 if not increase_log_steps:
407 self.log_steps = [self.batch_freq]
408 self.clamp = clamp
409 self.disabled = disabled
410 self.log_on_batch_idx = log_on_batch_idx
411 self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
412 self.log_first_step = log_first_step
413
414 @rank_zero_only # Ensure that only the first process in distributed training executes this method
415 def _testtube(
416 self, # The PyTorch Lightning module
417 pl_module, # A dictionary of images to log.
418 images, #
419 batch_idx, # The batch index.
420 split, # The split (train/val) on which to log the images
421 ):
422 # Method for logging images using test-tube logger
423 for k in images:
424 grid = torchvision.utils.make_grid(images[k])
425 grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
426
427 tag = f"{split}/{k}"
428 # Add image grid to logger's experiment
429 pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
430
431 @rank_zero_only
432 def log_local(
433 self,
434 save_dir,
435 split, # The split (train/val) on which to log the images
436 images, # A dictionary of images to log
437 global_step, # The global step
438 current_epoch, # The current epoch.
439 batch_idx,
440 ):

Callers 1

main.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…