MCPcopy
hub / github.com/google-deepmind/gemma / transform

Method transform

gemma/gm/ckpts/_policy.py:39–71  ·  view source on GitHub ↗
(self, state: kd.train.TrainState)

Source from the content-addressed store, hash-verified

37 anchor: kd.ckpts.InitTransform | None = None
38
39 def transform(self, state: kd.train.TrainState) -> kd.train.TrainState:
40 if set(state.params.keys()) != {'policy', 'anchor'}:
41 raise ValueError(
42 'AnchoredPolicyLoader is meant to be used with'
43 ' `model=gm.nn.AnchoredPolicy`.'
44 )
45
46 # Load the policy params.
47 policy_state = dataclasses.replace(state, params=state.params['policy'])
48 policy_state = self.policy.transform(policy_state)
49
50 # Load the anchor params.
51 if self.anchor is None:
52 # If `anchor` is not provided, load a copy the policy params.
53 _checkpoint.release_memory(state.params['anchor'])
54 anchor_params = jax.tree.map(jnp.copy, policy_state.params)
55 anchor_state = dataclasses.replace(
56 policy_state,
57 params=anchor_params,
58 )
59 else:
60 anchor_state = dataclasses.replace(state, params=state.params['anchor'])
61 anchor_state = self.anchor.transform(anchor_state)
62
63 # Merge the two states back together.
64 state = dataclasses.replace(
65 state,
66 params={
67 'policy': policy_state.params,
68 'anchor': anchor_state.params,
69 },
70 )
71 return state

Callers

nothing calls this directly

Calls 1

mapMethod · 0.45

Tested by

no test coverage detected