MCPcopy
hub / github.com/lucas-maes/le-wm / lejepa_forward

Function lejepa_forward

train.py:17–45  ·  view source on GitHub ↗

encode observations, predict next states, compute losses.

(self, batch, stage, cfg)

Source from the content-addressed store, hash-verified

15
16
17def lejepa_forward(self, batch, stage, cfg):
18 """encode observations, predict next states, compute losses."""
19
20 ctx_len = cfg.history_size
21 n_preds = cfg.num_preds
22 lambd = cfg.loss.sigreg.weight
23
24 # Replace NaN values with 0 (occurs at sequence boundaries)
25 batch["action"] = torch.nan_to_num(batch["action"], 0.0)
26
27 output = self.model.encode(batch)
28
29 emb = output["emb"] # (B, T, D)
30 act_emb = output["act_emb"]
31
32 ctx_emb = emb[:, :ctx_len]
33 ctx_act = act_emb[:, : ctx_len]
34
35 tgt_emb = emb[:, n_preds:] # label
36 pred_emb = self.model.predict(ctx_emb, ctx_act) # pred
37
38 # LeWM loss
39 output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean()
40 output["sigreg_loss"]= self.sigreg(emb.transpose(0, 1))
41 output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"]
42
43 losses_dict = {f"{stage}/{k}": v.detach() for k, v in output.items() if "loss" in k}
44 self.log_dict(losses_dict, on_step=True, sync_dist=True)
45 return output
46
47@hydra.main(version_base=None, config_path="./config/train", config_name="lewm")
48def run(cfg):

Callers

nothing calls this directly

Calls 2

encodeMethod · 0.80
predictMethod · 0.80

Tested by

no test coverage detected