| 430 | |
| 431 | @rank_zero_only |
| 432 | def log_local( |
| 433 | self, |
| 434 | save_dir, |
| 435 | split, # The split (train/val) on which to log the images |
| 436 | images, # A dictionary of images to log |
| 437 | global_step, # The global step |
| 438 | current_epoch, # The current epoch. |
| 439 | batch_idx, |
| 440 | ): |
| 441 | # Method for saving image grids to local file system |
| 442 | root = os.path.join(save_dir, "images", split) |
| 443 | for k in images: |
| 444 | grid = torchvision.utils.make_grid(images[k], nrow=4) |
| 445 | if self.rescale: |
| 446 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w |
| 447 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) |
| 448 | grid = grid.numpy() |
| 449 | grid = (grid * 255).astype(np.uint8) |
| 450 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) |
| 451 | path = os.path.join(root, filename) |
| 452 | os.makedirs(os.path.split(path)[0], exist_ok=True) |
| 453 | # Save image grid as PNG file |
| 454 | Image.fromarray(grid).save(path) |
| 455 | |
| 456 | def log_img(self, pl_module, batch, batch_idx, split="train"): |
| 457 | # Function for logging images to both the logger and local file system. |