MCPcopy
hub / github.com/google-deepmind/penzai / body

Method body

penzai/nn/layer_stack.py:205–235  ·  view source on GitHub ↗
(carry_data, slice_data)

Source from the content-addressed store, hash-verified

203 ]
204
205 def body(carry_data, slice_data):
206 cur_arg, shared_var_states = carry_data
207 named_array_slices, sliced_var_states = slice_data
208
209 # Insert slices of the named arrays.
210 sublayer = stacked_array_selection.set_sequence(named_array_slices)
211
212 # Call with variables and get new variables.
213 next_arg, new_var_states = sublayer.stateless_call(
214 shared_var_states + sliced_var_states, cur_arg, **pure_side_inputs
215 )
216 new_var_state_by_label = {var.label: var for var in new_var_states}
217 new_shared_var_states = [
218 new_var_state_by_label[var.label] for var in shared_var_states
219 ]
220 new_sliced_var_states = [
221 new_var_state_by_label[var.label] for var in sliced_var_states
222 ]
223
224 # Ensure results are NamedArray instances (with positional axes at the
225 # front) so that we can safely add a new axis.
226 new_sliced_var_states = (
227 selectors.select(new_sliced_var_states)
228 .at_instances_of(named_axes.NamedArrayBase)
229 .apply(lambda x: x.with_positional_prefix())
230 )
231
232 new_carry = named_axes.order_like(
233 (next_arg, new_shared_var_states), carry_data
234 )
235 return new_carry, new_sliced_var_states
236
237 (final_value, new_shared_var_states), new_sliced_var_states = jax.lax.scan(
238 body,

Callers 4

__call__Method · 0.80
__call__Method · 0.80
__call__Method · 0.80
__call__Method · 0.80

Calls 7

set_sequenceMethod · 0.80
stateless_callMethod · 0.80
at_instances_ofMethod · 0.80
selectMethod · 0.80
order_likeMethod · 0.80
applyMethod · 0.45

Tested by

no test coverage detected