Test propagating additional kwargs to map tasks.
(ray_start_regular_shared, use_actors)
| 484 | |
| 485 | @pytest.mark.parametrize("use_actors", [False, True]) |
| 486 | def test_map_kwargs(ray_start_regular_shared, use_actors): |
| 487 | """Test propagating additional kwargs to map tasks.""" |
| 488 | foo = 1 |
| 489 | bar = np.random.random(1024 * 1024) |
| 490 | kwargs = { |
| 491 | "foo": foo, # Pass by value |
| 492 | "bar": ray.put(bar), # Pass by ObjectRef |
| 493 | } |
| 494 | |
| 495 | def map_fn(block_iter: Iterable[Block], ctx: TaskContext) -> Iterable[Block]: |
| 496 | nonlocal foo, bar |
| 497 | assert ctx.kwargs["foo"] == foo |
| 498 | # bar should be automatically deref'ed. |
| 499 | assert np.array_equal(ctx.kwargs["bar"], bar) |
| 500 | |
| 501 | yield from block_iter |
| 502 | |
| 503 | input_op = InputDataBuffer( |
| 504 | DataContext.get_current(), |
| 505 | make_ref_bundles([[i] for i in range(10)]), |
| 506 | ) |
| 507 | compute_strategy = ActorPoolStrategy() if use_actors else TaskPoolStrategy() |
| 508 | op = MapOperator.create( |
| 509 | create_map_transformer_from_block_fn(map_fn), |
| 510 | input_op=input_op, |
| 511 | data_context=DataContext.get_current(), |
| 512 | name="TestMapper", |
| 513 | compute_strategy=compute_strategy, |
| 514 | ) |
| 515 | op.add_map_task_kwargs_fn(lambda: kwargs) |
| 516 | op.start(ExecutionOptions()) |
| 517 | if use_actors: |
| 518 | # Wait for the actor to start. |
| 519 | run_op_tasks_sync(op) |
| 520 | |
| 521 | while input_op.has_next(): |
| 522 | if use_actors: |
| 523 | # For actors, we need to check capacity before adding input |
| 524 | # and process tasks when the actor pool is at capacity. |
| 525 | while not op.can_add_input(): |
| 526 | run_one_op_task(op) |
| 527 | |
| 528 | assert op.can_add_input() |
| 529 | op.add_input(input_op.get_next(), 0) |
| 530 | op.all_inputs_done() |
| 531 | run_op_tasks_sync(op) |
| 532 | |
| 533 | _take_outputs(op) |
| 534 | assert op.has_completed() |
| 535 | |
| 536 | |
| 537 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected
searching dependent graphs…