MCPcopy
hub / github.com/hpcaitech/ColossalAI / log_img

Method log_img

examples/images/diffusion/main.py:456–497  ·  view source on GitHub ↗
(self, pl_module, batch, batch_idx, split="train")

Source from the content-addressed store, hash-verified

454 Image.fromarray(grid).save(path)
455
456 def log_img(self, pl_module, batch, batch_idx, split="train"):
457 # Function for logging images to both the logger and local file system.
458 check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
459 # check if it's time to log an image batch
460 if (
461 self.check_frequency(check_idx)
462 and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
463 and callable(pl_module.log_images)
464 and self.max_images > 0
465 ):
466 # Get logger type and check if training mode is on
467 logger = type(pl_module.logger)
468
469 is_train = pl_module.training
470 if is_train:
471 pl_module.eval()
472
473 with torch.no_grad():
474 # Get images from log_images method of the pl_module
475 images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
476
477 # Clip images if specified and convert to CPU tensor
478 for k in images:
479 N = min(images[k].shape[0], self.max_images)
480 images[k] = images[k][:N]
481 if isinstance(images[k], torch.Tensor):
482 images[k] = images[k].detach().cpu()
483 if self.clamp:
484 images[k] = torch.clamp(images[k], -1.0, 1.0)
485
486 # Log images locally to file system
487 self.log_local(
488 pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx
489 )
490
491 # log the images using the logger
492 logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
493 logger_log_images(pl_module, images, pl_module.global_step, split)
494
495 # switch back to training mode if necessary
496 if is_train:
497 pl_module.train()
498
499 # The function checks if it's time to log an image batch
500 def check_frequency(self, check_idx):

Callers 1

Calls 9

check_frequencyMethod · 0.95
log_localMethod · 0.95
no_gradMethod · 0.80
detachMethod · 0.80
evalMethod · 0.45
log_imagesMethod · 0.45
cpuMethod · 0.45
getMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected