| 92 | self.global_step = -1 |
| 93 | |
| 94 | def train_epoch(self, data_set): |
| 95 | self.model.train() |
| 96 | self.tra.train() |
| 97 | |
| 98 | data_set.train() |
| 99 | |
| 100 | max_steps = self.n_epochs |
| 101 | if self.max_steps_per_epoch is not None: |
| 102 | max_steps = min(self.max_steps_per_epoch, self.n_epochs) |
| 103 | |
| 104 | count = 0 |
| 105 | total_loss = 0 |
| 106 | total_count = 0 |
| 107 | for batch in tqdm(data_set, total=max_steps): |
| 108 | count += 1 |
| 109 | if count > max_steps: |
| 110 | break |
| 111 | |
| 112 | self.global_step += 1 |
| 113 | |
| 114 | data, label, index = batch["data"], batch["label"], batch["index"] |
| 115 | |
| 116 | feature = data[:, :, : -self.tra.num_states] |
| 117 | hist_loss = data[:, : -data_set.horizon, -self.tra.num_states :] |
| 118 | |
| 119 | hidden = self.model(feature) |
| 120 | pred, all_preds, prob = self.tra(hidden, hist_loss) |
| 121 | |
| 122 | loss = (pred - label).pow(2).mean() |
| 123 | |
| 124 | L = (all_preds.detach() - label[:, None]).pow(2) |
| 125 | L -= L.min(dim=-1, keepdim=True).values # normalize & ensure positive input |
| 126 | |
| 127 | data_set.assign_data(index, L) # save loss to memory |
| 128 | |
| 129 | if prob is not None: |
| 130 | P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix |
| 131 | lamb = self.lamb * (self.rho**self.global_step) |
| 132 | reg = prob.log().mul(P).sum(dim=-1).mean() |
| 133 | loss = loss - lamb * reg |
| 134 | |
| 135 | loss.backward() |
| 136 | self.optimizer.step() |
| 137 | self.optimizer.zero_grad() |
| 138 | |
| 139 | total_loss += loss.item() |
| 140 | total_count += len(pred) |
| 141 | |
| 142 | total_loss /= total_count |
| 143 | |
| 144 | return total_loss |
| 145 | |
| 146 | def test_epoch(self, data_set, return_pred=False): |
| 147 | self.model.eval() |