Lazily defines a ray.remote function. This is used in Datasets to avoid circular import issues with ray.remote. (ray imports ray.data in order to allow ``ray.data.read_foo()`` to work, which means ray.remote cannot be used top-level in ray.data). NOTE: Dynamic arguments should not
(fn: Any, **ray_remote_args)
| 6 | |
| 7 | |
| 8 | def cached_remote_fn(fn: Any, **ray_remote_args) -> Any: |
| 9 | """Lazily defines a ray.remote function. |
| 10 | |
| 11 | This is used in Datasets to avoid circular import issues with ray.remote. |
| 12 | (ray imports ray.data in order to allow ``ray.data.read_foo()`` to work, |
| 13 | which means ray.remote cannot be used top-level in ray.data). |
| 14 | |
| 15 | NOTE: Dynamic arguments should not be passed in directly, |
| 16 | and should be set with ``options`` instead: |
| 17 | ``cached_remote_fn(fn, **static_args).options(**dynamic_args)``. |
| 18 | """ |
| 19 | |
| 20 | # NOTE: Hash of the passed in arguments guarantees that we're caching |
| 21 | # complete instantiation of the Ray's remote method |
| 22 | # |
| 23 | # To compute the hash of passed in arguments and make sure it's deterministic |
| 24 | # - Sort all KV-pairs by the keys |
| 25 | # - Convert sorted list into tuple |
| 26 | # - Compute hash of the resulting tuple |
| 27 | hashable_args = _make_hashable(ray_remote_args) |
| 28 | args_hash = hash(hashable_args) |
| 29 | |
| 30 | if (fn, args_hash) not in CACHED_FUNCTIONS: |
| 31 | default_ray_remote_args = { |
| 32 | # Use the default scheduling strategy for all tasks so that we will |
| 33 | # not inherit a placement group from the caller, if there is one. |
| 34 | # The caller of this function may override the scheduling strategy |
| 35 | # as needed. |
| 36 | "scheduling_strategy": "DEFAULT", |
| 37 | "max_retries": -1, |
| 38 | } |
| 39 | ray_remote_args = {**default_ray_remote_args, **ray_remote_args} |
| 40 | _add_system_error_to_retry_exceptions(ray_remote_args) |
| 41 | |
| 42 | CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn) |
| 43 | |
| 44 | return CACHED_FUNCTIONS[(fn, args_hash)] |
| 45 | |
| 46 | |
| 47 | def _make_hashable(obj): |
searching dependent graphs…