Return compacted copies of *yv* and *yi* based on a per-row bitmask. Only the elements whose index appears among the active bits of *bitmask* are kept; the rest are replaced by *sentinel*. Kept elements preserve their original left-to-right order. Parameters ----------
(yv, yi, bitmask, sentinel=-1)
| 8 | |
| 9 | |
| 10 | def compaction(yv, yi, bitmask, sentinel=-1): |
| 11 | """ |
| 12 | Return compacted copies of *yv* and *yi* based on a per-row bitmask. |
| 13 | |
| 14 | Only the elements whose index appears among the active bits of *bitmask* |
| 15 | are kept; the rest are replaced by *sentinel*. Kept elements preserve |
| 16 | their original left-to-right order. |
| 17 | |
| 18 | Parameters |
| 19 | ---------- |
| 20 | yv : torch.Tensor, shape (B, K) |
| 21 | Values tensor. |
| 22 | yi : torch.Tensor, shape (B, K), dtype torch.long |
| 23 | Integer indices (0 ≤ index < 32) associated with *yv*. |
| 24 | bitmask : torch.Tensor, shape (B,) **or** (B, 32) |
| 25 | Per-row mask of active indices. See the in-place version for details. |
| 26 | sentinel : int, default -1 |
| 27 | Value written into dropped positions of the returned tensors. |
| 28 | |
| 29 | Returns |
| 30 | ------- |
| 31 | (yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K) |
| 32 | New tensors with the same dtype/device as the inputs. |
| 33 | |
| 34 | """ |
| 35 | |
| 36 | n_rows, n_cols = yi.shape |
| 37 | ret_yv = torch.empty_like(yv) |
| 38 | ret_yi = torch.empty_like(yi) |
| 39 | if isinstance(bitmask, Bitmatrix): |
| 40 | bitmask = bitmask.storage.data |
| 41 | |
| 42 | _masked_compaction[(n_rows, )]( |
| 43 | yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs |
| 44 | ret_yv, ret_yi, # outputs |
| 45 | sentinel, # sentinel |
| 46 | K=n_cols # constants |
| 47 | ) |
| 48 | return ret_yv, ret_yi |
| 49 | |
| 50 | |
| 51 | def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1): |
no test coverage detected