MCPcopy
hub / github.com/huggingface/pytorch-image-models / feature_take_indices

Function feature_take_indices

timm/models/_features.py:28–66  ·  view source on GitHub ↗

Determine the absolute feature indices to 'take' from. Note: This function can be called in forward() so must be torchscript compatible, which requires some incomplete typing and workaround hacks. Args: num_features: total number of features to select from indices: ind

(
        num_features: int,
        indices: Optional[Union[int, List[int]]] = None,
        as_set: bool = False,
)

Source from the content-addressed store, hash-verified

26
27
28def feature_take_indices(
29 num_features: int,
30 indices: Optional[Union[int, List[int]]] = None,
31 as_set: bool = False,
32) -> Tuple[List[int], int]:
33 """ Determine the absolute feature indices to 'take' from.
34
35 Note: This function can be called in forward() so must be torchscript compatible,
36 which requires some incomplete typing and workaround hacks.
37
38 Args:
39 num_features: total number of features to select from
40 indices: indices to select,
41 None -> select all
42 int -> select last n
43 list/tuple of int -> return specified (-ve indices specify from end)
44 as_set: return as a set
45
46 Returns:
47 List (or set) of absolute (from beginning) indices, Maximum index
48 """
49 if indices is None:
50 indices = num_features # all features if None
51
52 if isinstance(indices, int):
53 # convert int -> last n indices
54 _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
55 take_indices = [num_features - indices + i for i in range(indices)]
56 else:
57 take_indices: List[int] = []
58 for i in indices:
59 idx = num_features + i if i < 0 else i
60 _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
61 take_indices.append(idx)
62
63 if not torch.jit.is_scripting() and as_set:
64 return set(take_indices), max(take_indices)
65
66 return take_indices, max(take_indices)
67
68
69def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:

Calls 1

_assertFunction · 0.90

Tested by

no test coverage detected