(test_data, batch_dims: Tuple[int, ...])
| 447 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") |
| 448 | @pytest.mark.parametrize("batch_dims", [(), (2,), (1, 2)]) |
| 449 | def test_isect(test_data, batch_dims: Tuple[int, ...]): |
| 450 | from gsplat.cuda._torch_impl import _isect_offset_encode, _isect_tiles |
| 451 | from gsplat.cuda._wrapper import isect_offset_encode, isect_tiles |
| 452 | |
| 453 | torch.manual_seed(42) |
| 454 | |
| 455 | B = math.prod(batch_dims) |
| 456 | C, N = 3, 1000 |
| 457 | I = B * C |
| 458 | width, height = 40, 60 |
| 459 | |
| 460 | test_data = { |
| 461 | "means2d": torch.randn(C, N, 2, device=device) * width, |
| 462 | "radii": torch.randint(0, width, (C, N, 2), device=device, dtype=torch.int32), |
| 463 | "depths": torch.rand(C, N, device=device), |
| 464 | } |
| 465 | test_data = expand(test_data, batch_dims) |
| 466 | means2d = test_data["means2d"] |
| 467 | radii = test_data["radii"] |
| 468 | depths = test_data["depths"] |
| 469 | |
| 470 | tile_size = 16 |
| 471 | tile_width = math.ceil(width / tile_size) |
| 472 | tile_height = math.ceil(height / tile_size) |
| 473 | |
| 474 | tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( |
| 475 | means2d, radii, depths, tile_size, tile_width, tile_height |
| 476 | ) |
| 477 | isect_offsets = isect_offset_encode(isect_ids, I, tile_width, tile_height) |
| 478 | |
| 479 | _tiles_per_gauss, _isect_ids, _gauss_ids = _isect_tiles( |
| 480 | means2d, radii, depths, tile_size, tile_width, tile_height |
| 481 | ) |
| 482 | _isect_offsets = _isect_offset_encode(_isect_ids, I, tile_width, tile_height) |
| 483 | |
| 484 | torch.testing.assert_close(tiles_per_gauss, _tiles_per_gauss) |
| 485 | torch.testing.assert_close(isect_ids, _isect_ids) |
| 486 | torch.testing.assert_close(flatten_ids, _gauss_ids) |
| 487 | torch.testing.assert_close(isect_offsets, _isect_offsets) |
| 488 | |
| 489 | |
| 490 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") |
nothing calls this directly
no test coverage detected
searching dependent graphs…