(
ray_start_regular_shared, udf_kind, target_max_block_size_infinite_or_default
)
| 346 | |
| 347 | @pytest.mark.parametrize("udf_kind", ["gen", "func"]) |
| 348 | def test_flat_map( |
| 349 | ray_start_regular_shared, udf_kind, target_max_block_size_infinite_or_default |
| 350 | ): |
| 351 | ds = ray.data.range(3) |
| 352 | |
| 353 | if udf_kind == "gen": |
| 354 | |
| 355 | def _udf(item: dict) -> Iterator[int]: |
| 356 | for _ in range(2): |
| 357 | yield {"id": item["id"] + 1} |
| 358 | |
| 359 | elif udf_kind == "func": |
| 360 | |
| 361 | def _udf(item: dict) -> dict: |
| 362 | return [{"id": item["id"] + 1} for _ in range(2)] |
| 363 | |
| 364 | else: |
| 365 | pytest.fail(f"Invalid udf_kind: {udf_kind}") |
| 366 | |
| 367 | assert sorted(extract_values("id", ds.flat_map(_udf).take())) == [ |
| 368 | 1, |
| 369 | 1, |
| 370 | 2, |
| 371 | 2, |
| 372 | 3, |
| 373 | 3, |
| 374 | ] |
| 375 | |
| 376 | |
| 377 | # Helper function to process timestamp data in nanoseconds |
nothing calls this directly
no test coverage detected
searching dependent graphs…