(
shutdown_only, udf_kind, target_max_block_size_infinite_or_default
)
| 1111 | |
| 1112 | @pytest.mark.parametrize("udf_kind", ["coroutine", "async_gen"]) |
| 1113 | def test_async_flat_map( |
| 1114 | shutdown_only, udf_kind, target_max_block_size_infinite_or_default |
| 1115 | ): |
| 1116 | class AsyncActor: |
| 1117 | def __init__(self): |
| 1118 | pass |
| 1119 | |
| 1120 | if udf_kind == "async_gen": |
| 1121 | |
| 1122 | async def __call__(self, row): |
| 1123 | id = row["id"] |
| 1124 | yield {"id": id} |
| 1125 | await asyncio.sleep(random.randint(0, 5) / 100) |
| 1126 | yield {"id": id + 1} |
| 1127 | |
| 1128 | elif udf_kind == "coroutine": |
| 1129 | |
| 1130 | async def __call__(self, row): |
| 1131 | id = row["id"] |
| 1132 | await asyncio.sleep(random.randint(0, 5) / 100) |
| 1133 | return [{"id": id}, {"id": id + 1}] |
| 1134 | |
| 1135 | else: |
| 1136 | pytest.fail(f"Unknown udf_kind: {udf_kind}") |
| 1137 | |
| 1138 | n = 10 |
| 1139 | ds = ray.data.from_items([{"id": i} for i in range(0, n, 2)]) |
| 1140 | ds = ds.flat_map(AsyncActor, concurrency=1, max_concurrency=2) |
| 1141 | output = ds.take_all() |
| 1142 | assert sorted(extract_values("id", output)) == list(range(n)) |
| 1143 | |
| 1144 | |
| 1145 | class TestGenerateTransformFnForAsyncMap: |
nothing calls this directly
no test coverage detected
searching dependent graphs…