encode observations, predict next states, compute losses.
(self, batch, stage, cfg)
| 15 | |
| 16 | |
| 17 | def lejepa_forward(self, batch, stage, cfg): |
| 18 | """encode observations, predict next states, compute losses.""" |
| 19 | |
| 20 | ctx_len = cfg.history_size |
| 21 | n_preds = cfg.num_preds |
| 22 | lambd = cfg.loss.sigreg.weight |
| 23 | |
| 24 | # Replace NaN values with 0 (occurs at sequence boundaries) |
| 25 | batch["action"] = torch.nan_to_num(batch["action"], 0.0) |
| 26 | |
| 27 | output = self.model.encode(batch) |
| 28 | |
| 29 | emb = output["emb"] # (B, T, D) |
| 30 | act_emb = output["act_emb"] |
| 31 | |
| 32 | ctx_emb = emb[:, :ctx_len] |
| 33 | ctx_act = act_emb[:, : ctx_len] |
| 34 | |
| 35 | tgt_emb = emb[:, n_preds:] # label |
| 36 | pred_emb = self.model.predict(ctx_emb, ctx_act) # pred |
| 37 | |
| 38 | # LeWM loss |
| 39 | output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean() |
| 40 | output["sigreg_loss"]= self.sigreg(emb.transpose(0, 1)) |
| 41 | output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"] |
| 42 | |
| 43 | losses_dict = {f"{stage}/{k}": v.detach() for k, v in output.items() if "loss" in k} |
| 44 | self.log_dict(losses_dict, on_step=True, sync_dist=True) |
| 45 | return output |
| 46 | |
| 47 | @hydra.main(version_base=None, config_path="./config/train", config_name="lewm") |
| 48 | def run(cfg): |