Decode an array of locations in a hypercube into a Hilbert integer. This is a vectorized-ish version of the Hilbert curve implementation by John Skilling as described in: Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference Proceedings (Vol. 707, No.
(locs, num_dims, num_bits)
| 89 | |
| 90 | |
| 91 | def encode(locs, num_dims, num_bits): |
| 92 | """Decode an array of locations in a hypercube into a Hilbert integer. |
| 93 | |
| 94 | This is a vectorized-ish version of the Hilbert curve implementation by John |
| 95 | Skilling as described in: |
| 96 | |
| 97 | Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference |
| 98 | Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. |
| 99 | |
| 100 | Params: |
| 101 | ------- |
| 102 | locs - An ndarray of locations in a hypercube of num_dims dimensions, in |
| 103 | which each dimension runs from 0 to 2**num_bits-1. The shape can |
| 104 | be arbitrary, as long as the last dimension of the same has size |
| 105 | num_dims. |
| 106 | |
| 107 | num_dims - The dimensionality of the hypercube. Integer. |
| 108 | |
| 109 | num_bits - The number of bits for each dimension. Integer. |
| 110 | |
| 111 | Returns: |
| 112 | -------- |
| 113 | The output is an ndarray of uint64 integers with the same shape as the |
| 114 | input, excluding the last dimension, which needs to be num_dims. |
| 115 | """ |
| 116 | |
| 117 | # Keep around the original shape for later. |
| 118 | orig_shape = locs.shape |
| 119 | bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) |
| 120 | bitpack_mask_rev = bitpack_mask.flip(-1) |
| 121 | |
| 122 | if orig_shape[-1] != num_dims: |
| 123 | raise ValueError( |
| 124 | """ |
| 125 | The shape of locs was surprising in that the last dimension was of size |
| 126 | %d, but num_dims=%d. These need to be equal. |
| 127 | """ |
| 128 | % (orig_shape[-1], num_dims) |
| 129 | ) |
| 130 | |
| 131 | if num_dims * num_bits > 63: |
| 132 | raise ValueError( |
| 133 | """ |
| 134 | num_dims=%d and num_bits=%d for %d bits total, which can't be encoded |
| 135 | into a int64. Are you sure you need that many points on your Hilbert |
| 136 | curve? |
| 137 | """ |
| 138 | % (num_dims, num_bits, num_dims * num_bits) |
| 139 | ) |
| 140 | |
| 141 | # Treat the location integers as 64-bit unsigned and then split them up into |
| 142 | # a sequence of uint8s. Preserve the association by dimension. |
| 143 | locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) |
| 144 | |
| 145 | # Now turn these into bits and truncate to num_bits. |
| 146 | gray = ( |
| 147 | locs_uint8.unsqueeze(-1) |
| 148 | .bitwise_and(bitpack_mask_rev) |
nothing calls this directly
no test coverage detected