()
| 38 | |
| 39 | |
| 40 | def main(): |
| 41 | # parse options |
| 42 | cfg = parse_args(phase="test") # parse config file |
| 43 | cfg.FOLDER = cfg.TEST.FOLDER |
| 44 | |
| 45 | # Logger |
| 46 | logger = create_logger(cfg, phase="test") |
| 47 | logger.info(OmegaConf.to_yaml(cfg)) |
| 48 | |
| 49 | # Output dir |
| 50 | model_name = cfg.model.target.split('.')[-2].lower() |
| 51 | output_dir = Path( |
| 52 | os.path.join(cfg.FOLDER, model_name, cfg.NAME, "samples_" + cfg.TIME)) |
| 53 | if cfg.TEST.SAVE_PREDICTIONS: |
| 54 | output_dir.mkdir(parents=True, exist_ok=True) |
| 55 | logger.info(f"Saving predictions to {str(output_dir)}") |
| 56 | |
| 57 | # Seed |
| 58 | pl.seed_everything(cfg.SEED_VALUE) |
| 59 | |
| 60 | # Environment Variables |
| 61 | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| 62 | |
| 63 | # Callbacks |
| 64 | callbacks = build_callbacks(cfg, logger=logger, phase="test") |
| 65 | logger.info("Callbacks initialized") |
| 66 | |
| 67 | # Dataset |
| 68 | datamodule = build_data(cfg) |
| 69 | logger.info("datasets module {} initialized".format("".join( |
| 70 | cfg.DATASET.target.split('.')[-2]))) |
| 71 | |
| 72 | # Model |
| 73 | model = build_model(cfg, datamodule) |
| 74 | logger.info("model {} loaded".format(cfg.model.target)) |
| 75 | |
| 76 | # Lightning Trainer |
| 77 | trainer = pl.Trainer( |
| 78 | benchmark=False, |
| 79 | max_epochs=cfg.TRAIN.END_EPOCH, |
| 80 | accelerator=cfg.ACCELERATOR, |
| 81 | devices=list(range(len(cfg.DEVICE))), |
| 82 | default_root_dir=cfg.FOLDER_EXP, |
| 83 | reload_dataloaders_every_n_epochs=1, |
| 84 | deterministic=False, |
| 85 | detect_anomaly=False, |
| 86 | enable_progress_bar=True, |
| 87 | logger=None, |
| 88 | callbacks=callbacks, |
| 89 | ) |
| 90 | |
| 91 | # Strict load vae model |
| 92 | if cfg.TRAIN.PRETRAINED_VAE: |
| 93 | load_pretrained_vae(cfg, model, logger) |
| 94 | |
| 95 | # loading state dict |
| 96 | if cfg.TEST.CHECKPOINTS: |
| 97 | load_pretrained(cfg, model, logger, phase="test") |
no test coverage detected