Determine the absolute feature indices to 'take' from. Note: This function can be called in forward() so must be torchscript compatible, which requires some incomplete typing and workaround hacks. Args: num_features: total number of features to select from indices: ind
(
num_features: int,
indices: Optional[Union[int, List[int]]] = None,
as_set: bool = False,
)
| 26 | |
| 27 | |
| 28 | def feature_take_indices( |
| 29 | num_features: int, |
| 30 | indices: Optional[Union[int, List[int]]] = None, |
| 31 | as_set: bool = False, |
| 32 | ) -> Tuple[List[int], int]: |
| 33 | """ Determine the absolute feature indices to 'take' from. |
| 34 | |
| 35 | Note: This function can be called in forward() so must be torchscript compatible, |
| 36 | which requires some incomplete typing and workaround hacks. |
| 37 | |
| 38 | Args: |
| 39 | num_features: total number of features to select from |
| 40 | indices: indices to select, |
| 41 | None -> select all |
| 42 | int -> select last n |
| 43 | list/tuple of int -> return specified (-ve indices specify from end) |
| 44 | as_set: return as a set |
| 45 | |
| 46 | Returns: |
| 47 | List (or set) of absolute (from beginning) indices, Maximum index |
| 48 | """ |
| 49 | if indices is None: |
| 50 | indices = num_features # all features if None |
| 51 | |
| 52 | if isinstance(indices, int): |
| 53 | # convert int -> last n indices |
| 54 | _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') |
| 55 | take_indices = [num_features - indices + i for i in range(indices)] |
| 56 | else: |
| 57 | take_indices: List[int] = [] |
| 58 | for i in indices: |
| 59 | idx = num_features + i if i < 0 else i |
| 60 | _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') |
| 61 | take_indices.append(idx) |
| 62 | |
| 63 | if not torch.jit.is_scripting() and as_set: |
| 64 | return set(take_indices), max(take_indices) |
| 65 | |
| 66 | return take_indices, max(take_indices) |
| 67 | |
| 68 | |
| 69 | def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: |
no test coverage detected