| 658 | return f'{self.name}' |
| 659 | |
| 660 | def bind(self, *args, **params): |
| 661 | canonical_args = [] |
| 662 | avals = [] |
| 663 | for i, arg in enumerate(args): |
| 664 | try: |
| 665 | c_arg = dtypes.canonicalize_value(arg) |
| 666 | aval = typeof(c_arg) |
| 667 | except TypeError as e: |
| 668 | raise TypeError( |
| 669 | f"Error interpreting argument to {self} as a JAX value." |
| 670 | f" The problematic value is of type {type(arg)} and was passed to" |
| 671 | f" {self} at position {i}.\n" |
| 672 | ) from e |
| 673 | if (not self.skip_canonicalization and isinstance(aval, ShapedArray) |
| 674 | and not aval.sharding.mesh.empty): |
| 675 | cur_mesh = mesh_lib.get_abstract_mesh() |
| 676 | if cur_mesh != aval.sharding.mesh: |
| 677 | # TODO(yashkatariya): Casting to Explicit is not yet allowed. Maybe we |
| 678 | # need cast_and_slice_p for it since shape might change? |
| 679 | # Atleast 1 mesh axis should be Manual and all other axes should be |
| 680 | # Manual or Auto to allow casting. |
| 681 | if cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual: |
| 682 | if aval.sharding.mesh.are_all_axes_auto: |
| 683 | from jax._src.pjit import reshard # pyrefly: ignore[missing-import] |
| 684 | c_arg = reshard(c_arg, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) |
| 685 | aval = typeof(c_arg) |
| 686 | elif aval.sharding.mesh._any_axis_explicit: |
| 687 | raise NotImplementedError( |
| 688 | "Closing over inputs to shard_map where the input is sharded " |
| 689 | "on `Explicit` axes is not implemented. As a workaround, " |
| 690 | "please pass those inputs as an argument to shard_map. Got " |
| 691 | f"input with shape {aval.str_short(True, True)}") |
| 692 | if isinstance(c_arg, Tracer) and not c_arg._trace.is_valid(): |
| 693 | raise escaped_tracer_error(c_arg) |
| 694 | canonical_args.append(c_arg) |
| 695 | avals.append(aval) |
| 696 | |
| 697 | args = canonical_args |
| 698 | |
| 699 | # This is equivalent to "with take_current_trace()", but the bind() code |
| 700 | # is called frequently and it's slightly faster to avoid using a context |
| 701 | # manager object. |
| 702 | prev_trace = trace_ctx.trace |
| 703 | trace_ctx.set_trace(eval_trace) |
| 704 | try: |
| 705 | return self.bind_with_trace(prev_trace, args, avals, params) |
| 706 | finally: |
| 707 | trace_ctx.set_trace(prev_trace) |
| 708 | |
| 709 | def bind_with_trace(self, trace, args, avals, params, /): |
| 710 | if self.is_high(*avals, **params) and trace.requires_low: |