(self, left_structure, right_structure)
| 783 | ), |
| 784 | ) |
| 785 | def test_order_like(self, left_structure, right_structure): |
| 786 | left = jax.tree_util.tree_map( |
| 787 | lambda leaf: jnp.arange(np.prod(leaf.shape), dtype=leaf.dtype).reshape( |
| 788 | leaf.shape |
| 789 | ), |
| 790 | left_structure, |
| 791 | ) |
| 792 | right = jax.tree_util.tree_map(jnp.zeros_like, right_structure) |
| 793 | left_like_right = left.order_like(right) |
| 794 | # Same content as left. |
| 795 | chex.assert_trees_all_equal( |
| 796 | left_like_right.canonicalize(), left.canonicalize() |
| 797 | ) |
| 798 | # Same structure as right. |
| 799 | chex.assert_trees_all_equal_structs(left_like_right, right) |
| 800 | |
| 801 | @parameterized.named_parameters( |
| 802 | dict( |
nothing calls this directly
no test coverage detected