(carry_data, slice_data)
| 203 | ] |
| 204 | |
| 205 | def body(carry_data, slice_data): |
| 206 | cur_arg, shared_var_states = carry_data |
| 207 | named_array_slices, sliced_var_states = slice_data |
| 208 | |
| 209 | # Insert slices of the named arrays. |
| 210 | sublayer = stacked_array_selection.set_sequence(named_array_slices) |
| 211 | |
| 212 | # Call with variables and get new variables. |
| 213 | next_arg, new_var_states = sublayer.stateless_call( |
| 214 | shared_var_states + sliced_var_states, cur_arg, **pure_side_inputs |
| 215 | ) |
| 216 | new_var_state_by_label = {var.label: var for var in new_var_states} |
| 217 | new_shared_var_states = [ |
| 218 | new_var_state_by_label[var.label] for var in shared_var_states |
| 219 | ] |
| 220 | new_sliced_var_states = [ |
| 221 | new_var_state_by_label[var.label] for var in sliced_var_states |
| 222 | ] |
| 223 | |
| 224 | # Ensure results are NamedArray instances (with positional axes at the |
| 225 | # front) so that we can safely add a new axis. |
| 226 | new_sliced_var_states = ( |
| 227 | selectors.select(new_sliced_var_states) |
| 228 | .at_instances_of(named_axes.NamedArrayBase) |
| 229 | .apply(lambda x: x.with_positional_prefix()) |
| 230 | ) |
| 231 | |
| 232 | new_carry = named_axes.order_like( |
| 233 | (next_arg, new_shared_var_states), carry_data |
| 234 | ) |
| 235 | return new_carry, new_sliced_var_states |
| 236 | |
| 237 | (final_value, new_shared_var_states), new_sliced_var_states = jax.lax.scan( |
| 238 | body, |
no test coverage detected