(self)
| 96 | self.logvar = posterior_variance.clamp(min=1e-20).log() |
| 97 | |
| 98 | def train(self): |
| 99 | args, config = self.args, self.config |
| 100 | tb_logger = self.config.tb_logger |
| 101 | dataset, test_dataset = get_dataset(args, config) |
| 102 | train_loader = data.DataLoader( |
| 103 | dataset, |
| 104 | batch_size=config.training.batch_size, |
| 105 | shuffle=True, |
| 106 | num_workers=config.data.num_workers, |
| 107 | ) |
| 108 | model = Model(config) |
| 109 | |
| 110 | model = model.to(self.device) |
| 111 | model = torch.nn.DataParallel(model) |
| 112 | |
| 113 | optimizer = get_optimizer(self.config, model.parameters()) |
| 114 | |
| 115 | if self.config.model.ema: |
| 116 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) |
| 117 | ema_helper.register(model) |
| 118 | else: |
| 119 | ema_helper = None |
| 120 | |
| 121 | start_epoch, step = 0, 0 |
| 122 | if self.args.resume_training: |
| 123 | states = torch.load(os.path.join(self.args.log_path, "ckpt.pth")) |
| 124 | model.load_state_dict(states[0]) |
| 125 | |
| 126 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps |
| 127 | optimizer.load_state_dict(states[1]) |
| 128 | start_epoch = states[2] |
| 129 | step = states[3] |
| 130 | if self.config.model.ema: |
| 131 | ema_helper.load_state_dict(states[4]) |
| 132 | |
| 133 | for epoch in range(start_epoch, self.config.training.n_epochs): |
| 134 | data_start = time.time() |
| 135 | data_time = 0 |
| 136 | for i, (x, y) in enumerate(train_loader): |
| 137 | n = x.size(0) |
| 138 | data_time += time.time() - data_start |
| 139 | model.train() |
| 140 | step += 1 |
| 141 | |
| 142 | x = x.to(self.device) |
| 143 | x = data_transform(self.config, x) |
| 144 | e = torch.randn_like(x) |
| 145 | b = self.betas |
| 146 | |
| 147 | # antithetic sampling |
| 148 | t = torch.randint( |
| 149 | low=0, high=self.num_timesteps, size=(n // 2 + 1,) |
| 150 | ).to(self.device) |
| 151 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n] |
| 152 | loss = loss_registry[config.model.type](model, x, t, e, b) |
| 153 | |
| 154 | tb_logger.add_scalar("loss", loss, global_step=step) |
| 155 |
no test coverage detected