(self, pl_module, batch, batch_idx, split="train")
| 454 | Image.fromarray(grid).save(path) |
| 455 | |
| 456 | def log_img(self, pl_module, batch, batch_idx, split="train"): |
| 457 | # Function for logging images to both the logger and local file system. |
| 458 | check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step |
| 459 | # check if it's time to log an image batch |
| 460 | if ( |
| 461 | self.check_frequency(check_idx) |
| 462 | and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 |
| 463 | and callable(pl_module.log_images) |
| 464 | and self.max_images > 0 |
| 465 | ): |
| 466 | # Get logger type and check if training mode is on |
| 467 | logger = type(pl_module.logger) |
| 468 | |
| 469 | is_train = pl_module.training |
| 470 | if is_train: |
| 471 | pl_module.eval() |
| 472 | |
| 473 | with torch.no_grad(): |
| 474 | # Get images from log_images method of the pl_module |
| 475 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) |
| 476 | |
| 477 | # Clip images if specified and convert to CPU tensor |
| 478 | for k in images: |
| 479 | N = min(images[k].shape[0], self.max_images) |
| 480 | images[k] = images[k][:N] |
| 481 | if isinstance(images[k], torch.Tensor): |
| 482 | images[k] = images[k].detach().cpu() |
| 483 | if self.clamp: |
| 484 | images[k] = torch.clamp(images[k], -1.0, 1.0) |
| 485 | |
| 486 | # Log images locally to file system |
| 487 | self.log_local( |
| 488 | pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx |
| 489 | ) |
| 490 | |
| 491 | # log the images using the logger |
| 492 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) |
| 493 | logger_log_images(pl_module, images, pl_module.global_step, split) |
| 494 | |
| 495 | # switch back to training mode if necessary |
| 496 | if is_train: |
| 497 | pl_module.train() |
| 498 | |
| 499 | # The function checks if it's time to log an image batch |
| 500 | def check_frequency(self, check_idx): |
no test coverage detected