Allocate buffer for deterministic per-program reductions.
(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True)
| 78 | |
| 79 | |
| 80 | def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): |
| 81 | """Allocate buffer for deterministic per-program reductions.""" |
| 82 | if base_shape is None: |
| 83 | return None, 0 |
| 84 | if deterministic: |
| 85 | factory = torch.zeros if zero_init else torch.empty |
| 86 | tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype) |
| 87 | return tensor, tensor.stride(-1) |
| 88 | return torch.empty(*base_shape, device=device, dtype=dtype), 0 |
| 89 | |
| 90 | |
| 91 | def finalize_tile_workspace(tensor, deterministic): |
no outgoing calls
no test coverage detected