MCPcopy
hub / github.com/state-spaces/mamba / alloc_tile_workspace

Function alloc_tile_workspace

mamba_ssm/utils/determinism.py:80–88  ·  view source on GitHub ↗

Allocate buffer for deterministic per-program reductions.

(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True)

Source from the content-addressed store, hash-verified

78
79
80def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True):
81 """Allocate buffer for deterministic per-program reductions."""
82 if base_shape is None:
83 return None, 0
84 if deterministic:
85 factory = torch.zeros if zero_init else torch.empty
86 tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype)
87 return tensor, tensor.stride(-1)
88 return torch.empty(*base_shape, device=device, dtype=dtype), 0
89
90
91def finalize_tile_workspace(tensor, deterministic):

Callers 6

_chunk_scan_bwd_dCFunction · 0.90
_chunk_scan_bwd_dxFunction · 0.90
_chunk_state_bwd_dxFunction · 0.90
_chunk_state_bwd_dbFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected