Predict next state embedding emb: (B, T, D) act_emb: (B, T, A_emb)
(self, emb, act_emb)
| 45 | return info |
| 46 | |
| 47 | def predict(self, emb, act_emb): |
| 48 | """Predict next state embedding |
| 49 | emb: (B, T, D) |
| 50 | act_emb: (B, T, A_emb) |
| 51 | """ |
| 52 | preds = self.predictor(emb, act_emb) |
| 53 | preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) |
| 54 | preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) |
| 55 | return preds |
| 56 | |
| 57 | #################### |
| 58 | ## Inference only ## |
no outgoing calls
no test coverage detected