MCPcopy
hub / github.com/jax-ml/jax / bind

Method bind

jax/_src/core.py:660–707  ·  view source on GitHub ↗
(self, *args, **params)

Source from the content-addressed store, hash-verified

658 return f'{self.name}'
659
660 def bind(self, *args, **params):
661 canonical_args = []
662 avals = []
663 for i, arg in enumerate(args):
664 try:
665 c_arg = dtypes.canonicalize_value(arg)
666 aval = typeof(c_arg)
667 except TypeError as e:
668 raise TypeError(
669 f"Error interpreting argument to {self} as a JAX value."
670 f" The problematic value is of type {type(arg)} and was passed to"
671 f" {self} at position {i}.\n"
672 ) from e
673 if (not self.skip_canonicalization and isinstance(aval, ShapedArray)
674 and not aval.sharding.mesh.empty):
675 cur_mesh = mesh_lib.get_abstract_mesh()
676 if cur_mesh != aval.sharding.mesh:
677 # TODO(yashkatariya): Casting to Explicit is not yet allowed. Maybe we
678 # need cast_and_slice_p for it since shape might change?
679 # Atleast 1 mesh axis should be Manual and all other axes should be
680 # Manual or Auto to allow casting.
681 if cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual:
682 if aval.sharding.mesh.are_all_axes_auto:
683 from jax._src.pjit import reshard # pyrefly: ignore[missing-import]
684 c_arg = reshard(c_arg, NamedSharding(cur_mesh, P(*[None] * aval.ndim)))
685 aval = typeof(c_arg)
686 elif aval.sharding.mesh._any_axis_explicit:
687 raise NotImplementedError(
688 "Closing over inputs to shard_map where the input is sharded "
689 "on `Explicit` axes is not implemented. As a workaround, "
690 "please pass those inputs as an argument to shard_map. Got "
691 f"input with shape {aval.str_short(True, True)}")
692 if isinstance(c_arg, Tracer) and not c_arg._trace.is_valid():
693 raise escaped_tracer_error(c_arg)
694 canonical_args.append(c_arg)
695 avals.append(aval)
696
697 args = canonical_args
698
699 # This is equivalent to "with take_current_trace()", but the bind() code
700 # is called frequently and it's slightly faster to avoid using a context
701 # manager object.
702 prev_trace = trace_ctx.trace
703 trace_ctx.set_trace(eval_trace)
704 try:
705 return self.bind_with_trace(prev_trace, args, avals, params)
706 finally:
707 trace_ctx.set_trace(prev_trace)
708
709 def bind_with_trace(self, trace, args, avals, params, /):
710 if self.is_high(*avals, **params) and trace.requires_low:

Callers 15

to_qarrayFunction · 0.80
from_qarrayFunction · 0.80
jvpMethod · 0.80
immutbox_newFunction · 0.80
immutbox_getFunction · 0.80

Calls 10

bind_with_traceMethod · 0.95
reshardFunction · 0.90
NamedShardingClass · 0.90
PClass · 0.90
typeofFunction · 0.85
escaped_tracer_errorFunction · 0.85
is_validMethod · 0.80
set_traceMethod · 0.80
str_shortMethod · 0.45
appendMethod · 0.45

Tested by 15

to_qarrayFunction · 0.64
from_qarrayFunction · 0.64
jvpMethod · 0.64
immutbox_newFunction · 0.64
immutbox_getFunction · 0.64