| 37 | anchor: kd.ckpts.InitTransform | None = None |
| 38 | |
| 39 | def transform(self, state: kd.train.TrainState) -> kd.train.TrainState: |
| 40 | if set(state.params.keys()) != {'policy', 'anchor'}: |
| 41 | raise ValueError( |
| 42 | 'AnchoredPolicyLoader is meant to be used with' |
| 43 | ' `model=gm.nn.AnchoredPolicy`.' |
| 44 | ) |
| 45 | |
| 46 | # Load the policy params. |
| 47 | policy_state = dataclasses.replace(state, params=state.params['policy']) |
| 48 | policy_state = self.policy.transform(policy_state) |
| 49 | |
| 50 | # Load the anchor params. |
| 51 | if self.anchor is None: |
| 52 | # If `anchor` is not provided, load a copy the policy params. |
| 53 | _checkpoint.release_memory(state.params['anchor']) |
| 54 | anchor_params = jax.tree.map(jnp.copy, policy_state.params) |
| 55 | anchor_state = dataclasses.replace( |
| 56 | policy_state, |
| 57 | params=anchor_params, |
| 58 | ) |
| 59 | else: |
| 60 | anchor_state = dataclasses.replace(state, params=state.params['anchor']) |
| 61 | anchor_state = self.anchor.transform(anchor_state) |
| 62 | |
| 63 | # Merge the two states back together. |
| 64 | state = dataclasses.replace( |
| 65 | state, |
| 66 | params={ |
| 67 | 'policy': policy_state.params, |
| 68 | 'anchor': anchor_state.params, |
| 69 | }, |
| 70 | ) |
| 71 | return state |