| 381 | |
| 382 | # PyTorch Lightning callback for logging images during training and validation of a deep learning model |
| 383 | class ImageLogger(Callback): |
| 384 | def __init__( |
| 385 | self, |
| 386 | batch_frequency, # Frequency of batches on which to log images |
| 387 | max_images, # Maximum number of images to log |
| 388 | clamp=True, # Whether to clamp pixel values to [-1,1] |
| 389 | increase_log_steps=True, # Whether to increase frequency of log steps exponentially |
| 390 | rescale=True, # Whether to rescale pixel values to [0,1] |
| 391 | disabled=False, # Whether to disable logging |
| 392 | log_on_batch_idx=False, # Whether to log on batch index instead of global step |
| 393 | log_first_step=False, # Whether to log on the first step |
| 394 | log_images_kwargs=None, |
| 395 | ): # Additional keyword arguments to pass to log_images method |
| 396 | super().__init__() |
| 397 | self.rescale = rescale |
| 398 | self.batch_freq = batch_frequency |
| 399 | self.max_images = max_images |
| 400 | self.logger_log_images = { |
| 401 | # Dictionary of logger classes and their corresponding logging methods |
| 402 | pl.loggers.CSVLogger: self._testtube, |
| 403 | } |
| 404 | # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency |
| 405 | self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] |
| 406 | if not increase_log_steps: |
| 407 | self.log_steps = [self.batch_freq] |
| 408 | self.clamp = clamp |
| 409 | self.disabled = disabled |
| 410 | self.log_on_batch_idx = log_on_batch_idx |
| 411 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} |
| 412 | self.log_first_step = log_first_step |
| 413 | |
| 414 | @rank_zero_only # Ensure that only the first process in distributed training executes this method |
| 415 | def _testtube( |
| 416 | self, # The PyTorch Lightning module |
| 417 | pl_module, # A dictionary of images to log. |
| 418 | images, # |
| 419 | batch_idx, # The batch index. |
| 420 | split, # The split (train/val) on which to log the images |
| 421 | ): |
| 422 | # Method for logging images using test-tube logger |
| 423 | for k in images: |
| 424 | grid = torchvision.utils.make_grid(images[k]) |
| 425 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w |
| 426 | |
| 427 | tag = f"{split}/{k}" |
| 428 | # Add image grid to logger's experiment |
| 429 | pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) |
| 430 | |
| 431 | @rank_zero_only |
| 432 | def log_local( |
| 433 | self, |
| 434 | save_dir, |
| 435 | split, # The split (train/val) on which to log the images |
| 436 | images, # A dictionary of images to log |
| 437 | global_step, # The global step |
| 438 | current_epoch, # The current epoch. |
| 439 | batch_idx, |
| 440 | ): |
no outgoing calls
no test coverage detected
searching dependent graphs…